├── python ├── jtorch │ ├── utils │ │ ├── hooks.py │ │ ├── checkpoint.py │ │ ├── __init__.py │ │ ├── dtype.py │ │ ├── pip_publish.py │ │ └── data.py │ ├── fx.py │ ├── nn │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── rnn.py │ │ ├── init.py │ │ └── __init__.py │ ├── vision │ │ ├── transforms.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── vision.py │ │ │ ├── utils.py │ │ │ └── mnist.py │ │ ├── _internally_replaced_utils.py │ │ └── utils.py │ ├── misc.py │ ├── test │ │ ├── test_conflict_func.py │ │ ├── test_misc.py │ │ ├── test_function.py │ │ └── test_tutorial.py │ ├── src │ │ ├── jtorch_core.h │ │ └── jtorch_core.cc │ ├── distributed.py │ ├── cuda.py │ ├── compiler.py │ ├── tutorial │ │ ├── auto_grad1.py │ │ ├── auto_grad5_optim.py │ │ ├── auto_grad6_module.py │ │ ├── auto_grad2.py │ │ ├── auto_grad7_dynet.py │ │ ├── auto_grad4.py │ │ ├── auto_grad3.py │ │ └── quickstart.py │ ├── autograd.py │ ├── __init__.py │ ├── gradscaler.py │ └── gradscaler_old.py └── torch │ ├── autograd.py │ └── __init__.py ├── MANIFEST.in ├── .gitignore ├── setup.py ├── README.md └── LICENSE.txt /python/jtorch/utils/hooks.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/jtorch/fx.py: -------------------------------------------------------------------------------- 1 | class Proxy: 2 | pass -------------------------------------------------------------------------------- /python/jtorch/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import rnn -------------------------------------------------------------------------------- /python/torch/autograd.py: -------------------------------------------------------------------------------- 1 | from jtorch.autograd import * -------------------------------------------------------------------------------- /python/jtorch/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | detach_variable = None -------------------------------------------------------------------------------- /python/jtorch/vision/transforms.py: -------------------------------------------------------------------------------- 1 | from jittor.transform import * -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude __data__ 2 | exclude __pycache__ 3 | prune **/__data__/ 4 | prune **/__pycache__ 5 | prune *.pyc -------------------------------------------------------------------------------- /python/jtorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | cpp_extension = None 2 | _flatten_dense_tensors = None 3 | _unflatten_dense_tensors = None 4 | 5 | tensorboard = None -------------------------------------------------------------------------------- /python/jtorch/vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST 2 | 3 | __all__ = ( 4 | "EMNIST", 5 | "FashionMNIST", 6 | "QMNIST", 7 | "MNIST", 8 | "KMNIST", 9 | ) -------------------------------------------------------------------------------- /python/jtorch/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def _jit_set_profiling_mode(x): pass 4 | def _jit_set_profiling_executor(x): pass 5 | def _jit_override_can_fuse_on_cpu(x): pass 6 | def _jit_override_can_fuse_on_gpu(x): pass 7 | 8 | def script(func): 9 | return func 10 | 11 | inf = math.inf 12 | nan = math.nan -------------------------------------------------------------------------------- /python/jtorch/nn/init.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | for k,v in jt.nn.init.__dict__.items(): 4 | if callable(v): 5 | globals()[k] = v 6 | 7 | 8 | normal = gauss 9 | normal_ = gauss_ 10 | xavier_normal = xavier_gauss 11 | xavier_normal_ = xavier_gauss_ 12 | 13 | jt.Var.normal_ = normal_ 14 | 15 | -------------------------------------------------------------------------------- /python/jtorch/utils/dtype.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | Dtype = Union[Callable, str] 3 | 4 | def get_string_dtype(dtype): 5 | if callable(dtype): 6 | dtype = dtype.__name__ 7 | if not isinstance(dtype, str): 8 | raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") 9 | return dtype -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | my 2 | .refresh 3 | .DS_Store 4 | __pycache__ 5 | .ipynb_checkpoints/ 6 | .vscode/ 7 | __res/ 8 | perf.data 9 | perf.data.old 10 | *.swp 11 | *.ipynb 12 | *.pdf 13 | *.zip 14 | *.tgz 15 | *.obj 16 | test.py 17 | extern/mkl/mkldnn_lnx*/* 18 | data/ 19 | build/ 20 | venv/ 21 | *.md 22 | !*.src.md 23 | !README.md 24 | !README.cn.md 25 | !CHANGELOG.md 26 | python/jittor.egg-info 27 | python/jtorch.egg-info 28 | dist/ 29 | !doc/source/* 30 | core 31 | __data__ 32 | -------------------------------------------------------------------------------- /python/jtorch/test/test_conflict_func.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | import jittor as jt 5 | 6 | class TestConflictFunc(unittest.TestCase): 7 | def test_max(self): 8 | a = torch.Tensor([1,4,2]) 9 | assert a.max() == 4 10 | v, k = a.max(dim=0) 11 | assert v==4 and k==1 12 | 13 | def test_argsort(self): 14 | a = torch.Tensor([1,4,2]) 15 | k = a.argsort() 16 | assert jt.all_equal(k, [0,2,1]) 17 | 18 | with jt.flag_scope(th_mode=0): 19 | k, v = a.argsort() 20 | assert jt.all_equal(k, [0,2,1]) 21 | 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /python/jtorch/nn/utils/rnn.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | PackedSequence = None 4 | 5 | def pad_sequence(sequences,batch_first=False,padding_value=0.0): 6 | max_f = max([len(s) for s in sequences]) 7 | # max_f = 512 8 | b = len(sequences) 9 | if batch_first: 10 | ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) 11 | for i,s in enumerate(sequences): 12 | ret[i,:len(s)] = s 13 | else: 14 | ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) 15 | for i,s in enumerate(sequences): 16 | ret[:len(s),i] = s 17 | # print(ret.shape) 18 | # ret = ret[:,:406] 19 | return ret 20 | -------------------------------------------------------------------------------- /python/jtorch/test/test_misc.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | class TestMisc(unittest.TestCase): 6 | def test_update_grad(self): 7 | class Net(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) 11 | net = Net() 12 | assert(net.a.requires_grad) 13 | net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) 14 | assert(net.a.requires_grad) 15 | 16 | def test_reshape(self): 17 | a = torch.ones(3,3) 18 | a.requires_grad = True 19 | b = torch.reshape(a, [9]) 20 | assert b.requires_grad == True 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /python/jtorch/utils/pip_publish.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | import sys 5 | 6 | home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") 7 | home_path = os.path.abspath(home_path) 8 | 9 | def callback(func, path, exc_info): 10 | print(f"remove \"{path}\" failed.") 11 | 12 | def rmtree(path): 13 | if os.path.isdir(path): 14 | print(f"remove \"{path}\" recursive.") 15 | shutil.rmtree(path, onerror=callback) 16 | 17 | def remove_tmpfile(): 18 | dist_file = home_path+"/dist" 19 | egg_file = glob.glob(home_path+"/**/*egg-info") 20 | rmtree(dist_file) 21 | for e in egg_file: 22 | rmtree(e) 23 | 24 | def run_cmd(cmd): 25 | print("[CMD]", cmd) 26 | assert os.system(cmd)==0 27 | 28 | os.chdir(home_path) 29 | remove_tmpfile() 30 | 31 | run_cmd(f"{sys.executable} ./setup.py sdist") 32 | run_cmd(f"{sys.executable} -m twine upload dist/*") 33 | 34 | remove_tmpfile() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="jtorch", 8 | version="0.1.8", 9 | author="jtorch", 10 | author_email="jtorch@qq.com", 11 | description="jtorch project", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/JITTorch/jtorch", 15 | project_urls={ 16 | "Bug Tracker": "https://github.com/JITTorch/jtorch/issues", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "Operating System :: OS Independent", 21 | ], 22 | packages=["jtorch", "torch"], 23 | package_dir={"": "python"}, 24 | package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, 25 | python_requires=">=3.7", 26 | install_requires=[ 27 | "jittor>=1.3.8.6", 28 | "requests", 29 | ], 30 | ) -------------------------------------------------------------------------------- /python/jtorch/src/jtorch_core.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common.h" 3 | #include "var_holder.h" 4 | #include "misc/fast_shared_ptr.h" 5 | 6 | namespace jittor { 7 | 8 | // @pyjt(device) 9 | // @attrs(heaptype) 10 | struct Device { 11 | string name; 12 | 13 | // @pyjt(__init__) 14 | Device(const string& name, int ordinal=0); 15 | // @pyjt(__get__type, __str__) 16 | inline string get_type() {return name;} 17 | // @pyjt(__get__index) 18 | inline int index() {return 0;} 19 | }; 20 | 21 | // @pyjt(backward) 22 | void backward(VarHolder* x); 23 | 24 | // @pyjt(grad_set) 25 | void grad_set(VarHolder* x, Maybe v); 26 | // @pyjt(grad_get) 27 | Maybe grad_get(VarHolder* x); 28 | // @pyjt(grad_del) 29 | void grad_del(VarHolder* x); 30 | 31 | // @pyjt(retain_grad_set) 32 | inline void retain_grad_set(VarHolder* x, bool v) { 33 | x->var->flags.set(NodeFlags::_th_require_grad, v); 34 | } 35 | // @pyjt(retain_grad_get) 36 | inline bool retain_grad_get(VarHolder* x) { 37 | return x->var->flags.get(NodeFlags::_th_require_grad); 38 | } 39 | 40 | } -------------------------------------------------------------------------------- /python/jtorch/distributed.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from enum import Enum 3 | import jittor as jt 4 | 5 | 6 | class DistributedDataParallel: 7 | def __new__(cls, model): 8 | return model 9 | 10 | def is_initialized(): 11 | return True 12 | 13 | def get_rank(group=None): 14 | return 0 15 | 16 | def get_world_size(group=None): 17 | return 1 18 | 19 | def get_backend(group=None): 20 | return "nccl" 21 | 22 | def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): 23 | return 1 24 | 25 | def barrier(): 26 | pass 27 | 28 | def is_available(): 29 | return True 30 | 31 | def is_built(): 32 | return True 33 | 34 | class ReduceOp: 35 | SUM = 0 36 | 37 | class GroupMember: 38 | WORLD = 0 39 | 40 | class ProcessGroup: 41 | pass 42 | 43 | class Join: 44 | pass 45 | 46 | dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) 47 | _backend = dist_backend.NCCL 48 | 49 | def is_mpi_available(): 50 | return jt.in_mpi 51 | 52 | def DistributedDataParallel(model, *args, **kw): 53 | return model 54 | -------------------------------------------------------------------------------- /python/jtorch/cuda.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jtorch 3 | 4 | def is_available(): 5 | return jt.has_cuda 6 | 7 | def device_count(): 8 | return int(jt.has_cuda) 9 | 10 | def set_device(device=None): 11 | pass 12 | 13 | def get_rng_state(device=None): 14 | pass 15 | 16 | def current_device(): 17 | return jtorch.device("cuda") 18 | 19 | def mem_get_info(i): 20 | return ("75GB",) 21 | 22 | 23 | class Generator: 24 | def __init__(self): 25 | pass 26 | 27 | def set_state(self, state): 28 | self.state = state 29 | 30 | default_generators = [Generator()] 31 | _lazy_call = lambda func: func() 32 | device = None 33 | 34 | LongTensor = jt.int64 35 | FloatTensor = jt.float 36 | HalfTensor = jt.float16 37 | BoolTensor = jt.bool 38 | 39 | manual_seed = jt.set_global_seed 40 | manual_seed_all = jt.set_global_seed 41 | 42 | def synchronize(): 43 | jt.sync_all(True) 44 | 45 | class Event: 46 | pass 47 | 48 | class Stream: 49 | pass 50 | 51 | from typing import Any 52 | 53 | from .gradscaler import GradScaler 54 | 55 | class autocast: 56 | def __init__(self,**kwargs): 57 | pass 58 | 59 | def __enter__(self,): 60 | pass 61 | 62 | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): 63 | pass 64 | 65 | -------------------------------------------------------------------------------- /python/jtorch/compiler.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor_utils 3 | import glob 4 | import os 5 | from jittor import pyjt_compiler 6 | import sys 7 | from jittor_utils import lock 8 | 9 | 10 | jtorch_path = os.path.dirname(__file__) 11 | cache_path = os.path.join(jt.compiler.cache_path, "jtorch") 12 | # os.makedirs(cache_path, exist_ok=True) 13 | os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True) 14 | 15 | with lock.lock_scope(): 16 | pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path) 17 | 18 | ext_args = 'c[cu]' if jt.has_cuda else 'cc' 19 | files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True) 20 | files += pyjt_gen_src 21 | cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" " 22 | if os.environ.get("use_data_o", "1") == "1": 23 | files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True) 24 | files = [f for f in files if "__data__" not in f] 25 | 26 | 27 | with lock.lock_scope(): 28 | jt.compiler.compile( 29 | jt.compiler.cc_path, 30 | jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags, 31 | files, 32 | "jtorch_core"+jt.compiler.extension_suffix, 33 | obj_dirname="jtorch_objs") 34 | 35 | 36 | with jittor_utils.import_scope(jt.compiler.import_flags): 37 | import jtorch_core as core 38 | 39 | jt.flags.th_mode = 1 40 | -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | dtype = torch.float 5 | device = torch.device("cpu") 6 | # device = torch.device("cuda:0") # Uncomment this to run on GPU 7 | 8 | # Create random input and output data 9 | x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 10 | y = torch.sin(x) 11 | 12 | # Randomly initialize weights 13 | a = torch.randn((), device=device, dtype=dtype) 14 | b = torch.randn((), device=device, dtype=dtype) 15 | c = torch.randn((), device=device, dtype=dtype) 16 | d = torch.randn((), device=device, dtype=dtype) 17 | 18 | learning_rate = 1e-6 19 | for t in range(20000): 20 | # Forward pass: compute predicted y 21 | y_pred = a + b * x + c * x ** 2 + d * x ** 3 22 | 23 | # Compute and print loss 24 | loss = (y_pred - y).pow(2).sum().item() 25 | if t % 1000 == 999: 26 | print(t, loss) 27 | 28 | # Backprop to compute gradients of a, b, c, d with respect to loss 29 | grad_y_pred = 2.0 * (y_pred - y) 30 | grad_a = grad_y_pred.sum() 31 | grad_b = (grad_y_pred * x).sum() 32 | grad_c = (grad_y_pred * x ** 2).sum() 33 | grad_d = (grad_y_pred * x ** 3).sum() 34 | 35 | # Update weights using gradient descent 36 | a -= learning_rate * grad_a 37 | b -= learning_rate * grad_b 38 | c -= learning_rate * grad_c 39 | d -= learning_rate * grad_d 40 | # print(t, torch.liveness_info()) 41 | # torch.sync_all() 42 | 43 | 44 | print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') -------------------------------------------------------------------------------- /python/jtorch/vision/_internally_replaced_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.machinery 2 | import os 3 | 4 | 5 | def _download_file_from_remote_location(fpath: str, url: str) -> None: 6 | pass 7 | 8 | 9 | def _is_remote_location_available() -> bool: 10 | return False 11 | 12 | 13 | def _get_extension_path(lib_name): 14 | 15 | lib_dir = os.path.dirname(__file__) 16 | if os.name == "nt": 17 | # Register the main torchvision library location on the default DLL path 18 | import ctypes 19 | import sys 20 | 21 | kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) 22 | with_load_library_flags = hasattr(kernel32, "AddDllDirectory") 23 | prev_error_mode = kernel32.SetErrorMode(0x0001) 24 | 25 | if with_load_library_flags: 26 | kernel32.AddDllDirectory.restype = ctypes.c_void_p 27 | 28 | if sys.version_info >= (3, 8): 29 | os.add_dll_directory(lib_dir) 30 | elif with_load_library_flags: 31 | res = kernel32.AddDllDirectory(lib_dir) 32 | if res is None: 33 | err = ctypes.WinError(ctypes.get_last_error()) 34 | err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' 35 | raise err 36 | 37 | kernel32.SetErrorMode(prev_error_mode) 38 | 39 | loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) 40 | 41 | extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) 42 | ext_specs = extfinder.find_spec(lib_name) 43 | if ext_specs is None: 44 | raise ImportError 45 | 46 | return ext_specs.origin -------------------------------------------------------------------------------- /python/jtorch/test/test_function.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | class TestFunction(unittest.TestCase): 6 | def test_example1(self): 7 | import jtorch 8 | from jtorch import Function 9 | 10 | class MyFunc(Function): 11 | @staticmethod 12 | def forward(self, x, y): 13 | self.x = x 14 | self.y = y 15 | return x*y, x/y 16 | 17 | @staticmethod 18 | def backward(self, grad0, grad1): 19 | return grad0 * self.y, grad1 * self.x 20 | 21 | a = jtorch.array(3.0) 22 | a.requires_grad = True 23 | b = jtorch.array(4.0) 24 | b.requires_grad = True 25 | func = MyFunc.apply 26 | c,d = func(a, b) 27 | (c+d*3).backward() 28 | assert a.grad.data == 4 29 | assert b.grad.data == 9 30 | 31 | def test_example2(self): 32 | import jtorch as jt 33 | from jtorch import Function 34 | 35 | class MyFunc(Function): 36 | @staticmethod 37 | def forward(self, x, y): 38 | self.x = x 39 | self.y = y 40 | return x*y, x/y 41 | 42 | @staticmethod 43 | def backward(self, grad0, grad1): 44 | assert grad1 is None 45 | return grad0 * self.y, None 46 | a = jt.array(3.0) 47 | a.requires_grad = True 48 | b = jt.array(4.0) 49 | b.requires_grad = True 50 | func = MyFunc.apply 51 | c,d = func(a, b) 52 | d.stop_grad() 53 | da, db = jt.grad(c+d*3, [a, b]) 54 | assert da.data == 4 55 | assert db.data == 0 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /python/jtorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | import jtorch 2 | import jittor as jt 3 | 4 | from jtorch import make_module, Tensor, ModuleMisc, wrapper 5 | 6 | for k,v in jt.nn.__dict__.items(): 7 | if callable(v): 8 | globals()[k] = wrapper(v) 9 | 10 | for k,v in jt.nn.__dict__.items(): 11 | if isinstance(v, type) and issubclass(v, jt.Module): 12 | globals()[k] = make_module(v) 13 | 14 | class Module(ModuleMisc, jt.Module): 15 | 16 | def __call__(self, *args, **kw): 17 | return self.forward(*args, **kw) 18 | 19 | def execute(self, *args, **kw): 20 | return self.forward(*args, **kw) 21 | 22 | 23 | 24 | 25 | def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor: 26 | x = x.clone() 27 | x.requires_grad = requires_grad 28 | x.retains_grad = requires_grad 29 | return x 30 | 31 | def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False): 32 | return jt.nn.embedding(input, weight) 33 | 34 | def dropout(x, p=0.5, training=False): 35 | return jt.nn.dropout(x, p, training) 36 | 37 | 38 | class Flatten(Module): 39 | ''' Flattens the contiguous range of dimensions in a Var. 40 | :param start_dim: the first dimension to be flattened. Defaults: 1. 41 | :type start_dim: int 42 | :param end_dim: the last dimension to be flattened. Defaults: -1. 43 | :type end_dim: int 44 | ''' 45 | def __init__(self, start_dim=1, end_dim=-1): 46 | self.start_dim = start_dim 47 | self.end_dim = end_dim 48 | 49 | def forward(self, x) -> jt.Var: 50 | return x.flatten(self.start_dim, self.end_dim) 51 | 52 | class _IncompatibleKeys: 53 | def __init__(self, missing_keys, unexpected_keys): 54 | self.missing_keys = missing_keys 55 | self.unexpected_keys = unexpected_keys 56 | 57 | _BatchNorm = None 58 | 59 | from . import utils -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad5_optim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | 5 | 6 | # Create Tensors to hold input and outputs. 7 | x = torch.linspace(-math.pi, math.pi, 2000) 8 | y = torch.sin(x) 9 | 10 | # Prepare the input tensor (x, x^2, x^3). 11 | p = torch.tensor([1, 2, 3]) 12 | xx = x.unsqueeze(-1).pow(p) 13 | 14 | # Use the nn package to define our model and loss function. 15 | model = torch.nn.Sequential( 16 | torch.nn.Linear(3, 1), 17 | torch.nn.Flatten(0, 1) 18 | ) 19 | loss_fn = torch.nn.MSELoss(reduction='sum') 20 | 21 | # Use the optim package to define an Optimizer that will update the weights of 22 | # the model for us. Here we will use RMSprop; the optim package contains many other 23 | # optimization algorithms. The first argument to the RMSprop constructor tells the 24 | # optimizer which Tensors it should update. 25 | learning_rate = 1e-3 26 | optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) 27 | for t in range(8000): 28 | # Forward pass: compute predicted y by passing x to the model. 29 | y_pred = model(xx) 30 | 31 | # Compute and print loss. 32 | loss = loss_fn(y_pred, y) 33 | if t % 1000 == 999: 34 | print(t, loss.item()) 35 | 36 | # Before the backward pass, use the optimizer object to zero all of the 37 | # gradients for the variables it will update (which are the learnable 38 | # weights of the model). This is because by default, gradients are 39 | # accumulated in buffers( i.e, not overwritten) whenever .backward() 40 | # is called. Checkout docs of torch.autograd.backward for more details. 41 | optimizer.zero_grad() 42 | 43 | # Backward pass: compute gradient of the loss with respect to model 44 | # parameters 45 | loss.backward() 46 | 47 | # Calling the step function on an Optimizer makes an update to its 48 | # parameters 49 | optimizer.step() 50 | 51 | 52 | linear_layer = model[0] 53 | print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad6_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | 5 | 6 | class Polynomial3(torch.nn.Module): 7 | def __init__(self): 8 | """ 9 | In the constructor we instantiate four parameters and assign them as 10 | member parameters. 11 | """ 12 | super().__init__() 13 | self.a = torch.nn.Parameter(torch.randn(())) 14 | self.b = torch.nn.Parameter(torch.randn(())) 15 | self.c = torch.nn.Parameter(torch.randn(())) 16 | self.d = torch.nn.Parameter(torch.randn(())) 17 | 18 | def forward(self, x): 19 | """ 20 | In the forward function we accept a Tensor of input data and we must return 21 | a Tensor of output data. We can use Modules defined in the constructor as 22 | well as arbitrary operators on Tensors. 23 | """ 24 | return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 25 | 26 | def string(self): 27 | """ 28 | Just like any class in Python, you can also define custom method on PyTorch modules 29 | """ 30 | return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' 31 | 32 | 33 | # Create Tensors to hold input and outputs. 34 | x = torch.linspace(-math.pi, math.pi, 2000) 35 | y = torch.sin(x) 36 | 37 | # Construct our model by instantiating the class defined above 38 | model = Polynomial3() 39 | 40 | # Construct our loss function and an Optimizer. The call to model.parameters() 41 | # in the SGD constructor will contain the learnable parameters (defined 42 | # with torch.nn.Parameter) which are members of the model. 43 | criterion = torch.nn.MSELoss(reduction='sum') 44 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) 45 | for t in range(8000): 46 | # Forward pass: Compute predicted y by passing x to the model 47 | y_pred = model(x) 48 | 49 | # Compute and print loss 50 | loss = criterion(y_pred, y) 51 | if t % 1000 == 999: 52 | print(t, loss.item()) 53 | 54 | # Zero gradients, perform a backward pass, and update the weights. 55 | optimizer.zero_grad() 56 | loss.backward() 57 | optimizer.step() 58 | 59 | print(f'Result: {model.string()}') -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | 5 | dtype = torch.float 6 | device = torch.device("cpu") 7 | # device = torch.device("cuda:0") # Uncomment this to run on GPU 8 | 9 | # Create Tensors to hold input and outputs. 10 | # By default, requires_grad=False, which indicates that we do not need to 11 | # compute gradients with respect to these Tensors during the backward pass. 12 | x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 13 | y = torch.sin(x) 14 | 15 | # Create random Tensors for weights. For a third order polynomial, we need 16 | # 4 weights: y = a + b x + c x^2 + d x^3 17 | # Setting requires_grad=True indicates that we want to compute gradients with 18 | # respect to these Tensors during the backward pass. 19 | a = torch.randn((), device=device, dtype=dtype, requires_grad=True) 20 | b = torch.randn((), device=device, dtype=dtype, requires_grad=True) 21 | c = torch.randn((), device=device, dtype=dtype, requires_grad=True) 22 | d = torch.randn((), device=device, dtype=dtype, requires_grad=True) 23 | 24 | learning_rate = 1e-6 25 | for t in range(20000): 26 | # Forward pass: compute predicted y using operations on Tensors. 27 | y_pred = a + b * x + c * x ** 2 + d * x ** 3 28 | # print(y_pred.requires_grad) 29 | # y_pred.requires_grad = False 30 | 31 | # Compute and print loss using operations on Tensors. 32 | # Now loss is a Tensor of shape (1,) 33 | # loss.item() gets the scalar value held in the loss. 34 | loss = (y_pred - y).pow(2).sum() 35 | if t % 1000 == 990: 36 | print(t, loss.item()) 37 | 38 | # Use autograd to compute the backward pass. This call will compute the 39 | # gradient of loss with respect to all Tensors with requires_grad=True. 40 | # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding 41 | # the gradient of the loss with respect to a, b, c, d respectively. 42 | # torch.backward(loss) 43 | loss.backward() 44 | 45 | # Manually update weights using gradient descent. Wrap in torch.no_grad() 46 | # because weights have requires_grad=True, but we don't need to track this 47 | # in autograd. 48 | with torch.no_grad(): 49 | a -= learning_rate * a.grad 50 | b -= learning_rate * b.grad 51 | c -= learning_rate * c.grad 52 | d -= learning_rate * d.grad 53 | 54 | # Manually zero the gradients after updating weights 55 | a.grad = None 56 | b.grad = None 57 | c.grad = None 58 | d.grad = None 59 | 60 | print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad7_dynet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import torch 4 | import math 5 | 6 | 7 | class DynamicNet(torch.nn.Module): 8 | def __init__(self): 9 | """ 10 | In the constructor we instantiate five parameters and assign them as members. 11 | """ 12 | super().__init__() 13 | self.a = torch.nn.Parameter(torch.randn(())) 14 | self.b = torch.nn.Parameter(torch.randn(())) 15 | self.c = torch.nn.Parameter(torch.randn(())) 16 | self.d = torch.nn.Parameter(torch.randn(())) 17 | self.e = torch.nn.Parameter(torch.randn(())) 18 | 19 | def forward(self, x): 20 | """ 21 | For the forward pass of the model, we randomly choose either 4, 5 22 | and reuse the e parameter to compute the contribution of these orders. 23 | 24 | Since each forward pass builds a dynamic computation graph, we can use normal 25 | Python control-flow operators like loops or conditional statements when 26 | defining the forward pass of the model. 27 | 28 | Here we also see that it is perfectly safe to reuse the same parameter many 29 | times when defining a computational graph. 30 | """ 31 | y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 32 | for exp in range(4, random.randint(4, 6)): 33 | y = y + self.e * x ** exp 34 | return y 35 | 36 | def string(self): 37 | """ 38 | Just like any class in Python, you can also define custom method on PyTorch modules 39 | """ 40 | return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' 41 | 42 | 43 | # Create Tensors to hold input and outputs. 44 | x = torch.linspace(-math.pi, math.pi, 2000) 45 | y = torch.sin(x) 46 | 47 | # Construct our model by instantiating the class defined above 48 | model = DynamicNet() 49 | 50 | # Construct our loss function and an Optimizer. Training this strange model with 51 | # vanilla stochastic gradient descent is tough, so we use momentum 52 | criterion = torch.nn.MSELoss(reduction='sum') 53 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) 54 | for t in range(60000): 55 | # Forward pass: Compute predicted y by passing x to the model 56 | y_pred = model(x) 57 | 58 | # Compute and print loss 59 | loss = criterion(y_pred, y) 60 | if t % 2000 == 1999: 61 | print(t, loss.item()) 62 | 63 | # Zero gradients, perform a backward pass, and update the weights. 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | # print(torch.liveness_info()) 68 | 69 | print(f'Result: {model.string()}') -------------------------------------------------------------------------------- /python/jtorch/test/test_tutorial.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import os 4 | import subprocess as sp 5 | import sys 6 | 7 | def check_two(cmd, parser=None, checker=None): 8 | jtorch_out = sp.getoutput(cmd) 9 | print("=========JTORCH OUT==========") 10 | print(jtorch_out) 11 | torch_out = sp.getoutput("PYTHONPATH= "+cmd) 12 | print("=========TORCH OUT==========") 13 | print(torch_out) 14 | if parser: 15 | torch_out = parser(torch_out) 16 | jtorch_out = parser(jtorch_out) 17 | if checker: 18 | checker(torch_out, jtorch_out) 19 | else: 20 | assert torch_out == jtorch_out 21 | return jtorch_out, torch_out 22 | 23 | jtorch_path = os.path.join(os.path.dirname(__file__), "..") 24 | # come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 25 | class TestTutorial(unittest.TestCase): 26 | def test_auto_grad1(self): 27 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py", 28 | parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), 29 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) 30 | def test_auto_grad2(self): 31 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py", 32 | parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), 33 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) 34 | def test_auto_grad3(self): 35 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py", 36 | parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float), 37 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) 38 | def test_auto_grad4(self): 39 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py", 40 | parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), 41 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) 42 | def test_auto_grad5(self): 43 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py", 44 | parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), 45 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) 46 | def test_auto_grad6(self): 47 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py", 48 | parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), 49 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) 50 | def test_auto_grad7(self): 51 | check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py", 52 | parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float), 53 | checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) 54 | 55 | if __name__ == "__main__": 56 | unittest.main() -------------------------------------------------------------------------------- /python/jtorch/src/jtorch_core.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "pyjt/py_obj_holder.h" 3 | #include "utils/str_utils.h" 4 | #include "jtorch_core.h" 5 | #include "graph.h" 6 | #include "grad.h" 7 | #include "ops/op_register.h" 8 | 9 | namespace jittor { 10 | 11 | void pyjt_def_all(PyObject* m); 12 | 13 | EXTERN_LIB void setter_use_cuda(int value); 14 | 15 | Device::Device(const string& name, int ordinal) : name(name) { 16 | if (startswith(name, "cpu")) 17 | setter_use_cuda(0); 18 | else 19 | setter_use_cuda(1); 20 | } 21 | 22 | unordered_map grad_backup; 23 | EXTERN_LIB void (*_var_free_hook)(Var*); 24 | EXTERN_LIB unordered_map* _grad_backup_ptr; 25 | 26 | void jtorch_var_free_hook(Var* v) { 27 | auto iter = grad_backup.find(v->id); 28 | if (iter != grad_backup.end()) { 29 | grad_backup.erase(iter); 30 | } 31 | } 32 | 33 | void jtorch_init() { 34 | _var_free_hook = &jtorch_var_free_hook; 35 | _grad_backup_ptr = &grad_backup; 36 | } 37 | 38 | inline static VarPtr& get_grad(Var* v) { 39 | return grad_backup[v->id]; 40 | } 41 | static auto make_binary = get_op_info("binary") 42 | .get_constructor(); 43 | 44 | inline static void add_grad(VarPtr& a, VarPtr&& b) { 45 | if (!a) a = move(b); 46 | else { 47 | a = make_binary(a, b, ns_add); 48 | } 49 | } 50 | 51 | 52 | void grad_set(VarHolder* x, Maybe v) { 53 | if (!v) { 54 | grad_del(x); 55 | return; 56 | } 57 | grad_backup[x->var->id] = v.ptr->var; 58 | } 59 | 60 | Maybe grad_get(VarHolder* x) { 61 | auto iter = grad_backup.find(x->var->id); 62 | if (iter != grad_backup.end()) { 63 | if (!iter->second.ptr) return nullptr; 64 | return new VarHolder(iter->second.ptr); 65 | } 66 | return nullptr; 67 | } 68 | 69 | void grad_del(VarHolder* x) { 70 | auto iter = grad_backup.find(x->var->id); 71 | if (iter != grad_backup.end()) 72 | grad_backup.erase(iter); 73 | } 74 | 75 | void backward(VarHolder* x) { 76 | vector gnodes({x->var}); 77 | bfs_backward(gnodes, [&](Node* node) { 78 | if (node->is_stop_grad()) 79 | return false; 80 | return true; 81 | }); 82 | vector targets; 83 | for (auto* node : gnodes) { 84 | if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad)) 85 | targets.push_back(node->var()); 86 | } 87 | auto grads = grad(x->var, targets); 88 | for (int i=0; im_doc = "Inner c++ core of jtorch"; 100 | jittor::pyjt_def_all(m); 101 | } 102 | PYJT_MODULE_INIT(jtorch_core); 103 | -------------------------------------------------------------------------------- /python/torch/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["FIX_TORCH_ERROR"] = "0" 3 | 4 | import jittor as jt 5 | import jtorch 6 | from jtorch import * 7 | __version__ = "2.0.0" 8 | 9 | import sys 10 | def load_mod(name): 11 | exec("import "+name) 12 | return eval(name) 13 | 14 | autograd = sys.modules["torch.autograd"] = load_mod("jtorch.autograd") 15 | 16 | cuda = load_mod("jtorch.cuda") 17 | cuda.amp = cuda 18 | sys.modules["torch.cuda"] = load_mod("jtorch.cuda") 19 | sys.modules["torch.npu"] = load_mod("jtorch.cuda") 20 | npu = sys.modules["torch.npu"] 21 | sys.modules["torch.cuda.amp"] = load_mod("jtorch.cuda") 22 | sys.modules['torch.optim'] = load_mod("jtorch.optim") 23 | sys.modules['torch.optim.lr_scheduler'] = load_mod("jtorch.optim") 24 | jtorch.optim.lr_scheduler = jtorch.optim 25 | 26 | sys.modules["torch.nn"] = load_mod("jtorch.nn") 27 | sys.modules["torch.nn.functional"] = load_mod("jtorch.nn") 28 | sys.modules["torch.nn.parallel"] = load_mod("jtorch.distributed") 29 | jtorch.nn.parallel = load_mod("jtorch.distributed") 30 | sys.modules["torch.nn.modules"] = load_mod("jtorch.nn") 31 | sys.modules['torch.nn.modules.module'] = load_mod("jtorch.nn") 32 | jtorch.nn.module = jtorch.nn 33 | jtorch.nn.modules = jtorch.nn 34 | sys.modules["torch.nn.parameter"] = load_mod("jtorch.nn") 35 | jtorch.nn.parameter = jtorch.nn 36 | sys.modules["torch.nn.utils"] = load_mod("jtorch.nn") 37 | jtorch.nn.functional = jtorch.nn 38 | sys.modules["torch.utils"] = load_mod("jtorch.utils") 39 | sys.modules["torch._utils"] = load_mod("jtorch.utils") 40 | _utils = jtorch.utils 41 | sys.modules["torch.utils.data"] = load_mod("jtorch.utils.data") 42 | sys.modules["torch.utils.data.sampler"] = load_mod("jtorch.utils.data") 43 | sys.modules["torch.utils.data.distributed"] = load_mod("jtorch.utils.data") 44 | jtorch.utils.data.sampler = jtorch.utils.data 45 | sys.modules["torch.utils.checkpoint"] = load_mod("jtorch.utils.checkpoint") 46 | sys.modules["torch.utils.hooks"] = load_mod("jtorch.utils.hooks") 47 | 48 | distributed = sys.modules["torch.distributed"] = load_mod("jtorch.distributed") 49 | sys.modules['torch.distributed.algorithms.join'] = distributed 50 | sys.modules['torch.backends.mps'] = distributed 51 | sys.modules['torch.backends'] = distributed 52 | backends = distributed 53 | backends.mps = distributed 54 | 55 | sys.modules['torch.fx'] = load_mod("jtorch.fx") 56 | 57 | sys.modules["torch.nn.parallel"] = load_mod("jtorch.distributed") 58 | sys.modules["torch.nn.parallel.distributed"] = load_mod("jtorch.distributed") 59 | _C = sys.modules["torch._C"] = load_mod("jtorch.misc") 60 | _six = sys.modules["torch._six"] = load_mod("jtorch.misc") 61 | jit = sys.modules["torch.jit"] = load_mod("jtorch.misc") 62 | 63 | sys.modules["torchvision"] = load_mod("jtorch.vision") 64 | sys.modules["torchvision.datasets"] = load_mod("jtorch.vision.datasets") 65 | sys.modules["torchvision.transforms"] = load_mod("jtorch.vision.transforms") 66 | sys.modules["torchvision.transforms.functional"] = load_mod("jtorch.vision.transforms") 67 | jtorch.vision.transforms.functional = jtorch.vision.transforms 68 | 69 | -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | 5 | 6 | # Create Tensors to hold input and outputs. 7 | x = torch.linspace(-math.pi, math.pi, 2000) 8 | y = torch.sin(x) 9 | 10 | # For this example, the output y is a linear function of (x, x^2, x^3), so 11 | # we can consider it as a linear layer neural network. Let's prepare the 12 | # tensor (x, x^2, x^3). 13 | p = torch.tensor([1, 2, 3]) 14 | xx = x.unsqueeze(-1).pow(p) 15 | 16 | # In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape 17 | # (3,), for this case, broadcasting semantics will apply to obtain a tensor 18 | # of shape (2000, 3) 19 | 20 | # Use the nn package to define our model as a sequence of layers. nn.Sequential 21 | # is a Module which contains other Modules, and applies them in sequence to 22 | # produce its output. The Linear Module computes output from input using a 23 | # linear function, and holds internal Tensors for its weight and bias. 24 | # The Flatten layer flatens the output of the linear layer to a 1D tensor, 25 | # to match the shape of `y`. 26 | model = torch.nn.Sequential( 27 | torch.nn.Linear(3, 1), 28 | torch.nn.Flatten(0, 1) 29 | ) 30 | 31 | # The nn package also contains definitions of popular loss functions; in this 32 | # case we will use Mean Squared Error (MSE) as our loss function. 33 | loss_fn = torch.nn.MSELoss(reduction='sum') 34 | # print(model[0].weight.requires_grad) 35 | 36 | learning_rate = 1e-6 37 | for t in range(8000): 38 | 39 | # Forward pass: compute predicted y by passing x to the model. Module objects 40 | # override the __call__ operator so you can call them like functions. When 41 | # doing so you pass a Tensor of input data to the Module and it produces 42 | # a Tensor of output data. 43 | y_pred = model(xx) 44 | 45 | # Compute and print loss. We pass Tensors containing the predicted and true 46 | # values of y, and the loss function returns a Tensor containing the 47 | # loss. 48 | loss = loss_fn(y_pred, y) 49 | if t % 1000 == 999: 50 | print(t, loss.item()) 51 | 52 | # Zero the gradients before running the backward pass. 53 | model.zero_grad() 54 | 55 | # Backward pass: compute gradient of the loss with respect to all the learnable 56 | # parameters of the model. Internally, the parameters of each Module are stored 57 | # in Tensors with requires_grad=True, so this call will compute gradients for 58 | # all learnable parameters in the model. 59 | loss.backward() 60 | 61 | # Update the weights using gradient descent. Each parameter is a Tensor, so 62 | # we can access its gradients like we did before. 63 | with torch.no_grad(): 64 | for param in model.parameters(): 65 | param -= learning_rate * param.grad 66 | 67 | # You can access the first layer of `model` like accessing the first item of a list 68 | linear_layer = model[0] 69 | 70 | # For linear layer, its parameters are stored as `weight` and `bias`. 71 | print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') -------------------------------------------------------------------------------- /python/jtorch/tutorial/auto_grad3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import math 4 | 5 | 6 | class LegendrePolynomial3(torch.autograd.Function): 7 | """ 8 | We can implement our own custom autograd Functions by subclassing 9 | torch.autograd.Function and implementing the forward and backward passes 10 | which operate on Tensors. 11 | """ 12 | 13 | @staticmethod 14 | def forward(ctx, input): 15 | """ 16 | In the forward pass we receive a Tensor containing the input and return 17 | a Tensor containing the output. ctx is a context object that can be used 18 | to stash information for backward computation. You can cache arbitrary 19 | objects for use in the backward pass using the ctx.save_for_backward method. 20 | """ 21 | ctx.save_for_backward(input) 22 | return 0.5 * (5 * input ** 3 - 3 * input) 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | """ 27 | In the backward pass we receive a Tensor containing the gradient of the loss 28 | with respect to the output, and we need to compute the gradient of the loss 29 | with respect to the input. 30 | """ 31 | input, = ctx.saved_tensors 32 | return grad_output * 1.5 * (5 * input ** 2 - 1) 33 | 34 | 35 | dtype = torch.float 36 | device = torch.device("cpu") 37 | # device = torch.device("cuda:0") # Uncomment this to run on GPU 38 | 39 | # Create Tensors to hold input and outputs. 40 | # By default, requires_grad=False, which indicates that we do not need to 41 | # compute gradients with respect to these Tensors during the backward pass. 42 | x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 43 | y = torch.sin(x) 44 | 45 | # Create random Tensors for weights. For this example, we need 46 | # 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized 47 | # not too far from the correct result to ensure convergence. 48 | # Setting requires_grad=True indicates that we want to compute gradients with 49 | # respect to these Tensors during the backward pass. 50 | a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) 51 | b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) 52 | c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) 53 | d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) 54 | 55 | learning_rate = 5e-6 56 | for t in range(2000): 57 | # To apply our Function, we use Function.apply method. We alias this as 'P3'. 58 | P3 = LegendrePolynomial3.apply 59 | 60 | # Forward pass: compute predicted y using operations; we compute 61 | # P3 using our custom autograd operation. 62 | y_pred = a + b * P3(c + d * x) 63 | 64 | # Compute and print loss 65 | loss = (y_pred - y).pow(2).sum() 66 | if t % 100 == 99: 67 | print(t, loss.item()) 68 | 69 | # Use autograd to compute the backward pass. 70 | loss.backward() 71 | 72 | # Update weights using gradient descent 73 | with torch.no_grad(): 74 | a -= learning_rate * a.grad 75 | b -= learning_rate * b.grad 76 | c -= learning_rate * c.grad 77 | d -= learning_rate * d.grad 78 | 79 | # Manually zero the gradients after updating weights 80 | a.grad = None 81 | b.grad = None 82 | c.grad = None 83 | d.grad = None 84 | 85 | print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)') -------------------------------------------------------------------------------- /python/jtorch/tutorial/quickstart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # from jtorch.utils import DataLoader 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets 6 | from torchvision.transforms import ToTensor 7 | 8 | # Download training data from open datasets. 9 | training_data = datasets.FashionMNIST( 10 | root="data", 11 | train=True, 12 | download=True, 13 | transform=ToTensor(), 14 | ) 15 | 16 | # Download test data from open datasets. 17 | test_data = datasets.FashionMNIST( 18 | root="data", 19 | train=False, 20 | download=True, 21 | transform=ToTensor(), 22 | ) 23 | 24 | batch_size = 64 25 | 26 | # Create data loaders. 27 | train_dataloader = DataLoader(training_data, batch_size=batch_size) 28 | test_dataloader = DataLoader(test_data, batch_size=batch_size) 29 | 30 | print(len(train_dataloader)) 31 | for X, y in test_dataloader: 32 | print(f"Shape of X [N, C, H, W]: {X.shape}") 33 | print(f"Shape of y: {y.shape} {y.dtype}") 34 | break 35 | 36 | # Get cpu or gpu device for training. 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | print(f"Using {device} device") 39 | 40 | # Define model 41 | class NeuralNetwork(nn.Module): 42 | def __init__(self): 43 | super(NeuralNetwork, self).__init__() 44 | self.flatten = nn.Flatten() 45 | self.linear_relu_stack = nn.Sequential( 46 | nn.Linear(28*28, 512), 47 | nn.ReLU(), 48 | nn.Linear(512, 512), 49 | nn.ReLU(), 50 | nn.Linear(512, 10) 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.flatten(x) 55 | logits = self.linear_relu_stack(x) 56 | return logits 57 | 58 | model = NeuralNetwork().to(device) 59 | print(model) 60 | 61 | 62 | loss_fn = nn.CrossEntropyLoss() 63 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 64 | 65 | def train(dataloader, model, loss_fn, optimizer): 66 | size = len(dataloader.dataset) 67 | model.train() 68 | for batch, (X, y) in enumerate(dataloader): 69 | X, y = X.to(device), y.to(device) 70 | 71 | # Compute prediction error 72 | pred = model(X) 73 | loss = loss_fn(pred, y) 74 | 75 | # Backpropagation 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | 80 | if batch % 100 == 0: 81 | loss, current = loss.item(), batch * len(X) 82 | print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") 83 | 84 | def test(dataloader, model, loss_fn): 85 | size = len(dataloader.dataset) 86 | num_batches = len(dataloader) 87 | model.eval() 88 | test_loss, correct = 0, 0 89 | with torch.no_grad(): 90 | for X, y in dataloader: 91 | X, y = X.to(device), y.to(device) 92 | pred = model(X) 93 | test_loss += loss_fn(pred, y).item() 94 | correct += (pred.argmax(1) == y).type(torch.float).sum().item() 95 | test_loss /= num_batches 96 | correct /= size 97 | print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") 98 | 99 | 100 | epochs = 5 101 | test(test_dataloader, model, loss_fn) 102 | for t in range(epochs): 103 | print(f"Epoch {t+1}\n-------------------------------") 104 | train(train_dataloader, model, loss_fn, optimizer) 105 | test(test_dataloader, model, loss_fn) 106 | print("Done!") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JTorch: 一个全兼容 PyTorch 接口的高性能动态编译深度学习框架 2 | 3 | JTorch 是一个完全兼容 PyTorch 接口的深度学习框架,同时基于 Jittor 元算子与统一计算图特性的加持,实现高性能动态编译,同时,用户原来使用的PyTorch代码,不需要进行任何修改,即可加速运行。总结而言,JTorch具有以下几点优势: 4 | 5 | 1. 零成本:完全兼容原生 PyTorch 接口, 用户代码不需要作任何更改。 6 | 2. 速度快:通过统一计算图执行方法,JTorch可以实现对代码的动态编译以及加速,相比原版 PyTorch拥有更好的性能。 7 | 3. 支持硬件多:JTorch底层通过元算子抽象,可以快速兼容适配多种人工智能芯片。 8 | 4. 兼容生态: 对原有 PyTorch 生态形成兼容,如各种第三方开发的 PyTorch 模型库。 9 | 5. 兼容计图: JTorch完全兼容计图,计图中的接口可以混合使用,性能高。 10 | 6. 完全自主可控: JTorch 具有完全的自主知识产权,用户完全不需要安装 Torch,即可直接使用。 11 | 12 | 13 | JTorch相关连接: 14 | 15 | * [Github](https://github.com/JITTorch/jtorch) 16 | * [Jittor 论坛](https://discuss.jittor.org/) 17 | * 即时通信: QQ Group(761222083) 18 | 19 | # 安装与测试 20 | 21 | 安装方法如下: 22 | 23 | ``` 24 | python3 -m pip install jtorch 25 | ``` 26 | 27 | 注意,请使用python3.7及以上的版本 28 | 29 | 运行简单测试: 30 | 31 | ``` 32 | python3 -m jtorch.test.test_tutorial 33 | ``` 34 | 35 | # 快速入门 36 | 37 | ## 使用 JTorch 实现简单动态网络(PyTorch兼容) 38 | 39 | ```python 40 | # -*- coding: utf-8 -*- 41 | import random 42 | import torch 43 | import math 44 | 45 | 46 | class DynamicNet(torch.nn.Module): 47 | def __init__(self): 48 | """ 49 | In the constructor we instantiate five parameters and assign them as members. 50 | """ 51 | super().__init__() 52 | self.a = torch.nn.Parameter(torch.randn(())) 53 | self.b = torch.nn.Parameter(torch.randn(())) 54 | self.c = torch.nn.Parameter(torch.randn(())) 55 | self.d = torch.nn.Parameter(torch.randn(())) 56 | self.e = torch.nn.Parameter(torch.randn(())) 57 | 58 | def forward(self, x): 59 | """ 60 | For the forward pass of the model, we randomly choose either 4, 5 61 | and reuse the e parameter to compute the contribution of these orders. 62 | 63 | Since each forward pass builds a dynamic computation graph, we can use normal 64 | Python control-flow operators like loops or conditional statements when 65 | defining the forward pass of the model. 66 | 67 | Here we also see that it is perfectly safe to reuse the same parameter many 68 | times when defining a computational graph. 69 | """ 70 | y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 71 | for exp in range(4, random.randint(4, 6)): 72 | y = y + self.e * x ** exp 73 | return y 74 | 75 | def string(self): 76 | """ 77 | Just like any class in Python, you can also define custom method on PyTorch modules 78 | """ 79 | return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' 80 | 81 | 82 | # Create Tensors to hold input and outputs. 83 | x = torch.linspace(-math.pi, math.pi, 2000) 84 | y = torch.sin(x) 85 | 86 | # Construct our model by instantiating the class defined above 87 | model = DynamicNet() 88 | 89 | # Construct our loss function and an Optimizer. Training this strange model with 90 | # vanilla stochastic gradient descent is tough, so we use momentum 91 | criterion = torch.nn.MSELoss(reduction='sum') 92 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) 93 | for t in range(60000): 94 | # Forward pass: Compute predicted y by passing x to the model 95 | y_pred = model(x) 96 | 97 | # Compute and print loss 98 | loss = criterion(y_pred, y) 99 | if t % 2000 == 1999: 100 | print(t, loss.item()) 101 | 102 | # Zero gradients, perform a backward pass, and update the weights. 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | # print(torch.liveness_info()) 107 | 108 | print(f'Result: {model.string()}') 109 | ``` 110 | 111 | ## 联系我们 112 | 113 | 电子邮件:jtorch@qq.com 114 | 115 | 提出issue:https://github.com/jittorch/jtorch/issues 116 | 117 | QQ 群:761222083 118 | 119 | 120 | ## 版权声明 121 | 122 | 如LICENSE.txt文件中所示, JTorch 使用Apache 2.0版权协议。 123 | 124 | -------------------------------------------------------------------------------- /python/jtorch/vision/datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, List, Optional, Tuple 3 | 4 | import torch 5 | import torch.utils.data as data 6 | 7 | from ..utils import _log_api_usage_once 8 | 9 | 10 | class VisionDataset(data.Dataset): 11 | """ 12 | Base Class For making datasets which are compatible with torchvision. 13 | It is necessary to override the ``__getitem__`` and ``__len__`` method. 14 | Args: 15 | root (string): Root directory of dataset. 16 | transforms (callable, optional): A function/transforms that takes in 17 | an image and a label and returns the transformed versions of both. 18 | transform (callable, optional): A function/transform that takes in an PIL image 19 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 20 | target_transform (callable, optional): A function/transform that takes in the 21 | target and transforms it. 22 | .. note:: 23 | :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. 24 | """ 25 | 26 | _repr_indent = 4 27 | 28 | def __init__( 29 | self, 30 | root: str, 31 | transforms: Optional[Callable] = None, 32 | transform: Optional[Callable] = None, 33 | target_transform: Optional[Callable] = None, 34 | ) -> None: 35 | self.root = root 36 | 37 | has_transforms = transforms is not None 38 | has_separate_transform = transform is not None or target_transform is not None 39 | if has_transforms and has_separate_transform: 40 | raise ValueError("Only transforms or transform/target_transform can be passed as argument") 41 | 42 | # for backwards-compatibility 43 | self.transform = transform 44 | self.target_transform = target_transform 45 | 46 | if has_separate_transform: 47 | transforms = StandardTransform(transform, target_transform) 48 | self.transforms = transforms 49 | 50 | def __getitem__(self, index: int) -> Any: 51 | """ 52 | Args: 53 | index (int): Index 54 | Returns: 55 | (Any): Sample and meta data, optionally transformed by the respective transforms. 56 | """ 57 | raise NotImplementedError 58 | 59 | def __len__(self) -> int: 60 | raise NotImplementedError 61 | 62 | def __repr__(self) -> str: 63 | head = "Dataset " + self.__class__.__name__ 64 | body = [f"Number of datapoints: {self.__len__()}"] 65 | if self.root is not None: 66 | body.append(f"Root location: {self.root}") 67 | body += self.extra_repr().splitlines() 68 | if hasattr(self, "transforms") and self.transforms is not None: 69 | body += [repr(self.transforms)] 70 | lines = [head] + [" " * self._repr_indent + line for line in body] 71 | return "\n".join(lines) 72 | 73 | def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: 74 | lines = transform.__repr__().splitlines() 75 | return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] 76 | 77 | def extra_repr(self) -> str: 78 | return "" 79 | 80 | 81 | class StandardTransform: 82 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 83 | self.transform = transform 84 | self.target_transform = target_transform 85 | 86 | def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: 87 | if self.transform is not None: 88 | input = self.transform(input) 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | return input, target 92 | 93 | def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: 94 | lines = transform.__repr__().splitlines() 95 | return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] 96 | 97 | def __repr__(self) -> str: 98 | body = [self.__class__.__name__] 99 | if self.transform is not None: 100 | body += self._format_transform_repr(self.transform, "Transform: ") 101 | if self.target_transform is not None: 102 | body += self._format_transform_repr(self.target_transform, "Target transform: ") 103 | 104 | return "\n".join(body) -------------------------------------------------------------------------------- /python/jtorch/autograd.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import Var 3 | from collections.abc import Sequence, Mapping 4 | 5 | Variable = Var 6 | 7 | class FunctionContext: 8 | def save_for_backward(self, *args): 9 | self.saved_tensors = args 10 | 11 | class Function: 12 | ''' Function Module for customized backward operations 13 | 14 | Example 1 (Function can have multiple input and multiple output, and user 15 | can store value for backward computation):: 16 | 17 | import jtorch 18 | from jtorch import Function 19 | 20 | class MyFunc(Function): 21 | @staticmethod 22 | def forward(self, x, y): 23 | self.x = x 24 | self.y = y 25 | return x*y, x/y 26 | 27 | @staticmethod 28 | def backward(self, grad0, grad1): 29 | return grad0 * self.y, grad1 * self.x 30 | 31 | a = jtorch.array(3.0) 32 | a.requires_grad = True 33 | b = jtorch.array(4.0) 34 | b.requires_grad = True 35 | func = MyFunc.apply 36 | c,d = func(a, b) 37 | (c+d*3).backward() 38 | assert a.grad.data == 4 39 | assert b.grad.data == 9 40 | 41 | Example 2(Function can return None for no gradiant, and gradiant 42 | can also be None):: 43 | 44 | import jtorch 45 | from jtorch import Function 46 | 47 | class MyFunc(Function): 48 | @staticmethod 49 | def forward(self, x, y): 50 | self.x = x 51 | self.y = y 52 | return x*y, x/y 53 | 54 | @staticmethod 55 | def backward(self, grad0, grad1): 56 | assert grad1 is None 57 | return grad0 * self.y, None 58 | a = jt.array(3.0) 59 | a.requires_grad = True 60 | b = jt.array(4.0) 61 | b.requires_grad = True 62 | func = MyFunc.apply 63 | c,d = func(a, b) 64 | d.stop_grad() 65 | da, db = jt.grad(c+d*3, [a, b]) 66 | assert da.data == 4 67 | assert db.data == 0 68 | 69 | ''' 70 | def __call__(self, *args): 71 | backup = args 72 | args = list(args) 73 | taped_inputs = [] 74 | taped_outputs = [] 75 | input_mask = [-1] * len(args) 76 | for i,v in enumerate(args): 77 | if isinstance(v, Var): 78 | if v.is_stop_grad(): 79 | # -2 in input_mask represents it is stop_grad 80 | input_mask[i] = -2 81 | continue 82 | v = v.tape() 83 | input_mask[i] = len(taped_inputs) 84 | args[i] = v 85 | taped_inputs.append(v) 86 | ctx = FunctionContext() 87 | ori_res = self.forward(ctx, *args) 88 | # ori_res = self.execute(*args) 89 | if not isinstance(ori_res, Sequence): 90 | res = [ori_res] 91 | else: 92 | res = list(ori_res) 93 | output_mask = [-1] * len(res) 94 | for i,v in enumerate(res): 95 | if isinstance(v, Var): 96 | v = v.tape() 97 | output_mask[i] = len(taped_outputs) 98 | res[i] = v 99 | taped_outputs.append(v) 100 | ctx.input_mask = input_mask 101 | ctx.output_mask = output_mask 102 | # tape output and input together so 103 | # backward treat them as one operator 104 | jt.tape_together(taped_inputs, taped_outputs, 105 | lambda *args: self._grad(ctx, self, *args)) 106 | if isinstance(ori_res, Sequence): 107 | return res 108 | else: 109 | return res[0] 110 | 111 | @staticmethod 112 | def _grad(ctx, func, *args): 113 | new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask ) 114 | ret = func.backward(ctx, *new_args) 115 | if not isinstance(ret, Sequence): 116 | ret = (ret,) 117 | new_ret = [] 118 | for i, r in enumerate(ret): 119 | j = ctx.input_mask[i] 120 | if j<0: 121 | # -2 in input_mask represents it is stop_grad 122 | assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ 123 | "because the input value is not jittor variable." 124 | else: 125 | new_ret.append(r) 126 | return new_ret 127 | 128 | def dfs(self, parents, k, callback, callback_leave=None): 129 | pass 130 | 131 | @classmethod 132 | def apply(cls, *args, **kw): 133 | func = cls() 134 | return func(*args, **kw) 135 | -------------------------------------------------------------------------------- /python/jtorch/utils/data.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.dataset 3 | from jittor.dataset import Dataset as JDataset 4 | 5 | from collections import namedtuple 6 | from typing import Any, Callable, Iterable, Optional, Sequence, Union 7 | 8 | 9 | class Dataset: 10 | def __getitem__(self, index): 11 | raise NotImplementedError 12 | 13 | class IterableDataset: 14 | def __iter__(self): 15 | raise NotImplementedError 16 | 17 | 18 | class DataLoader(JDataset): 19 | def __init__(self, dataset, 20 | batch_size: Optional[int] = 1, 21 | shuffle: Optional[bool] = False, 22 | sampler = None, 23 | batch_sampler = None, 24 | num_workers: int = 0, 25 | collate_fn = None, 26 | pin_memory: bool = False, 27 | drop_last: bool = False, 28 | timeout: float = 0, 29 | worker_init_fn = None, 30 | multiprocessing_context=None, 31 | generator=None, 32 | *, prefetch_factor: int = 2, 33 | persistent_workers: bool = False, 34 | pin_memory_device: str = "") -> None: 35 | super().__init__(batch_size=batch_size, 36 | shuffle=shuffle, 37 | num_workers=num_workers, 38 | drop_last=drop_last) 39 | 40 | unsupported_kwargs = { 41 | "batch_sampler": batch_sampler, 42 | "pin_memory": pin_memory, 43 | "timeout": timeout, 44 | "worker_init_fn": worker_init_fn, 45 | "multiprocessing_context": multiprocessing_context, 46 | "generator": generator, 47 | "persistent_workers": persistent_workers, 48 | "pin_memory_device": pin_memory_device 49 | } 50 | for kwarg, value in unsupported_kwargs.items(): 51 | if value: 52 | jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}") 53 | 54 | self.dataset = dataset 55 | self.collate_fn = collate_fn 56 | self.sampler = sampler 57 | 58 | if not isinstance(dataset, IterableDataset): 59 | self.total_len = len(dataset) 60 | else: 61 | # TODO: support multiple worker for iterable dataset 62 | assert(num_workers == 0) 63 | 64 | def collate_batch(self, batch): 65 | if self.collate_fn is not None: 66 | return self.collate_fn(batch) 67 | else: 68 | return super().collate_batch(batch) 69 | 70 | def __getitem__(self, i): 71 | return self.dataset[i] 72 | 73 | def __iter__(self): 74 | if isinstance(self.dataset, IterableDataset): 75 | return self.inner_iter() 76 | else: 77 | return super().__iter__() 78 | 79 | def inner_iter(self): 80 | current_batch = [] 81 | 82 | if jt.world_size > 1: 83 | assert self.batch_size % jt.world_size == 0, \ 84 | f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}" 85 | real_batch_size = int(self.batch_size / jt.world_size) 86 | else: 87 | real_batch_size = self.batch_size 88 | 89 | for element in self.dataset: 90 | current_batch.append(element) 91 | 92 | if len(current_batch) == real_batch_size: 93 | current_batch = self.collate_batch(current_batch) 94 | current_batch = self.to_jittor(current_batch) 95 | yield current_batch 96 | current_batch = [] 97 | 98 | if not self.drop_last and len(current_batch) > 0: 99 | current_batch = self.collate_batch(current_batch) 100 | yield self.to_jittor(current_batch) 101 | 102 | def get_worker_info(): 103 | # always return the fake worker info 104 | return namedtuple('WorkerInfo', 'id num_workers')(0, 1) 105 | 106 | class RandomSampler(jt.dataset.RandomSampler): 107 | def __init__(self, dataset, generator=None, **kwargs): 108 | super().__init__(dataset, **kwargs) 109 | 110 | def __iter__(self): 111 | if getattr(self.dataset, "support_random_access", True): 112 | return super().__iter__() 113 | else: 114 | self.dataset.shuffle() 115 | return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) 116 | 117 | class DistributedSampler(jt.dataset.Sampler): 118 | def __init__(self, sampler: RandomSampler): 119 | assert(isinstance(sampler, RandomSampler)) 120 | self.sampler = sampler 121 | 122 | def set_epoch(self, epoch: int): 123 | ### do nothing, let jittor's inner dataset handle 124 | pass 125 | 126 | def __iter__(self): 127 | return self.sampler.__iter__() 128 | 129 | def __len__(self): 130 | return self.sampler.__len__() 131 | 132 | BatchSampler = jt.dataset.BatchSampler 133 | Sampler = jt.dataset.Sampler 134 | SequentialSampler = jt.dataset.SequentialSampler 135 | SubsetRandomSampler = jt.dataset.SubsetRandomSampler 136 | 137 | TensorDataset = Dataset 138 | -------------------------------------------------------------------------------- /python/jtorch/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["FIX_TORCH_ERROR"] = "0" 3 | 4 | import jittor as jt 5 | from jittor import * 6 | org_int = int = type(1) 7 | org_float = float = type(1.0) 8 | org_bool = bool = type(True) 9 | 10 | import jtorch.compiler 11 | 12 | import jtorch_core 13 | from jtorch_core import * 14 | 15 | device.__reduce__ = lambda self: (device, (self.type,)) 16 | device.__module__ = "jtorch" 17 | jt.jittor_core.device = device 18 | 19 | 20 | def handle_dtype(args, kw, dtype): 21 | def convert(x): 22 | if isinstance(x, jt.Var): 23 | return x.cast(dtype) 24 | return x 25 | if dtype is not None: 26 | if args is not None: 27 | if isinstance(args, (tuple,list)): 28 | args = [ convert(a) for a in args ] 29 | else: 30 | args = convert(x) 31 | if kw is not None: 32 | kw = { k:convert(v) for k,v in kw.items() } 33 | return args, kw 34 | 35 | def get_args_names(func): 36 | import inspect 37 | spec = inspect.getfullargspec(func) 38 | return spec[0] + spec[4] 39 | 40 | def wrapper(func): 41 | has_dtype = False 42 | if hasattr(func, "__code__"): 43 | has_dtype = "dtype" in get_args_names(func) 44 | def inner(*args, **kw): 45 | requires_grad = None 46 | dtype = None 47 | if "requires_grad" in kw: 48 | requires_grad = kw["requires_grad"] 49 | del kw["requires_grad"] 50 | if not has_dtype and "dtype" in kw: 51 | dtype = kw["dtype"] 52 | del kw["dtype"] 53 | if "device" in kw: 54 | del kw["device"] 55 | args, kw = handle_dtype(args, kw, dtype) 56 | ret = func(*args, **kw) 57 | if isinstance(ret, jt.Var): 58 | if requires_grad is not None: 59 | ret.requires_grad = requires_grad 60 | if dtype is not None: 61 | ret.astype(dtype) 62 | return ret 63 | return inner 64 | 65 | 66 | import inspect 67 | _wrapper_keys = set(["shape", "start", "size"]) 68 | _wrapper_keys.add("x") 69 | for k,v in list(globals().items()): 70 | if callable(v) and not isinstance(v, type): 71 | try: 72 | spec = inspect.getfullargspec(v) 73 | args_name = spec[0] 74 | if len(args_name) and args_name[0] in _wrapper_keys: 75 | globals()[k] = wrapper(v) 76 | elif spec.varargs in _wrapper_keys: 77 | globals()[k] = wrapper(v) 78 | except: 79 | pass 80 | 81 | def empty(*size, dtype=jt.float32, device=None, requires_grad=False): 82 | if len(size) == 1 and not isinstance(size[0], org_int): 83 | size = size[0] 84 | return jt.empty(size, dtype) 85 | 86 | Tensor = Var 87 | 88 | Tensor.backward = lambda x: jtorch_core.backward(x) 89 | Tensor.grad = property(grad_get, grad_set, grad_del) 90 | Tensor.retains_grad = property(retain_grad_get, retain_grad_set) 91 | def retain_grad(x:Tensor, value:bool=True): 92 | x.retains_grad = value 93 | return value 94 | Tensor.retain_grad = retain_grad 95 | 96 | Tensor.dim = lambda self: self.ndim 97 | Tensor.ndimension = lambda self: self.ndim 98 | Tensor.nelement = lambda self: self.numel() 99 | Tensor.cuda = lambda self: self 100 | def device_get(x:Tensor): 101 | return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") 102 | Tensor.device = property(device_get) 103 | 104 | def argmax(x: Var, dim=None, keepdim: bool = False): 105 | return jt.argmax(x, dim, keepdim)[0] 106 | Tensor.argmax = argmax 107 | 108 | def tensor_type(x: Var, dtype=None, **kwargs): 109 | if dtype: 110 | return x.astype(dtype) 111 | else: 112 | return x.dtype 113 | Tensor.type = tensor_type 114 | 115 | def is_floating_point(x: Var): 116 | return "float" in str(x.dtype) 117 | Tensor.is_floating_point = is_floating_point 118 | 119 | from . import autograd 120 | from .autograd import * 121 | 122 | tensor = wrapper(array) 123 | 124 | def mod_zero_grad(self): 125 | for p in self.parameters(): 126 | p.grad = None 127 | Module.zero_grad = mod_zero_grad 128 | 129 | class ModuleMisc: 130 | def parameters(self): 131 | return iter(super().parameters()) 132 | 133 | def load_state_dict(self, state_dict, strict=False): 134 | return super().load_state_dict(state_dict) 135 | 136 | def to(self, device,dtype=None): 137 | ''' do nothing but return its self''' 138 | return self 139 | def register_parameter(self,name,data): 140 | self.name = data 141 | 142 | def make_module(cls): 143 | class TMod(ModuleMisc, cls): 144 | def __init__(self, *args, **kw): 145 | dtype = None 146 | if "dtype" in kw: 147 | dtype = kw["dtype"] 148 | del kw["dtype"] 149 | self._dtype = dtype 150 | with jt.flag_scope(th_mode=0): 151 | super().__init__(*args, **kw) 152 | for k,v in self.__dict__.items(): 153 | if not k.startswith("_") and isinstance(v, Var) \ 154 | and v.requires_grad: 155 | v.retain_grad() 156 | if dtype is not None and isinstance(v, Var): 157 | v.assign(v.cast(dtype)) 158 | def __call__(self, *args, **kw): 159 | args, kw = handle_dtype(args, kw, self._dtype) 160 | # if forward is override by user, call forward 161 | if self.__class__.forward is not TMod.forward: 162 | return self.forward(*args, **kw) 163 | return self.execute(*args, **kw) 164 | def forward(self, *args, **kw): 165 | args, kw = handle_dtype(args, kw, self._dtype) 166 | return self.execute(*args, **kw) 167 | 168 | @property 169 | def training(self): 170 | if not hasattr(self, "is_train"): 171 | self.is_train = True 172 | return self.is_train 173 | @training.setter 174 | def training(self, value): 175 | self.is_train = value 176 | 177 | TMod.__name__ = cls.__name__ 178 | return TMod 179 | 180 | import jtorch.cuda 181 | import jtorch.nn 182 | from jtorch.nn import Module, Parameter 183 | import jtorch.optim 184 | 185 | from jtorch.utils.dtype import Dtype, get_string_dtype 186 | 187 | def frombuffer(buffer: bytearray, 188 | *, 189 | dtype: Dtype, 190 | count: int = -1, 191 | offset: int = 0, 192 | requires_grad: bool = True) -> Tensor: 193 | dtype = get_string_dtype(dtype) 194 | tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) 195 | if requires_grad and tensor.dtype.is_float(): 196 | tensor.requires_grad = True 197 | return tensor 198 | 199 | def conflict_wrapper(origin_func, new_func): 200 | def wrapper(*args, **kw): 201 | if jt.flags.th_mode: 202 | return new_func(*args, **kw) 203 | else: 204 | return origin_func(*args, **kw) 205 | return wrapper 206 | 207 | def min(*args, **kw): 208 | dim = None 209 | if len(args) >= 2 and isinstance(args[1], org_int): 210 | dim = args[1] 211 | elif "dim" in kw and isinstance(kw["dim"], org_int): 212 | dim = kw["dim"] 213 | if dim is not None: 214 | k, v = jt.argmin(*args, **kw) 215 | return v, k 216 | elif len(args) == 2 and isinstance(args[1], jt.Var): 217 | return jt.minimum(args[0], args[1]) 218 | else: 219 | return jt.min(*args, **kw) 220 | Tensor.min = conflict_wrapper(jt.min, min) 221 | 222 | def max(*args, **kw): 223 | dim = None 224 | if "dim" in kw: 225 | x = kw["dim"] 226 | if len(args) >= 2 and isinstance(args[1], org_int): 227 | dim = args[1] 228 | elif "dim" in kw and isinstance(kw["dim"], org_int): 229 | dim = kw["dim"] 230 | if dim is not None: 231 | k, v = jt.argmax(*args, **kw) 232 | return v, k 233 | elif len(args) == 2 and isinstance(args[1], jt.Var): 234 | return jt.maximum(args[0], args[1]) 235 | else: 236 | return jt.max(*args, **kw) 237 | Tensor.max = conflict_wrapper(jt.max, max) 238 | 239 | def argsort(*args, **kw): 240 | k, v = jt.argsort(*args, **kw) 241 | return k 242 | Tensor.argsort = conflict_wrapper(jt.argsort, argsort) 243 | 244 | LongTensor = jt.int64 245 | FloatTensor = jt.float 246 | HalfTensor = jt.float16 247 | BoolTensor = jt.bool 248 | 249 | class JDType: 250 | def __init__(self, func, str): 251 | self.func = func 252 | self.str = str 253 | self.__name__ = str.split(".")[-1] 254 | def __call__(self, *args, **kw): 255 | return self.func(*args, **kw) 256 | def __str__(self): 257 | return self.str 258 | def is_floating_point(self): 259 | return "float" in str(self.str) 260 | 261 | int8 = JDType(jt.int8, "torch.int8") 262 | int16 = JDType(jt.int16, "torch.int16") 263 | int = int32 = JDType(jt.int32, "torch.int32") 264 | long = int64 = JDType(jt.int64, "torch.int64") 265 | 266 | half = float16 = JDType(jt.float16, "torch.float16") 267 | float = float32 = JDType(jt.float32, "torch.float32") 268 | double = float64 = JDType(jt.float64, "torch.float64") 269 | bfloat16 = "bfloat16" # TODO 270 | complex64 = "complex64" # TODO 271 | complex128 = "complex128" # TODO 272 | 273 | def load(path,**kwargs): 274 | def _to_jittor(data): 275 | if isinstance(data,dict): 276 | return {k:_to_jittor(d) for k,d in data.items()} 277 | if isinstance(data,list): 278 | return [_to_jittor(d) for d in data] 279 | if isinstance(data,np.ndarray): 280 | return jt.array(data) 281 | return data 282 | data = jt.load(path) 283 | 284 | return _to_jittor(data) 285 | 286 | def is_tensor(x): 287 | return isinstance(x, Tensor) 288 | 289 | manual_seed = jt.set_global_seed 290 | jt.flags.amp_level = 3 291 | Size = jt.NanoVector 292 | 293 | class Generator: 294 | def manual_seed(self,seed): 295 | pass 296 | 297 | 298 | from . import fx 299 | 300 | 301 | _default_type = "float32" 302 | 303 | def get_default_dtype(): 304 | return _default_type 305 | def set_default_dtype(dtype): 306 | global _default_type 307 | _default_type = dtype 308 | 309 | dtype = JDType 310 | 311 | def div(x,y,rounding_mode="floor"): 312 | assert rounding_mode == "floor" 313 | z = (x / y) 314 | if rounding_mode == "floor": 315 | z = z.floor() 316 | if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): 317 | z = z.int32() 318 | return z -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 FittenTech. All Rights Reserved 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright (c) 2022 FittenTech. All Rights Reserved. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /python/jtorch/vision/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import contextlib 3 | import gzip 4 | import hashlib 5 | import itertools 6 | import lzma 7 | import os 8 | import os.path 9 | import pathlib 10 | import re 11 | import sys 12 | import tarfile 13 | import urllib 14 | import urllib.error 15 | import urllib.request 16 | import warnings 17 | import zipfile 18 | from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar 19 | from urllib.parse import urlparse 20 | 21 | import numpy as np 22 | import requests 23 | import torch 24 | from tqdm import tqdm 25 | 26 | from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available 27 | 28 | USER_AGENT = "pytorch/vision" 29 | 30 | 31 | def _save_response_content( 32 | content: Iterator[bytes], 33 | destination: str, 34 | length: Optional[int] = None, 35 | ) -> None: 36 | with open(destination, "wb") as fh, tqdm(total=length) as pbar: 37 | for chunk in content: 38 | # filter out keep-alive new chunks 39 | if not chunk: 40 | continue 41 | 42 | fh.write(chunk) 43 | pbar.update(len(chunk)) 44 | 45 | 46 | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: 47 | with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: 48 | _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) 49 | 50 | 51 | def gen_bar_updater() -> Callable[[int, int, int], None]: 52 | warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") 53 | pbar = tqdm(total=None) 54 | 55 | def bar_update(count, block_size, total_size): 56 | if pbar.total is None and total_size: 57 | pbar.total = total_size 58 | progress_bytes = count * block_size 59 | pbar.update(progress_bytes - pbar.n) 60 | 61 | return bar_update 62 | 63 | 64 | def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: 65 | # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are 66 | # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without 67 | # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. 68 | if sys.version_info >= (3, 9): 69 | md5 = hashlib.md5(usedforsecurity=False) 70 | else: 71 | md5 = hashlib.md5() 72 | with open(fpath, "rb") as f: 73 | for chunk in iter(lambda: f.read(chunk_size), b""): 74 | md5.update(chunk) 75 | return md5.hexdigest() 76 | 77 | 78 | def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: 79 | return md5 == calculate_md5(fpath, **kwargs) 80 | 81 | 82 | def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: 83 | if not os.path.isfile(fpath): 84 | return False 85 | if md5 is None: 86 | return True 87 | return check_md5(fpath, md5) 88 | 89 | 90 | def _get_redirect_url(url: str, max_hops: int = 3) -> str: 91 | initial_url = url 92 | headers = {"Method": "HEAD", "User-Agent": USER_AGENT} 93 | 94 | for _ in range(max_hops + 1): 95 | with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: 96 | if response.url == url or response.url is None: 97 | return url 98 | 99 | url = response.url 100 | else: 101 | raise RecursionError( 102 | f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." 103 | ) 104 | 105 | 106 | def _get_google_drive_file_id(url: str) -> Optional[str]: 107 | parts = urlparse(url) 108 | 109 | if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: 110 | return None 111 | 112 | match = re.match(r"/file/d/(?P[^/]*)", parts.path) 113 | if match is None: 114 | return None 115 | 116 | return match.group("id") 117 | 118 | 119 | def download_url( 120 | url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 121 | ) -> None: 122 | """Download a file from a url and place it in root. 123 | 124 | Args: 125 | url (str): URL to download file from 126 | root (str): Directory to place downloaded file in 127 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 128 | md5 (str, optional): MD5 checksum of the download. If None, do not check 129 | max_redirect_hops (int, optional): Maximum number of redirect hops allowed 130 | """ 131 | root = os.path.expanduser(root) 132 | if not filename: 133 | filename = os.path.basename(url) 134 | fpath = os.path.join(root, filename) 135 | 136 | os.makedirs(root, exist_ok=True) 137 | 138 | # check if file is already present locally 139 | if check_integrity(fpath, md5): 140 | print("Using downloaded and verified file: " + fpath) 141 | return 142 | 143 | if _is_remote_location_available(): 144 | _download_file_from_remote_location(fpath, url) 145 | else: 146 | # expand redirect chain if needed 147 | url = _get_redirect_url(url, max_hops=max_redirect_hops) 148 | 149 | # check if file is located on Google Drive 150 | file_id = _get_google_drive_file_id(url) 151 | if file_id is not None: 152 | return download_file_from_google_drive(file_id, root, filename, md5) 153 | 154 | # download the file 155 | try: 156 | print("Downloading " + url + " to " + fpath) 157 | _urlretrieve(url, fpath) 158 | except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] 159 | if url[:5] == "https": 160 | url = url.replace("https:", "http:") 161 | print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) 162 | _urlretrieve(url, fpath) 163 | else: 164 | raise e 165 | 166 | # check integrity of downloaded file 167 | if not check_integrity(fpath, md5): 168 | raise RuntimeError("File not found or corrupted.") 169 | 170 | 171 | def list_dir(root: str, prefix: bool = False) -> List[str]: 172 | """List all directories at a given root 173 | 174 | Args: 175 | root (str): Path to directory whose folders need to be listed 176 | prefix (bool, optional): If true, prepends the path to each result, otherwise 177 | only returns the name of the directories found 178 | """ 179 | root = os.path.expanduser(root) 180 | directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] 181 | if prefix is True: 182 | directories = [os.path.join(root, d) for d in directories] 183 | return directories 184 | 185 | 186 | def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: 187 | """List all files ending with a suffix at a given root 188 | 189 | Args: 190 | root (str): Path to directory whose folders need to be listed 191 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 192 | It uses the Python "str.endswith" method and is passed directly 193 | prefix (bool, optional): If true, prepends the path to each result, otherwise 194 | only returns the name of the files found 195 | """ 196 | root = os.path.expanduser(root) 197 | files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] 198 | if prefix is True: 199 | files = [os.path.join(root, d) for d in files] 200 | return files 201 | 202 | 203 | def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: 204 | content = response.iter_content(chunk_size) 205 | first_chunk = None 206 | # filter out keep-alive new chunks 207 | while not first_chunk: 208 | first_chunk = next(content) 209 | content = itertools.chain([first_chunk], content) 210 | 211 | try: 212 | match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) 213 | api_response = match["api_response"] if match is not None else None 214 | except UnicodeDecodeError: 215 | api_response = None 216 | return api_response, content 217 | 218 | 219 | def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): 220 | """Download a Google Drive file from and place it in root. 221 | 222 | Args: 223 | file_id (str): id of file to be downloaded 224 | root (str): Directory to place downloaded file in 225 | filename (str, optional): Name to save the file under. If None, use the id of the file. 226 | md5 (str, optional): MD5 checksum of the download. If None, do not check 227 | """ 228 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 229 | 230 | root = os.path.expanduser(root) 231 | if not filename: 232 | filename = file_id 233 | fpath = os.path.join(root, filename) 234 | 235 | os.makedirs(root, exist_ok=True) 236 | 237 | if check_integrity(fpath, md5): 238 | print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") 239 | return 240 | 241 | url = "https://drive.google.com/uc" 242 | params = dict(id=file_id, export="download") 243 | with requests.Session() as session: 244 | response = session.get(url, params=params, stream=True) 245 | 246 | for key, value in response.cookies.items(): 247 | if key.startswith("download_warning"): 248 | token = value 249 | break 250 | else: 251 | api_response, content = _extract_gdrive_api_response(response) 252 | token = "t" if api_response == "Virus scan warning" else None 253 | 254 | if token is not None: 255 | response = session.get(url, params=dict(params, confirm=token), stream=True) 256 | api_response, content = _extract_gdrive_api_response(response) 257 | 258 | if api_response == "Quota exceeded": 259 | raise RuntimeError( 260 | f"The daily quota of the file {filename} is exceeded and it " 261 | f"can't be downloaded. This is a limitation of Google Drive " 262 | f"and can only be overcome by trying again later." 263 | ) 264 | 265 | _save_response_content(content, fpath) 266 | 267 | # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text 268 | if os.stat(fpath).st_size < 10 * 1024: 269 | with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: 270 | text = fh.read() 271 | # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 272 | if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): 273 | warnings.warn( 274 | f"We detected some HTML elements in the downloaded file. " 275 | f"This most likely means that the download triggered an unhandled API response by GDrive. " 276 | f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " 277 | f"the response:\n\n{text}" 278 | ) 279 | 280 | if md5 and not check_md5(fpath, md5): 281 | raise RuntimeError( 282 | f"The MD5 checksum of the download file {fpath} does not match the one on record." 283 | f"Please delete the file and try again. " 284 | f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." 285 | ) 286 | 287 | 288 | def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: 289 | with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: 290 | tar.extractall(to_path) 291 | 292 | 293 | _ZIP_COMPRESSION_MAP: Dict[str, int] = { 294 | ".bz2": zipfile.ZIP_BZIP2, 295 | ".xz": zipfile.ZIP_LZMA, 296 | } 297 | 298 | 299 | def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: 300 | with zipfile.ZipFile( 301 | from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED 302 | ) as zip: 303 | zip.extractall(to_path) 304 | 305 | 306 | _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { 307 | ".tar": _extract_tar, 308 | ".zip": _extract_zip, 309 | } 310 | _COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { 311 | ".bz2": bz2.open, 312 | ".gz": gzip.open, 313 | ".xz": lzma.open, 314 | } 315 | _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { 316 | ".tbz": (".tar", ".bz2"), 317 | ".tbz2": (".tar", ".bz2"), 318 | ".tgz": (".tar", ".gz"), 319 | } 320 | 321 | 322 | def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: 323 | """Detect the archive type and/or compression of a file. 324 | 325 | Args: 326 | file (str): the filename 327 | 328 | Returns: 329 | (tuple): tuple of suffix, archive type, and compression 330 | 331 | Raises: 332 | RuntimeError: if file has no suffix or suffix is not supported 333 | """ 334 | suffixes = pathlib.Path(file).suffixes 335 | if not suffixes: 336 | raise RuntimeError( 337 | f"File '{file}' has no suffixes that could be used to detect the archive type and compression." 338 | ) 339 | suffix = suffixes[-1] 340 | 341 | # check if the suffix is a known alias 342 | if suffix in _FILE_TYPE_ALIASES: 343 | return (suffix, *_FILE_TYPE_ALIASES[suffix]) 344 | 345 | # check if the suffix is an archive type 346 | if suffix in _ARCHIVE_EXTRACTORS: 347 | return suffix, suffix, None 348 | 349 | # check if the suffix is a compression 350 | if suffix in _COMPRESSED_FILE_OPENERS: 351 | # check for suffix hierarchy 352 | if len(suffixes) > 1: 353 | suffix2 = suffixes[-2] 354 | 355 | # check if the suffix2 is an archive type 356 | if suffix2 in _ARCHIVE_EXTRACTORS: 357 | return suffix2 + suffix, suffix2, suffix 358 | 359 | return suffix, None, suffix 360 | 361 | valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) 362 | raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") 363 | 364 | 365 | def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: 366 | r"""Decompress a file. 367 | 368 | The compression is automatically detected from the file name. 369 | 370 | Args: 371 | from_path (str): Path to the file to be decompressed. 372 | to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. 373 | remove_finished (bool): If ``True``, remove the file after the extraction. 374 | 375 | Returns: 376 | (str): Path to the decompressed file. 377 | """ 378 | suffix, archive_type, compression = _detect_file_type(from_path) 379 | if not compression: 380 | raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") 381 | 382 | if to_path is None: 383 | to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") 384 | 385 | # We don't need to check for a missing key here, since this was already done in _detect_file_type() 386 | compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] 387 | 388 | with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: 389 | wfh.write(rfh.read()) 390 | 391 | if remove_finished: 392 | os.remove(from_path) 393 | 394 | return to_path 395 | 396 | 397 | def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: 398 | """Extract an archive. 399 | 400 | The archive type and a possible compression is automatically detected from the file name. If the file is compressed 401 | but not an archive the call is dispatched to :func:`decompress`. 402 | 403 | Args: 404 | from_path (str): Path to the file to be extracted. 405 | to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is 406 | used. 407 | remove_finished (bool): If ``True``, remove the file after the extraction. 408 | 409 | Returns: 410 | (str): Path to the directory the file was extracted to. 411 | """ 412 | if to_path is None: 413 | to_path = os.path.dirname(from_path) 414 | 415 | suffix, archive_type, compression = _detect_file_type(from_path) 416 | if not archive_type: 417 | return _decompress( 418 | from_path, 419 | os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), 420 | remove_finished=remove_finished, 421 | ) 422 | 423 | # We don't need to check for a missing key here, since this was already done in _detect_file_type() 424 | extractor = _ARCHIVE_EXTRACTORS[archive_type] 425 | 426 | extractor(from_path, to_path, compression) 427 | if remove_finished: 428 | os.remove(from_path) 429 | 430 | return to_path 431 | 432 | 433 | def download_and_extract_archive( 434 | url: str, 435 | download_root: str, 436 | extract_root: Optional[str] = None, 437 | filename: Optional[str] = None, 438 | md5: Optional[str] = None, 439 | remove_finished: bool = False, 440 | ) -> None: 441 | download_root = os.path.expanduser(download_root) 442 | if extract_root is None: 443 | extract_root = download_root 444 | if not filename: 445 | filename = os.path.basename(url) 446 | 447 | download_url(url, download_root, filename, md5) 448 | 449 | archive = os.path.join(download_root, filename) 450 | print(f"Extracting {archive} to {extract_root}") 451 | extract_archive(archive, extract_root, remove_finished) 452 | 453 | 454 | def iterable_to_str(iterable: Iterable) -> str: 455 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 456 | 457 | 458 | T = TypeVar("T", str, bytes) 459 | 460 | 461 | def verify_str_arg( 462 | value: T, 463 | arg: Optional[str] = None, 464 | valid_values: Optional[Iterable[T]] = None, 465 | custom_msg: Optional[str] = None, 466 | ) -> T: 467 | if not isinstance(value, torch._six.string_classes): 468 | if arg is None: 469 | msg = "Expected type str, but got type {type}." 470 | else: 471 | msg = "Expected type str for argument {arg}, but got type {type}." 472 | msg = msg.format(type=type(value), arg=arg) 473 | raise ValueError(msg) 474 | 475 | if valid_values is None: 476 | return value 477 | 478 | if value not in valid_values: 479 | if custom_msg is not None: 480 | msg = custom_msg 481 | else: 482 | msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." 483 | msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) 484 | raise ValueError(msg) 485 | 486 | return value 487 | 488 | 489 | def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: 490 | """Read file in .pfm format. Might contain either 1 or 3 channels of data. 491 | 492 | Args: 493 | file_name (str): Path to the file. 494 | slice_channels (int): Number of channels to slice out of the file. 495 | Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. 496 | """ 497 | 498 | with open(file_name, "rb") as f: 499 | header = f.readline().rstrip() 500 | if header not in [b"PF", b"Pf"]: 501 | raise ValueError("Invalid PFM file") 502 | 503 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) 504 | if not dim_match: 505 | raise Exception("Malformed PFM header.") 506 | w, h = (int(dim) for dim in dim_match.groups()) 507 | 508 | scale = float(f.readline().rstrip()) 509 | if scale < 0: # little-endian 510 | endian = "<" 511 | scale = -scale 512 | else: 513 | endian = ">" # big-endian 514 | 515 | data = np.fromfile(f, dtype=endian + "f") 516 | 517 | pfm_channels = 3 if header == b"PF" else 1 518 | 519 | data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) 520 | data = np.flip(data, axis=1) # flip on h dimension 521 | data = data[:slice_channels, :, :] 522 | return data.astype(np.float32) -------------------------------------------------------------------------------- /python/jtorch/vision/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import os.path 4 | import shutil 5 | import string 6 | import sys 7 | import warnings 8 | from typing import Any, Callable, Dict, List, Optional, Tuple 9 | from urllib.error import URLError 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | 15 | from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg 16 | from .vision import VisionDataset 17 | 18 | 19 | class MNIST(VisionDataset): 20 | """`MNIST `_ Dataset. 21 | 22 | Args: 23 | root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` 24 | and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. 25 | train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, 26 | otherwise from ``t10k-images-idx3-ubyte``. 27 | download (bool, optional): If True, downloads the dataset from the internet and 28 | puts it in root directory. If dataset is already downloaded, it is not 29 | downloaded again. 30 | transform (callable, optional): A function/transform that takes in an PIL image 31 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 32 | target_transform (callable, optional): A function/transform that takes in the 33 | target and transforms it. 34 | """ 35 | 36 | mirrors = [ 37 | "http://yann.lecun.com/exdb/mnist/", 38 | "https://ossci-datasets.s3.amazonaws.com/mnist/", 39 | ] 40 | 41 | resources = [ 42 | ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), 43 | ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), 44 | ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), 45 | ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), 46 | ] 47 | 48 | training_file = "training.pt" 49 | test_file = "test.pt" 50 | classes = [ 51 | "0 - zero", 52 | "1 - one", 53 | "2 - two", 54 | "3 - three", 55 | "4 - four", 56 | "5 - five", 57 | "6 - six", 58 | "7 - seven", 59 | "8 - eight", 60 | "9 - nine", 61 | ] 62 | 63 | @property 64 | def train_labels(self): 65 | warnings.warn("train_labels has been renamed targets") 66 | return self.targets 67 | 68 | @property 69 | def test_labels(self): 70 | warnings.warn("test_labels has been renamed targets") 71 | return self.targets 72 | 73 | @property 74 | def train_data(self): 75 | warnings.warn("train_data has been renamed data") 76 | return self.data 77 | 78 | @property 79 | def test_data(self): 80 | warnings.warn("test_data has been renamed data") 81 | return self.data 82 | 83 | def __init__( 84 | self, 85 | root: str, 86 | train: bool = True, 87 | transform: Optional[Callable] = None, 88 | target_transform: Optional[Callable] = None, 89 | download: bool = False, 90 | ) -> None: 91 | super().__init__(root, transform=transform, target_transform=target_transform) 92 | self.train = train # training set or test set 93 | 94 | if self._check_legacy_exist(): 95 | self.data, self.targets = self._load_legacy_data() 96 | return 97 | 98 | if download: 99 | self.download() 100 | 101 | if not self._check_exists(): 102 | raise RuntimeError("Dataset not found. You can use download=True to download it") 103 | 104 | self.data, self.targets = self._load_data() 105 | 106 | def _check_legacy_exist(self): 107 | processed_folder_exists = os.path.exists(self.processed_folder) 108 | if not processed_folder_exists: 109 | return False 110 | 111 | return all( 112 | check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) 113 | ) 114 | 115 | def _load_legacy_data(self): 116 | # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data 117 | # directly. 118 | data_file = self.training_file if self.train else self.test_file 119 | return torch.load(os.path.join(self.processed_folder, data_file)) 120 | 121 | def _load_data(self): 122 | image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" 123 | data = read_image_file(os.path.join(self.raw_folder, image_file)) 124 | 125 | label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" 126 | targets = read_label_file(os.path.join(self.raw_folder, label_file)) 127 | 128 | return data, targets 129 | 130 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 131 | """ 132 | Args: 133 | index (int): Index 134 | 135 | Returns: 136 | tuple: (image, target) where target is index of the target class. 137 | """ 138 | img, target = self.data[index], int(self.targets[index]) 139 | 140 | # doing this so that it is consistent with all other datasets 141 | # to return a PIL Image 142 | img = Image.fromarray(img.numpy(), mode="L") 143 | 144 | if self.transform is not None: 145 | img = self.transform(img) 146 | 147 | if self.target_transform is not None: 148 | target = self.target_transform(target) 149 | 150 | return img, target 151 | 152 | def __len__(self) -> int: 153 | return len(self.data) 154 | 155 | @property 156 | def raw_folder(self) -> str: 157 | return os.path.join(self.root, self.__class__.__name__, "raw") 158 | 159 | @property 160 | def processed_folder(self) -> str: 161 | return os.path.join(self.root, self.__class__.__name__, "processed") 162 | 163 | @property 164 | def class_to_idx(self) -> Dict[str, int]: 165 | return {_class: i for i, _class in enumerate(self.classes)} 166 | 167 | def _check_exists(self) -> bool: 168 | return all( 169 | check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) 170 | for url, _ in self.resources 171 | ) 172 | 173 | def download(self) -> None: 174 | """Download the MNIST data if it doesn't exist already.""" 175 | 176 | if self._check_exists(): 177 | return 178 | 179 | os.makedirs(self.raw_folder, exist_ok=True) 180 | 181 | # download files 182 | for filename, md5 in self.resources: 183 | for mirror in self.mirrors: 184 | url = f"{mirror}{filename}" 185 | try: 186 | print(f"Downloading {url}") 187 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 188 | except URLError as error: 189 | print(f"Failed to download (trying next):\n{error}") 190 | continue 191 | finally: 192 | print() 193 | break 194 | else: 195 | raise RuntimeError(f"Error downloading {filename}") 196 | 197 | def extra_repr(self) -> str: 198 | split = "Train" if self.train is True else "Test" 199 | return f"Split: {split}" 200 | 201 | 202 | class FashionMNIST(MNIST): 203 | """`Fashion-MNIST `_ Dataset. 204 | 205 | Args: 206 | root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` 207 | and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. 208 | train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, 209 | otherwise from ``t10k-images-idx3-ubyte``. 210 | download (bool, optional): If True, downloads the dataset from the internet and 211 | puts it in root directory. If dataset is already downloaded, it is not 212 | downloaded again. 213 | transform (callable, optional): A function/transform that takes in an PIL image 214 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 215 | target_transform (callable, optional): A function/transform that takes in the 216 | target and transforms it. 217 | """ 218 | 219 | mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] 220 | 221 | resources = [ 222 | ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), 223 | ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), 224 | ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), 225 | ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), 226 | ] 227 | classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] 228 | 229 | 230 | class KMNIST(MNIST): 231 | """`Kuzushiji-MNIST `_ Dataset. 232 | 233 | Args: 234 | root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` 235 | and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. 236 | train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, 237 | otherwise from ``t10k-images-idx3-ubyte``. 238 | download (bool, optional): If True, downloads the dataset from the internet and 239 | puts it in root directory. If dataset is already downloaded, it is not 240 | downloaded again. 241 | transform (callable, optional): A function/transform that takes in an PIL image 242 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 243 | target_transform (callable, optional): A function/transform that takes in the 244 | target and transforms it. 245 | """ 246 | 247 | mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] 248 | 249 | resources = [ 250 | ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), 251 | ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), 252 | ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), 253 | ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), 254 | ] 255 | classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] 256 | 257 | 258 | class EMNIST(MNIST): 259 | """`EMNIST `_ Dataset. 260 | 261 | Args: 262 | root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` 263 | and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. 264 | split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, 265 | ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies 266 | which one to use. 267 | train (bool, optional): If True, creates dataset from ``training.pt``, 268 | otherwise from ``test.pt``. 269 | download (bool, optional): If True, downloads the dataset from the internet and 270 | puts it in root directory. If dataset is already downloaded, it is not 271 | downloaded again. 272 | transform (callable, optional): A function/transform that takes in an PIL image 273 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 274 | target_transform (callable, optional): A function/transform that takes in the 275 | target and transforms it. 276 | """ 277 | 278 | url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" 279 | md5 = "58c8d27c78d21e728a6bc7b3cc06412e" 280 | splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") 281 | # Merged Classes assumes Same structure for both uppercase and lowercase version 282 | _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} 283 | _all_classes = set(string.digits + string.ascii_letters) 284 | classes_split_dict = { 285 | "byclass": sorted(list(_all_classes)), 286 | "bymerge": sorted(list(_all_classes - _merged_classes)), 287 | "balanced": sorted(list(_all_classes - _merged_classes)), 288 | "letters": ["N/A"] + list(string.ascii_lowercase), 289 | "digits": list(string.digits), 290 | "mnist": list(string.digits), 291 | } 292 | 293 | def __init__(self, root: str, split: str, **kwargs: Any) -> None: 294 | self.split = verify_str_arg(split, "split", self.splits) 295 | self.training_file = self._training_file(split) 296 | self.test_file = self._test_file(split) 297 | super().__init__(root, **kwargs) 298 | self.classes = self.classes_split_dict[self.split] 299 | 300 | @staticmethod 301 | def _training_file(split) -> str: 302 | return f"training_{split}.pt" 303 | 304 | @staticmethod 305 | def _test_file(split) -> str: 306 | return f"test_{split}.pt" 307 | 308 | @property 309 | def _file_prefix(self) -> str: 310 | return f"emnist-{self.split}-{'train' if self.train else 'test'}" 311 | 312 | @property 313 | def images_file(self) -> str: 314 | return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") 315 | 316 | @property 317 | def labels_file(self) -> str: 318 | return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") 319 | 320 | def _load_data(self): 321 | return read_image_file(self.images_file), read_label_file(self.labels_file) 322 | 323 | def _check_exists(self) -> bool: 324 | return all(check_integrity(file) for file in (self.images_file, self.labels_file)) 325 | 326 | def download(self) -> None: 327 | """Download the EMNIST data if it doesn't exist already.""" 328 | 329 | if self._check_exists(): 330 | return 331 | 332 | os.makedirs(self.raw_folder, exist_ok=True) 333 | 334 | download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) 335 | gzip_folder = os.path.join(self.raw_folder, "gzip") 336 | for gzip_file in os.listdir(gzip_folder): 337 | if gzip_file.endswith(".gz"): 338 | extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) 339 | shutil.rmtree(gzip_folder) 340 | 341 | 342 | class QMNIST(MNIST): 343 | """`QMNIST `_ Dataset. 344 | 345 | Args: 346 | root (string): Root directory of dataset whose ``raw`` 347 | subdir contains binary files of the datasets. 348 | what (string,optional): Can be 'train', 'test', 'test10k', 349 | 'test50k', or 'nist' for respectively the mnist compatible 350 | training set, the 60k qmnist testing set, the 10k qmnist 351 | examples that match the mnist testing set, the 50k 352 | remaining qmnist testing examples, or all the nist 353 | digits. The default is to select 'train' or 'test' 354 | according to the compatibility argument 'train'. 355 | compat (bool,optional): A boolean that says whether the target 356 | for each example is class number (for compatibility with 357 | the MNIST dataloader) or a torch vector containing the 358 | full qmnist information. Default=True. 359 | download (bool, optional): If True, downloads the dataset from 360 | the internet and puts it in root directory. If dataset is 361 | already downloaded, it is not downloaded again. 362 | transform (callable, optional): A function/transform that 363 | takes in an PIL image and returns a transformed 364 | version. E.g, ``transforms.RandomCrop`` 365 | target_transform (callable, optional): A function/transform 366 | that takes in the target and transforms it. 367 | train (bool,optional,compatibility): When argument 'what' is 368 | not specified, this boolean decides whether to load the 369 | training set ot the testing set. Default: True. 370 | """ 371 | 372 | subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} 373 | resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] 374 | "train": [ 375 | ( 376 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", 377 | "ed72d4157d28c017586c42bc6afe6370", 378 | ), 379 | ( 380 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", 381 | "0058f8dd561b90ffdd0f734c6a30e5e4", 382 | ), 383 | ], 384 | "test": [ 385 | ( 386 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", 387 | "1394631089c404de565df7b7aeaf9412", 388 | ), 389 | ( 390 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", 391 | "5b5b05890a5e13444e108efe57b788aa", 392 | ), 393 | ], 394 | "nist": [ 395 | ( 396 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", 397 | "7f124b3b8ab81486c9d8c2749c17f834", 398 | ), 399 | ( 400 | "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", 401 | "5ed0e788978e45d4a8bd4b7caec3d79d", 402 | ), 403 | ], 404 | } 405 | classes = [ 406 | "0 - zero", 407 | "1 - one", 408 | "2 - two", 409 | "3 - three", 410 | "4 - four", 411 | "5 - five", 412 | "6 - six", 413 | "7 - seven", 414 | "8 - eight", 415 | "9 - nine", 416 | ] 417 | 418 | def __init__( 419 | self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any 420 | ) -> None: 421 | if what is None: 422 | what = "train" if train else "test" 423 | self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) 424 | self.compat = compat 425 | self.data_file = what + ".pt" 426 | self.training_file = self.data_file 427 | self.test_file = self.data_file 428 | super().__init__(root, train, **kwargs) 429 | 430 | @property 431 | def images_file(self) -> str: 432 | (url, _), _ = self.resources[self.subsets[self.what]] 433 | return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) 434 | 435 | @property 436 | def labels_file(self) -> str: 437 | _, (url, _) = self.resources[self.subsets[self.what]] 438 | return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) 439 | 440 | def _check_exists(self) -> bool: 441 | return all(check_integrity(file) for file in (self.images_file, self.labels_file)) 442 | 443 | def _load_data(self): 444 | data = read_sn3_pascalvincent_tensor(self.images_file) 445 | if data.dtype != torch.uint8: 446 | raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") 447 | if data.ndimension() != 3: 448 | raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") 449 | 450 | targets = read_sn3_pascalvincent_tensor(self.labels_file).long() 451 | if targets.ndimension() != 2: 452 | raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") 453 | 454 | if self.what == "test10k": 455 | data = data[0:10000, :, :].clone() 456 | targets = targets[0:10000, :].clone() 457 | elif self.what == "test50k": 458 | data = data[10000:, :, :].clone() 459 | targets = targets[10000:, :].clone() 460 | 461 | return data, targets 462 | 463 | def download(self) -> None: 464 | """Download the QMNIST data if it doesn't exist already. 465 | Note that we only download what has been asked for (argument 'what'). 466 | """ 467 | if self._check_exists(): 468 | return 469 | 470 | os.makedirs(self.raw_folder, exist_ok=True) 471 | split = self.resources[self.subsets[self.what]] 472 | 473 | for url, md5 in split: 474 | download_and_extract_archive(url, self.raw_folder, md5=md5) 475 | 476 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 477 | # redefined to handle the compat flag 478 | img, target = self.data[index], self.targets[index] 479 | img = Image.fromarray(img.numpy(), mode="L") 480 | if self.transform is not None: 481 | img = self.transform(img) 482 | if self.compat: 483 | target = int(target[0]) 484 | if self.target_transform is not None: 485 | target = self.target_transform(target) 486 | return img, target 487 | 488 | def extra_repr(self) -> str: 489 | return f"Split: {self.what}" 490 | 491 | 492 | def get_int(b: bytes) -> int: 493 | return int(codecs.encode(b, "hex"), 16) 494 | 495 | 496 | SN3_PASCALVINCENT_BITSMAP = { 497 | 8: torch.uint8, 498 | 9: torch.int8, 499 | 11: torch.int16, 500 | 12: torch.int32, 501 | 13: torch.float32, 502 | 14: torch.float64, 503 | } 504 | 505 | TORCH_TYPE_BITS = { 506 | torch.uint8: 8, 507 | torch.int8: 8, 508 | torch.int16: 16, 509 | torch.int32: 32, 510 | torch.float32: 32, 511 | torch.float64: 64, 512 | } 513 | 514 | 515 | def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: 516 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). 517 | Argument may be a filename, compressed filename, or file object. 518 | """ 519 | # read 520 | with open(path, "rb") as f: 521 | data = f.read() 522 | # parse 523 | magic = get_int(data[0:4]) 524 | nd = magic % 256 525 | ty = magic // 256 526 | assert 1 <= nd <= 3 527 | assert 8 <= ty <= 14 528 | torch_type = SN3_PASCALVINCENT_BITSMAP[ty] 529 | s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] 530 | 531 | num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8 532 | # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, 533 | # we need to reverse the bytes before we can read them with torch.frombuffer(). 534 | needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 535 | parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) 536 | if needs_byte_reversal: 537 | parsed = parsed.flip(0) 538 | 539 | assert parsed.shape[0] == np.prod(s) or not strict 540 | return parsed.view(*s) 541 | 542 | 543 | def read_label_file(path: str) -> torch.Tensor: 544 | x = read_sn3_pascalvincent_tensor(path, strict=False) 545 | if x.dtype != torch.uint8: 546 | raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") 547 | if x.ndimension() != 1: 548 | raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") 549 | return x.long() 550 | 551 | 552 | def read_image_file(path: str) -> torch.Tensor: 553 | x = read_sn3_pascalvincent_tensor(path, strict=False) 554 | if x.dtype != torch.uint8: 555 | raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") 556 | if x.ndimension() != 3: 557 | raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") 558 | return x -------------------------------------------------------------------------------- /python/jtorch/gradscaler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, abc 2 | from enum import Enum 3 | from typing import Any, Dict, List, Optional, Tuple, cast 4 | import inspect 5 | import warnings 6 | 7 | import jittor as jt 8 | # import torch 9 | 10 | def _refresh_per_optimizer_state(): 11 | return {} 12 | 13 | 14 | class GradScaler: 15 | _scale: Optional[jt.Var] 16 | _grows_tracker: Optional[jt.Var] 17 | _per_optimizer_states: Dict[int, Dict[str, Any]] 18 | """ 19 | An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling 20 | conveniently. 21 | 22 | * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. 23 | * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. 24 | * ``scaler.update()`` updates ``scaler``'s scale factor. 25 | 26 | Example:: 27 | 28 | # Creates a GradScaler once at the beginning of training. 29 | scaler = GradScaler() 30 | 31 | for epoch in epochs: 32 | for input, target in data: 33 | optimizer.zero_grad() 34 | output = model(input) 35 | loss = loss_fn(output, target) 36 | 37 | # Scales loss. Calls backward() on scaled loss to create scaled gradients. 38 | scaler.scale(loss).backward() 39 | 40 | # scaler.step() first unscales gradients of the optimizer's params. 41 | # If gradients don't contain infs/NaNs, optimizer.step() is then called, 42 | # otherwise, optimizer.step() is skipped. 43 | scaler.step(optimizer) 44 | 45 | # Updates the scale for next iteration. 46 | scaler.update() 47 | 48 | See the :ref:`Automatic Mixed Precision examples` for usage 49 | (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, 50 | and multiple losses/optimizers. 51 | 52 | ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, 53 | a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if 54 | the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used 55 | without incurring inf or NaN gradient values. 56 | ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every 57 | ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). 58 | 59 | * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params 60 | themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. 61 | 62 | * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. 63 | If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by 64 | ``growth_factor``. 65 | 66 | The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its 67 | value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these 68 | iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). 69 | 70 | Args: 71 | init_scale (float, optional, default=2.**16): Initial scale factor. 72 | growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during 73 | :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. 74 | backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during 75 | :meth:`update` if inf/NaN gradients occur in an iteration. 76 | growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients 77 | that must occur for the scale to be multiplied by ``growth_factor``. 78 | enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply 79 | invokes the underlying ``optimizer.step()``, and other methods become no-ops. 80 | Default: ``True`` 81 | """ 82 | def __init__(self, 83 | init_scale=2.**16, 84 | growth_factor=2.0, 85 | backoff_factor=0.5, 86 | growth_interval=2000, 87 | enabled=True): 88 | self._enabled = enabled 89 | 90 | if self._enabled: 91 | assert growth_factor > 1.0, "The growth factor must be > 1.0." 92 | assert backoff_factor < 1.0, "The backoff factor must be < 1.0." 93 | 94 | self._init_scale = init_scale 95 | # self._scale will be lazily initialized during the first call to scale() 96 | self._scale = None 97 | self._growth_factor = growth_factor 98 | self._backoff_factor = backoff_factor 99 | self._growth_interval = growth_interval 100 | self._init_growth_tracker = 0 101 | # self._growth_tracker will be lazily initialized during the first call to scale() 102 | self._growth_tracker = None 103 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 104 | 105 | def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: 106 | fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." 107 | assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix 108 | assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix 109 | return (self._scale, self._growth_tracker) 110 | 111 | def _lazy_init_scale_growth_tracker(self): 112 | assert self._growth_tracker is None, "_growth_tracker initialized before _scale" 113 | self._scale = self._init_scale 114 | self._growth_tracker = self._init_growth_tracker 115 | 116 | def scale(self, outputs): 117 | """ 118 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 119 | 120 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 121 | unmodified. 122 | 123 | Args: 124 | outputs (Tensor or iterable of Tensors): Outputs to scale. 125 | """ 126 | if not self._enabled: 127 | return outputs 128 | 129 | 130 | # Short-circuit for the common case. 131 | if isinstance(outputs, jt.Var): 132 | assert jt.flags.use_cuda == 1 133 | if self._scale is None: 134 | self._lazy_init_scale_growth_tracker() 135 | assert self._scale is not None 136 | return outputs * self._scale 137 | 138 | def apply_scale(val): 139 | if isinstance(val, jt.Var): 140 | assert jt.flags.use_cuda == 1 141 | if self._scale is None: 142 | self._lazy_init_scale_growth_tracker() 143 | assert self._scale is not None 144 | return val * self._scale 145 | elif isinstance(val, abc.Iterable): 146 | iterable = map(apply_scale, val) 147 | if isinstance(val, (list, tuple)): 148 | return type(val)(iterable) 149 | else: 150 | return iterable 151 | else: 152 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 153 | 154 | return apply_scale(outputs) 155 | 156 | def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): 157 | with jt.no_grad(): 158 | optimizer.pre_step() 159 | for group in optimizer.param_groups: 160 | for to_unscale in group["grads"]: 161 | if to_unscale is None or isinstance(to_unscale,(int,float)): 162 | continue 163 | if (not allow_fp16) and str(to_unscale.dtype) == "float16": 164 | raise ValueError("Attempting to unscale FP16 gradients.") 165 | 166 | if not (to_unscale.isinf().any()): 167 | if inv_scale != 1.0: 168 | to_unscale.update(to_unscale*inv_scale) 169 | else: 170 | found_inf = 1.0 171 | 172 | return found_inf 173 | 174 | def unscale_(self, optimizer): 175 | """ 176 | Divides ("unscales") the optimizer's gradient tensors by the scale factor. 177 | 178 | :meth:`unscale_` is optional, serving cases where you need to 179 | :ref:`modify or inspect gradients` 180 | between the backward pass(es) and :meth:`step`. 181 | If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. 182 | 183 | Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: 184 | 185 | ... 186 | scaler.scale(loss).backward() 187 | scaler.unscale_(optimizer) 188 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 189 | scaler.step(optimizer) 190 | scaler.update() 191 | 192 | Args: 193 | optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. 194 | 195 | .. note:: 196 | :meth:`unscale_` does not incur a CPU-GPU sync. 197 | 198 | .. warning:: 199 | :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, 200 | and only after all gradients for that optimizer's assigned parameters have been accumulated. 201 | Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. 202 | 203 | .. warning:: 204 | :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. 205 | """ 206 | if not self._enabled: 207 | return 208 | 209 | self._check_scale_growth_tracker("unscale_") 210 | 211 | optimizer_state = self._per_optimizer_states[id(optimizer)] 212 | 213 | if hasattr(optimizer,"get_find_inf"): 214 | return 215 | # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 216 | assert self._scale is not None 217 | inv_scale = 1.0 / self._scale 218 | found_inf = 0.0 219 | optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) 220 | 221 | 222 | def step(self, optimizer, *args, **kwargs): 223 | """ 224 | :meth:`step` carries out the following two operations: 225 | 226 | 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` 227 | earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. 228 | 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled 229 | gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. 230 | 231 | ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. 232 | 233 | Returns the return value of ``optimizer.step(*args, **kwargs)``. 234 | 235 | Args: 236 | optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. 237 | args: Any arguments. 238 | kwargs: Any keyword arguments. 239 | 240 | .. warning:: 241 | Closure use is not currently supported. 242 | """ 243 | if (not self._enabled): 244 | return optimizer.step(*args, **kwargs) 245 | 246 | if "closure" in kwargs: 247 | raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") 248 | 249 | self._check_scale_growth_tracker("step") 250 | 251 | optimizer_state = self._per_optimizer_states[id(optimizer)] 252 | retval = None 253 | 254 | if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): 255 | # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. 256 | # The contract with custom optimizers is that their step() should accept an additional, 257 | # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: 258 | # it can query its own state, invoke unscale_ on itself, etc 259 | # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument 260 | # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` 261 | # and `found_inf` to the passed optimizer so that the optimizer can utilize those 262 | # to skip the parameter updates or unscale gradients before updating parameters in 263 | # the fused kernel, e.g. `FusedAdamMathFunctor`. 264 | # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, 265 | # while the method is expected to be called by users side, i.e. their optimizers. 266 | kwargs_ = kwargs 267 | has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters 268 | if has_grad_scaler_kwarg: 269 | warnings.warn( 270 | "GradScaler is going to stop passing itself as a keyword argument to the passed " 271 | "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " 272 | "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", 273 | FutureWarning) 274 | kwargs_.update({"grad_scaler": self}) 275 | else: 276 | if optimizer_state["stage"] is OptState.READY: 277 | self._check_inf_per_device(optimizer) 278 | scaler = self._get_scale_async() 279 | found_inf = cast( 280 | jt.Var, 281 | sum([ 282 | t for t in optimizer_state["found_inf_per_device"].values() 283 | ]) 284 | ) 285 | optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler 286 | optimizer.found_inf = found_inf 287 | retval = optimizer.step(*args, **kwargs_) 288 | optimizer_state["stage"] = OptState.STEPPED 289 | if not has_grad_scaler_kwarg: 290 | del optimizer.grad_scale 291 | del optimizer.found_inf 292 | return retval 293 | 294 | if hasattr(optimizer,"get_find_inf"): 295 | optimizer.set_grad_scale(self._scale) 296 | optimizer.step() 297 | optimizer_state["found_inf_per_device"] = optimizer.get_find_inf() 298 | return 299 | 300 | retval = None 301 | if not optimizer_state["found_inf_per_device"]: 302 | retval = optimizer.step(*args, **kwargs) 303 | else: 304 | optimizer.post_step() 305 | 306 | return retval 307 | 308 | 309 | def update(self, new_scale=None): 310 | """ 311 | Updates the scale factor. 312 | 313 | If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 314 | to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 315 | the scale is multiplied by ``growth_factor`` to increase it. 316 | 317 | Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 318 | used directly, it's used to fill GradScaler's internal scale tensor. So if 319 | ``new_scale`` was a tensor, later in-place changes to that tensor will not further 320 | affect the scale GradScaler uses internally.) 321 | 322 | Args: 323 | new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. 324 | 325 | .. warning:: 326 | :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 327 | been invoked for all optimizers used this iteration. 328 | """ 329 | if not self._enabled: 330 | return 331 | 332 | _scale, _growth_tracker = self._check_scale_growth_tracker("update") 333 | 334 | if new_scale is not None: 335 | # Accept a new user-defined scale. 336 | if isinstance(new_scale, float): 337 | self._scale.fill_(new_scale) # type: ignore[union-attr] 338 | else: 339 | reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." 340 | assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] 341 | assert new_scale.numel() == 1, reason 342 | assert new_scale.requires_grad is False, reason 343 | self._scale.copy_(new_scale) # type: ignore[union-attr] 344 | else: 345 | # Consume shared inf/nan data collected from optimizers to update the scale. 346 | # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 347 | found_infs = [state["found_inf_per_device"] 348 | for state in self._per_optimizer_states.values() 349 | ] 350 | 351 | assert len(found_infs) > 0, "No inf checks were recorded prior to update." 352 | 353 | found_inf_combined = found_infs[0] 354 | if len(found_infs) > 1: 355 | for i in range(1, len(found_infs)): 356 | found_inf_combined += found_infs[i] 357 | 358 | 359 | current_scale = _scale 360 | if found_inf_combined: 361 | current_scale *=self._backoff_factor 362 | _growth_tracker = 0 363 | else: 364 | successful = _growth_tracker+1 365 | if successful == self._growth_interval: 366 | new_scale = current_scale*self._growth_factor 367 | if new_scale < 1e9: 368 | current_scale = new_scale 369 | _growth_tracker = 0 370 | else: 371 | _growth_tracker = successful 372 | 373 | self._scale, self._growth_tracker = current_scale,_growth_tracker 374 | 375 | # To prepare for next iteration, clear the data collected from optimizers this iteration. 376 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 377 | 378 | def _get_scale_async(self): 379 | return self._scale 380 | 381 | def get_scale(self): 382 | """ 383 | Returns a Python float containing the current scale, or 1.0 if scaling is disabled. 384 | 385 | .. warning:: 386 | :meth:`get_scale` incurs a CPU-GPU sync. 387 | """ 388 | if self._enabled: 389 | return self._init_scale if self._scale is None else self._get_scale_async() 390 | else: 391 | return 1.0 392 | 393 | def get_growth_factor(self): 394 | r""" 395 | Returns a Python float containing the scale growth factor. 396 | """ 397 | return self._growth_factor 398 | 399 | def set_growth_factor(self, new_factor): 400 | r""" 401 | Args: 402 | new_scale (float): Value to use as the new scale growth factor. 403 | """ 404 | self._growth_factor = new_factor 405 | 406 | def get_backoff_factor(self): 407 | r""" 408 | Returns a Python float containing the scale backoff factor. 409 | """ 410 | return self._backoff_factor 411 | 412 | def set_backoff_factor(self, new_factor): 413 | r""" 414 | Args: 415 | new_scale (float): Value to use as the new scale backoff factor. 416 | """ 417 | self._backoff_factor = new_factor 418 | 419 | def get_growth_interval(self): 420 | r""" 421 | Returns a Python int containing the growth interval. 422 | """ 423 | return self._growth_interval 424 | 425 | def set_growth_interval(self, new_interval): 426 | r""" 427 | Args: 428 | new_interval (int): Value to use as the new growth interval. 429 | """ 430 | self._growth_interval = new_interval 431 | 432 | def _get_growth_tracker(self): 433 | if self._enabled: 434 | return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() 435 | else: 436 | return 0 437 | 438 | def is_enabled(self): 439 | r""" 440 | Returns a bool indicating whether this instance is enabled. 441 | """ 442 | return self._enabled 443 | 444 | def state_dict(self): 445 | r""" 446 | Returns the state of the scaler as a :class:`dict`. It contains five entries: 447 | 448 | * ``"scale"`` - a Python float containing the current scale 449 | * ``"growth_factor"`` - a Python float containing the current growth factor 450 | * ``"backoff_factor"`` - a Python float containing the current backoff factor 451 | * ``"growth_interval"`` - a Python int containing the current growth interval 452 | * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. 453 | 454 | If this instance is not enabled, returns an empty dict. 455 | 456 | .. note:: 457 | If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` 458 | should be called after :meth:`update`. 459 | """ 460 | return {"scale": self.get_scale(), 461 | "growth_factor": self._growth_factor, 462 | "backoff_factor": self._backoff_factor, 463 | "growth_interval": self._growth_interval, 464 | "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} 465 | 466 | def load_state_dict(self, state_dict): 467 | r""" 468 | Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. 469 | 470 | Args: 471 | state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. 472 | """ 473 | if not self._enabled: 474 | return 475 | 476 | if len(state_dict) == 0: 477 | raise RuntimeError("The source state dict is empty, possibly because it was saved " 478 | "from a disabled instance of GradScaler.") 479 | 480 | self._init_scale = state_dict["scale"] 481 | if self._scale is not None: 482 | self._scale.fill_(state_dict["scale"]) 483 | self._growth_factor = state_dict["growth_factor"] 484 | self._backoff_factor = state_dict["backoff_factor"] 485 | self._growth_interval = state_dict["growth_interval"] 486 | self._init_growth_tracker = state_dict["_growth_tracker"] 487 | if self._growth_tracker is not None: 488 | self._growth_tracker.fill_(state_dict["_growth_tracker"]) 489 | 490 | def __getstate__(self): 491 | state = self.__dict__.copy() 492 | if self._enabled: 493 | assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ 494 | "of an iteration, or at the end after scaler.update()." 495 | # Pickling _scale and _growth_tracker Tensors directly triggers 496 | # "warnings.warn("pickle support for Storage will be removed in 1.5..." 497 | # so instead, we set the unpickled instance up to reinitialize them lazily. 498 | state['_init_scale'] = self.get_scale() 499 | state['_init_growth_tracker'] = self._get_growth_tracker() 500 | state['_scale'] = None 501 | state['_growth_tracker'] = None 502 | return state 503 | 504 | def __setstate__(self, state): 505 | self.__dict__.update(state) 506 | 507 | def _check_inf_per_device(self, optimizer): 508 | _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") 509 | 510 | dummy_inv_scale = 1.0 511 | found_inf = 0.0 512 | 513 | self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ 514 | self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) 515 | 516 | return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 517 | 518 | def _found_inf_per_device(self, optimizer): 519 | return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 520 | -------------------------------------------------------------------------------- /python/jtorch/vision/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import pathlib 4 | import warnings 5 | from itertools import repeat 6 | from types import FunctionType 7 | from typing import Any, BinaryIO, List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image, ImageColor, ImageDraw, ImageFont 12 | 13 | __all__ = [ 14 | "make_grid", 15 | "save_image", 16 | "draw_bounding_boxes", 17 | "draw_segmentation_masks", 18 | "draw_keypoints", 19 | "flow_to_image", 20 | ] 21 | 22 | 23 | @torch.no_grad() 24 | def make_grid( 25 | tensor: Union[torch.Tensor, List[torch.Tensor]], 26 | nrow: int = 8, 27 | padding: int = 2, 28 | normalize: bool = False, 29 | value_range: Optional[Tuple[int, int]] = None, 30 | scale_each: bool = False, 31 | pad_value: float = 0.0, 32 | **kwargs, 33 | ) -> torch.Tensor: 34 | """ 35 | Make a grid of images. 36 | 37 | Args: 38 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 39 | or a list of images all of the same size. 40 | nrow (int, optional): Number of images displayed in each row of the grid. 41 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 42 | padding (int, optional): amount of padding. Default: ``2``. 43 | normalize (bool, optional): If True, shift the image to the range (0, 1), 44 | by the min and max values specified by ``value_range``. Default: ``False``. 45 | value_range (tuple, optional): tuple (min, max) where min and max are numbers, 46 | then these numbers are used to normalize the image. By default, min and max 47 | are computed from the tensor. 48 | range (tuple. optional): 49 | .. warning:: 50 | This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` 51 | instead. 52 | scale_each (bool, optional): If ``True``, scale each image in the batch of 53 | images separately rather than the (min, max) over all images. Default: ``False``. 54 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 55 | 56 | Returns: 57 | grid (Tensor): the tensor containing grid of images. 58 | """ 59 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 60 | _log_api_usage_once(make_grid) 61 | if not torch.is_tensor(tensor): 62 | if isinstance(tensor, list): 63 | for t in tensor: 64 | if not torch.is_tensor(t): 65 | raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") 66 | else: 67 | raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") 68 | 69 | if "range" in kwargs.keys(): 70 | warnings.warn( 71 | "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " 72 | "Please use 'value_range' instead." 73 | ) 74 | value_range = kwargs["range"] 75 | 76 | # if list of tensors, convert to a 4D mini-batch Tensor 77 | if isinstance(tensor, list): 78 | tensor = torch.stack(tensor, dim=0) 79 | 80 | if tensor.dim() == 2: # single image H x W 81 | tensor = tensor.unsqueeze(0) 82 | if tensor.dim() == 3: # single image 83 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel 84 | tensor = torch.cat((tensor, tensor, tensor), 0) 85 | tensor = tensor.unsqueeze(0) 86 | 87 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images 88 | tensor = torch.cat((tensor, tensor, tensor), 1) 89 | 90 | if normalize is True: 91 | tensor = tensor.clone() # avoid modifying tensor in-place 92 | if value_range is not None and not isinstance(value_range, tuple): 93 | raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") 94 | 95 | def norm_ip(img, low, high): 96 | img.clamp_(min=low, max=high) 97 | img.sub_(low).div_(max(high - low, 1e-5)) 98 | 99 | def norm_range(t, value_range): 100 | if value_range is not None: 101 | norm_ip(t, value_range[0], value_range[1]) 102 | else: 103 | norm_ip(t, float(t.min()), float(t.max())) 104 | 105 | if scale_each is True: 106 | for t in tensor: # loop over mini-batch dimension 107 | norm_range(t, value_range) 108 | else: 109 | norm_range(tensor, value_range) 110 | 111 | if not isinstance(tensor, torch.Tensor): 112 | raise TypeError("tensor should be of type torch.Tensor") 113 | if tensor.size(0) == 1: 114 | return tensor.squeeze(0) 115 | 116 | # make the mini-batch of images into a grid 117 | nmaps = tensor.size(0) 118 | xmaps = min(nrow, nmaps) 119 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 120 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) 121 | num_channels = tensor.size(1) 122 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) 123 | k = 0 124 | for y in range(ymaps): 125 | for x in range(xmaps): 126 | if k >= nmaps: 127 | break 128 | # Tensor.copy_() is a valid method but seems to be missing from the stubs 129 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ 130 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] 131 | 2, x * width + padding, width - padding 132 | ).copy_(tensor[k]) 133 | k = k + 1 134 | return grid 135 | 136 | 137 | @torch.no_grad() 138 | def save_image( 139 | tensor: Union[torch.Tensor, List[torch.Tensor]], 140 | fp: Union[str, pathlib.Path, BinaryIO], 141 | format: Optional[str] = None, 142 | **kwargs, 143 | ) -> None: 144 | """ 145 | Save a given Tensor into an image file. 146 | 147 | Args: 148 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, 149 | saves the tensor as a grid of images by calling ``make_grid``. 150 | fp (string or file object): A filename or a file object 151 | format(Optional): If omitted, the format to use is determined from the filename extension. 152 | If a file object was used instead of a filename, this parameter should always be used. 153 | **kwargs: Other arguments are documented in ``make_grid``. 154 | """ 155 | 156 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 157 | _log_api_usage_once(save_image) 158 | grid = make_grid(tensor, **kwargs) 159 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 160 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 161 | im = Image.fromarray(ndarr) 162 | im.save(fp, format=format) 163 | 164 | 165 | @torch.no_grad() 166 | def draw_bounding_boxes( 167 | image: torch.Tensor, 168 | boxes: torch.Tensor, 169 | labels: Optional[List[str]] = None, 170 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, 171 | fill: Optional[bool] = False, 172 | width: int = 1, 173 | font: Optional[str] = None, 174 | font_size: Optional[int] = None, 175 | ) -> torch.Tensor: 176 | 177 | """ 178 | Draws bounding boxes on given image. 179 | The values of the input image should be uint8 between 0 and 255. 180 | If fill is True, Resulting Tensor should be saved as PNG image. 181 | 182 | Args: 183 | image (Tensor): Tensor of shape (C x H x W) and dtype uint8. 184 | boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that 185 | the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and 186 | `0 <= ymin < ymax < H`. 187 | labels (List[str]): List containing the labels of bounding boxes. 188 | colors (color or list of colors, optional): List containing the colors 189 | of the boxes or single color for all boxes. The color can be represented as 190 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 191 | By default, random colors are generated for boxes. 192 | fill (bool): If `True` fills the bounding box with specified color. 193 | width (int): Width of bounding box. 194 | font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may 195 | also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, 196 | `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. 197 | font_size (int): The requested font size in points. 198 | 199 | Returns: 200 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. 201 | """ 202 | 203 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 204 | _log_api_usage_once(draw_bounding_boxes) 205 | if not isinstance(image, torch.Tensor): 206 | raise TypeError(f"Tensor expected, got {type(image)}") 207 | elif image.dtype != torch.uint8: 208 | raise ValueError(f"Tensor uint8 expected, got {image.dtype}") 209 | elif image.dim() != 3: 210 | raise ValueError("Pass individual images, not batches") 211 | elif image.size(0) not in {1, 3}: 212 | raise ValueError("Only grayscale and RGB images are supported") 213 | elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): 214 | raise ValueError( 215 | "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" 216 | ) 217 | 218 | num_boxes = boxes.shape[0] 219 | 220 | if num_boxes == 0: 221 | warnings.warn("boxes doesn't contain any box. No box was drawn") 222 | return image 223 | 224 | if labels is None: 225 | labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] 226 | elif len(labels) != num_boxes: 227 | raise ValueError( 228 | f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." 229 | ) 230 | 231 | if colors is None: 232 | colors = _generate_color_palette(num_boxes) 233 | elif isinstance(colors, list): 234 | if len(colors) < num_boxes: 235 | raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") 236 | else: # colors specifies a single color for all boxes 237 | colors = [colors] * num_boxes 238 | 239 | colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] 240 | 241 | if font is None: 242 | if font_size is not None: 243 | warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") 244 | txt_font = ImageFont.load_default() 245 | else: 246 | txt_font = ImageFont.truetype(font=font, size=font_size or 10) 247 | 248 | # Handle Grayscale images 249 | if image.size(0) == 1: 250 | image = torch.tile(image, (3, 1, 1)) 251 | 252 | ndarr = image.permute(1, 2, 0).cpu().numpy() 253 | img_to_draw = Image.fromarray(ndarr) 254 | img_boxes = boxes.to(torch.int64).tolist() 255 | 256 | if fill: 257 | draw = ImageDraw.Draw(img_to_draw, "RGBA") 258 | else: 259 | draw = ImageDraw.Draw(img_to_draw) 260 | 261 | for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] 262 | if fill: 263 | fill_color = color + (100,) 264 | draw.rectangle(bbox, width=width, outline=color, fill=fill_color) 265 | else: 266 | draw.rectangle(bbox, width=width, outline=color) 267 | 268 | if label is not None: 269 | margin = width + 1 270 | draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) 271 | 272 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) 273 | 274 | 275 | @torch.no_grad() 276 | def draw_segmentation_masks( 277 | image: torch.Tensor, 278 | masks: torch.Tensor, 279 | alpha: float = 0.8, 280 | colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, 281 | ) -> torch.Tensor: 282 | 283 | """ 284 | Draws segmentation masks on given RGB image. 285 | The values of the input image should be uint8 between 0 and 255. 286 | 287 | Args: 288 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8. 289 | masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. 290 | alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 291 | 0 means full transparency, 1 means no transparency. 292 | colors (color or list of colors, optional): List containing the colors 293 | of the masks or single color for all masks. The color can be represented as 294 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 295 | By default, random colors are generated for each mask. 296 | 297 | Returns: 298 | img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. 299 | """ 300 | 301 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 302 | _log_api_usage_once(draw_segmentation_masks) 303 | if not isinstance(image, torch.Tensor): 304 | raise TypeError(f"The image must be a tensor, got {type(image)}") 305 | elif image.dtype != torch.uint8: 306 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}") 307 | elif image.dim() != 3: 308 | raise ValueError("Pass individual images, not batches") 309 | elif image.size()[0] != 3: 310 | raise ValueError("Pass an RGB image. Other Image formats are not supported") 311 | if masks.ndim == 2: 312 | masks = masks[None, :, :] 313 | if masks.ndim != 3: 314 | raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") 315 | if masks.dtype != torch.bool: 316 | raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") 317 | if masks.shape[-2:] != image.shape[-2:]: 318 | raise ValueError("The image and the masks must have the same height and width") 319 | 320 | num_masks = masks.size()[0] 321 | if colors is not None and num_masks > len(colors): 322 | raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") 323 | 324 | if num_masks == 0: 325 | warnings.warn("masks doesn't contain any mask. No mask was drawn") 326 | return image 327 | 328 | if colors is None: 329 | colors = _generate_color_palette(num_masks) 330 | 331 | if not isinstance(colors, list): 332 | colors = [colors] 333 | if not isinstance(colors[0], (tuple, str)): 334 | raise ValueError("colors must be a tuple or a string, or a list thereof") 335 | if isinstance(colors[0], tuple) and len(colors[0]) != 3: 336 | raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") 337 | 338 | out_dtype = torch.uint8 339 | 340 | colors_ = [] 341 | for color in colors: 342 | if isinstance(color, str): 343 | color = ImageColor.getrgb(color) 344 | colors_.append(torch.tensor(color, dtype=out_dtype)) 345 | 346 | img_to_draw = image.detach().clone() 347 | # TODO: There might be a way to vectorize this 348 | for mask, color in zip(masks, colors_): 349 | img_to_draw[:, mask] = color[:, None] 350 | 351 | out = image * (1 - alpha) + img_to_draw * alpha 352 | return out.to(out_dtype) 353 | 354 | 355 | @torch.no_grad() 356 | def draw_keypoints( 357 | image: torch.Tensor, 358 | keypoints: torch.Tensor, 359 | connectivity: Optional[List[Tuple[int, int]]] = None, 360 | colors: Optional[Union[str, Tuple[int, int, int]]] = None, 361 | radius: int = 2, 362 | width: int = 3, 363 | ) -> torch.Tensor: 364 | 365 | """ 366 | Draws Keypoints on given RGB image. 367 | The values of the input image should be uint8 between 0 and 255. 368 | 369 | Args: 370 | image (Tensor): Tensor of shape (3, H, W) and dtype uint8. 371 | keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, 372 | in the format [x, y]. 373 | connectivity (List[Tuple[int, int]]]): A List of tuple where, 374 | each tuple contains pair of keypoints to be connected. 375 | colors (str, Tuple): The color can be represented as 376 | PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. 377 | radius (int): Integer denoting radius of keypoint. 378 | width (int): Integer denoting width of line connecting keypoints. 379 | 380 | Returns: 381 | img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. 382 | """ 383 | 384 | if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 385 | _log_api_usage_once(draw_keypoints) 386 | if not isinstance(image, torch.Tensor): 387 | raise TypeError(f"The image must be a tensor, got {type(image)}") 388 | elif image.dtype != torch.uint8: 389 | raise ValueError(f"The image dtype must be uint8, got {image.dtype}") 390 | elif image.dim() != 3: 391 | raise ValueError("Pass individual images, not batches") 392 | elif image.size()[0] != 3: 393 | raise ValueError("Pass an RGB image. Other Image formats are not supported") 394 | 395 | if keypoints.ndim != 3: 396 | raise ValueError("keypoints must be of shape (num_instances, K, 2)") 397 | 398 | ndarr = image.permute(1, 2, 0).cpu().numpy() 399 | img_to_draw = Image.fromarray(ndarr) 400 | draw = ImageDraw.Draw(img_to_draw) 401 | img_kpts = keypoints.to(torch.int64).tolist() 402 | 403 | for kpt_id, kpt_inst in enumerate(img_kpts): 404 | for inst_id, kpt in enumerate(kpt_inst): 405 | x1 = kpt[0] - radius 406 | x2 = kpt[0] + radius 407 | y1 = kpt[1] - radius 408 | y2 = kpt[1] + radius 409 | draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) 410 | 411 | if connectivity: 412 | for connection in connectivity: 413 | start_pt_x = kpt_inst[connection[0]][0] 414 | start_pt_y = kpt_inst[connection[0]][1] 415 | 416 | end_pt_x = kpt_inst[connection[1]][0] 417 | end_pt_y = kpt_inst[connection[1]][1] 418 | 419 | draw.line( 420 | ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), 421 | width=width, 422 | ) 423 | 424 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) 425 | 426 | 427 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 428 | @torch.no_grad() 429 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 430 | 431 | """ 432 | Converts a flow to an RGB image. 433 | 434 | Args: 435 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 436 | 437 | Returns: 438 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 439 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 440 | """ 441 | 442 | if flow.dtype != torch.float: 443 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 444 | 445 | orig_shape = flow.shape 446 | if flow.ndim == 3: 447 | flow = flow[None] # Add batch dim 448 | 449 | if flow.ndim != 4 or flow.shape[1] != 2: 450 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") 451 | 452 | max_norm = torch.sum(flow**2, dim=1).sqrt().max() 453 | epsilon = torch.finfo((flow).dtype).eps 454 | normalized_flow = flow / (max_norm + epsilon) 455 | img = _normalized_flow_to_image(normalized_flow) 456 | 457 | if len(orig_shape) == 3: 458 | img = img[0] # Remove batch dim 459 | return img 460 | 461 | 462 | @torch.no_grad() 463 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 464 | 465 | """ 466 | Converts a batch of normalized flow to an RGB image. 467 | 468 | Args: 469 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 470 | Returns: 471 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 472 | """ 473 | 474 | N, _, H, W = normalized_flow.shape 475 | device = normalized_flow.device 476 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 477 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 478 | num_cols = colorwheel.shape[0] 479 | norm = torch.sum(normalized_flow**2, dim=1).sqrt() 480 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi 481 | fk = (a + 1) / 2 * (num_cols - 1) 482 | k0 = torch.floor(fk).to(torch.long) 483 | k1 = k0 + 1 484 | k1[k1 == num_cols] = 0 485 | f = fk - k0 486 | 487 | for c in range(colorwheel.shape[1]): 488 | tmp = colorwheel[:, c] 489 | col0 = tmp[k0] / 255.0 490 | col1 = tmp[k1] / 255.0 491 | col = (1 - f) * col0 + f * col1 492 | col = 1 - norm * (1 - col) 493 | flow_image[:, c, :, :] = torch.floor(255 * col) 494 | return flow_image 495 | 496 | 497 | def _make_colorwheel() -> torch.Tensor: 498 | """ 499 | Generates a color wheel for optical flow visualization as presented in: 500 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 501 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 502 | 503 | Returns: 504 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 505 | """ 506 | 507 | RY = 15 508 | YG = 6 509 | GC = 4 510 | CB = 11 511 | BM = 13 512 | MR = 6 513 | 514 | ncols = RY + YG + GC + CB + BM + MR 515 | colorwheel = torch.zeros((ncols, 3)) 516 | col = 0 517 | 518 | # RY 519 | colorwheel[0:RY, 0] = 255 520 | colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) 521 | col = col + RY 522 | # YG 523 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) 524 | colorwheel[col : col + YG, 1] = 255 525 | col = col + YG 526 | # GC 527 | colorwheel[col : col + GC, 1] = 255 528 | colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) 529 | col = col + GC 530 | # CB 531 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) 532 | colorwheel[col : col + CB, 2] = 255 533 | col = col + CB 534 | # BM 535 | colorwheel[col : col + BM, 2] = 255 536 | colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) 537 | col = col + BM 538 | # MR 539 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) 540 | colorwheel[col : col + MR, 0] = 255 541 | return colorwheel 542 | 543 | 544 | def _generate_color_palette(num_objects: int): 545 | palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) 546 | return [tuple((i * palette) % 255) for i in range(num_objects)] 547 | 548 | 549 | def _log_api_usage_once(obj: Any) -> None: 550 | 551 | """ 552 | Logs API usage(module and name) within an organization. 553 | In a large ecosystem, it's often useful to track the PyTorch and 554 | TorchVision APIs usage. This API provides the similar functionality to the 555 | logging module in the Python stdlib. It can be used for debugging purpose 556 | to log which methods are used and by default it is inactive, unless the user 557 | manually subscribes a logger via the `SetAPIUsageLogger method `_. 558 | Please note it is triggered only once for the same API call within a process. 559 | It does not collect any data from open-source users since it is no-op by default. 560 | For more information, please refer to 561 | * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; 562 | * Logging policy: https://github.com/pytorch/vision/issues/5052; 563 | 564 | Args: 565 | obj (class instance or method): an object to extract info from. 566 | """ 567 | pass 568 | 569 | 570 | def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: 571 | """ 572 | Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. 573 | Otherwise we will make a tuple of length n, all with value of x. 574 | reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 575 | 576 | Args: 577 | x (Any): input value 578 | n (int): length of the resulting tuple 579 | """ 580 | if isinstance(x, collections.abc.Iterable): 581 | return tuple(x) 582 | return tuple(repeat(x, n)) -------------------------------------------------------------------------------- /python/jtorch/gradscaler_old.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, abc 2 | from enum import Enum 3 | from typing import Any, Dict, List, Optional, Tuple, cast 4 | import inspect 5 | import warnings 6 | 7 | import jittor as jt 8 | # import torch 9 | 10 | 11 | __all__ = ["OptState", "GradScaler"] 12 | 13 | 14 | # Defines default_factory for GradScaler's _per_optimizer_states defaultdict, 15 | # as well as associated "enum" values. Prefers defining these at top level because 16 | # - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. 17 | # - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler 18 | # causes a circular reference, which we'd rather avoid. 19 | class OptState(Enum): 20 | READY = 0 21 | UNSCALED = 1 22 | STEPPED = 2 23 | 24 | 25 | def _refresh_per_optimizer_state(): 26 | return {"stage": OptState.READY, "found_inf_per_device": {}} 27 | 28 | 29 | class GradScaler: 30 | _scale: Optional[jt.Var] 31 | _grows_tracker: Optional[jt.Var] 32 | _per_optimizer_states: Dict[int, Dict[str, Any]] 33 | """ 34 | An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling 35 | conveniently. 36 | 37 | * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. 38 | * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. 39 | * ``scaler.update()`` updates ``scaler``'s scale factor. 40 | 41 | Example:: 42 | 43 | # Creates a GradScaler once at the beginning of training. 44 | scaler = GradScaler() 45 | 46 | for epoch in epochs: 47 | for input, target in data: 48 | optimizer.zero_grad() 49 | output = model(input) 50 | loss = loss_fn(output, target) 51 | 52 | # Scales loss. Calls backward() on scaled loss to create scaled gradients. 53 | scaler.scale(loss).backward() 54 | 55 | # scaler.step() first unscales gradients of the optimizer's params. 56 | # If gradients don't contain infs/NaNs, optimizer.step() is then called, 57 | # otherwise, optimizer.step() is skipped. 58 | scaler.step(optimizer) 59 | 60 | # Updates the scale for next iteration. 61 | scaler.update() 62 | 63 | See the :ref:`Automatic Mixed Precision examples` for usage 64 | (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, 65 | and multiple losses/optimizers. 66 | 67 | ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, 68 | a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if 69 | the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used 70 | without incurring inf or NaN gradient values. 71 | ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every 72 | ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). 73 | 74 | * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params 75 | themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. 76 | 77 | * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. 78 | If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by 79 | ``growth_factor``. 80 | 81 | The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its 82 | value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these 83 | iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). 84 | 85 | Args: 86 | init_scale (float, optional, default=2.**16): Initial scale factor. 87 | growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during 88 | :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. 89 | backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during 90 | :meth:`update` if inf/NaN gradients occur in an iteration. 91 | growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients 92 | that must occur for the scale to be multiplied by ``growth_factor``. 93 | enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply 94 | invokes the underlying ``optimizer.step()``, and other methods become no-ops. 95 | Default: ``True`` 96 | """ 97 | def __init__(self, 98 | init_scale=2.**16, 99 | growth_factor=2.0, 100 | backoff_factor=0.5, 101 | growth_interval=2000, 102 | enabled=True): 103 | self._enabled = enabled 104 | 105 | if self._enabled: 106 | assert growth_factor > 1.0, "The growth factor must be > 1.0." 107 | assert backoff_factor < 1.0, "The backoff factor must be < 1.0." 108 | 109 | self._init_scale = init_scale 110 | # self._scale will be lazily initialized during the first call to scale() 111 | self._scale = None 112 | self._growth_factor = growth_factor 113 | self._backoff_factor = backoff_factor 114 | self._growth_interval = growth_interval 115 | self._init_growth_tracker = 0 116 | # self._growth_tracker will be lazily initialized during the first call to scale() 117 | self._growth_tracker = None 118 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 119 | 120 | def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: 121 | fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." 122 | assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix 123 | assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix 124 | return (self._scale, self._growth_tracker) 125 | 126 | def _lazy_init_scale_growth_tracker(self): 127 | assert self._growth_tracker is None, "_growth_tracker initialized before _scale" 128 | self._scale = self._init_scale 129 | self._growth_tracker = self._init_growth_tracker 130 | 131 | def scale(self, outputs): 132 | """ 133 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 134 | 135 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 136 | unmodified. 137 | 138 | Args: 139 | outputs (Tensor or iterable of Tensors): Outputs to scale. 140 | """ 141 | print("scale") 142 | if not self._enabled: 143 | return outputs 144 | 145 | 146 | # Short-circuit for the common case. 147 | if isinstance(outputs, jt.Var): 148 | assert jt.flags.use_cuda == 1 149 | if self._scale is None: 150 | self._lazy_init_scale_growth_tracker() 151 | assert self._scale is not None 152 | return outputs * self._scale 153 | 154 | def apply_scale(val): 155 | if isinstance(val, jt.Var): 156 | assert jt.flags.use_cuda == 1 157 | if self._scale is None: 158 | self._lazy_init_scale_growth_tracker() 159 | assert self._scale is not None 160 | return val * self._scale 161 | elif isinstance(val, abc.Iterable): 162 | iterable = map(apply_scale, val) 163 | if isinstance(val, (list, tuple)): 164 | return type(val)(iterable) 165 | else: 166 | return iterable 167 | else: 168 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 169 | 170 | return apply_scale(outputs) 171 | 172 | def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): 173 | 174 | # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. 175 | # There could be hundreds of grads, so we'd like to iterate through them just once. 176 | # However, we don't know their devices or dtypes in advance. 177 | 178 | # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict 179 | # Google says mypy struggles with defaultdicts type annotations. 180 | with jt.no_grad(): 181 | optimizer.pre_step() 182 | for group in optimizer.param_groups: 183 | for to_unscale in group["grads"]: 184 | if to_unscale is None or isinstance(to_unscale,(int,float)): 185 | continue 186 | if (not allow_fp16) and str(to_unscale.dtype) == "float16": 187 | raise ValueError("Attempting to unscale FP16 gradients.") 188 | 189 | if not (to_unscale.isinf().any()): 190 | if inv_scale != 1.0: 191 | to_unscale.update(to_unscale*inv_scale) 192 | else: 193 | found_inf = 1.0 194 | 195 | return found_inf 196 | 197 | def unscale_(self, optimizer): 198 | """ 199 | Divides ("unscales") the optimizer's gradient tensors by the scale factor. 200 | 201 | :meth:`unscale_` is optional, serving cases where you need to 202 | :ref:`modify or inspect gradients` 203 | between the backward pass(es) and :meth:`step`. 204 | If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. 205 | 206 | Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: 207 | 208 | ... 209 | scaler.scale(loss).backward() 210 | scaler.unscale_(optimizer) 211 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 212 | scaler.step(optimizer) 213 | scaler.update() 214 | 215 | Args: 216 | optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. 217 | 218 | .. note:: 219 | :meth:`unscale_` does not incur a CPU-GPU sync. 220 | 221 | .. warning:: 222 | :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, 223 | and only after all gradients for that optimizer's assigned parameters have been accumulated. 224 | Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. 225 | 226 | .. warning:: 227 | :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. 228 | """ 229 | if not self._enabled: 230 | return 231 | 232 | self._check_scale_growth_tracker("unscale_") 233 | 234 | optimizer_state = self._per_optimizer_states[id(optimizer)] 235 | 236 | if optimizer_state["stage"] is OptState.UNSCALED: 237 | raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") 238 | elif optimizer_state["stage"] is OptState.STEPPED: 239 | raise RuntimeError("unscale_() is being called after step().") 240 | 241 | 242 | # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 243 | assert self._scale is not None 244 | inv_scale = 1.0 / self._scale 245 | found_inf = 0.0 246 | optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) 247 | optimizer_state["stage"] = OptState.UNSCALED 248 | 249 | def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): 250 | retval = None 251 | if not optimizer_state["found_inf_per_device"]: 252 | retval = optimizer.step(*args, **kwargs) 253 | else: 254 | optimizer.post_step() 255 | 256 | return retval 257 | 258 | def step(self, optimizer, *args, **kwargs): 259 | """ 260 | :meth:`step` carries out the following two operations: 261 | 262 | 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` 263 | earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. 264 | 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled 265 | gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. 266 | 267 | ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. 268 | 269 | Returns the return value of ``optimizer.step(*args, **kwargs)``. 270 | 271 | Args: 272 | optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. 273 | args: Any arguments. 274 | kwargs: Any keyword arguments. 275 | 276 | .. warning:: 277 | Closure use is not currently supported. 278 | """ 279 | if (not self._enabled): 280 | return optimizer.step(*args, **kwargs) 281 | 282 | if "closure" in kwargs: 283 | raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") 284 | 285 | self._check_scale_growth_tracker("step") 286 | 287 | optimizer_state = self._per_optimizer_states[id(optimizer)] 288 | 289 | if optimizer_state["stage"] is OptState.STEPPED: 290 | raise RuntimeError("step() has already been called since the last update().") 291 | 292 | retval = None 293 | 294 | if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): 295 | # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. 296 | # The contract with custom optimizers is that their step() should accept an additional, 297 | # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: 298 | # it can query its own state, invoke unscale_ on itself, etc 299 | # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument 300 | # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` 301 | # and `found_inf` to the passed optimizer so that the optimizer can utilize those 302 | # to skip the parameter updates or unscale gradients before updating parameters in 303 | # the fused kernel, e.g. `FusedAdamMathFunctor`. 304 | # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, 305 | # while the method is expected to be called by users side, i.e. their optimizers. 306 | kwargs_ = kwargs 307 | has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters 308 | if has_grad_scaler_kwarg: 309 | warnings.warn( 310 | "GradScaler is going to stop passing itself as a keyword argument to the passed " 311 | "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " 312 | "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", 313 | FutureWarning) 314 | kwargs_.update({"grad_scaler": self}) 315 | else: 316 | if optimizer_state["stage"] is OptState.READY: 317 | self._check_inf_per_device(optimizer) 318 | scaler = self._get_scale_async() 319 | found_inf = cast( 320 | jt.Var, 321 | sum([ 322 | t for t in optimizer_state["found_inf_per_device"].values() 323 | ]) 324 | ) 325 | optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler 326 | optimizer.found_inf = found_inf 327 | retval = optimizer.step(*args, **kwargs_) 328 | optimizer_state["stage"] = OptState.STEPPED 329 | if not has_grad_scaler_kwarg: 330 | del optimizer.grad_scale 331 | del optimizer.found_inf 332 | return retval 333 | 334 | 335 | if optimizer_state["stage"] is OptState.READY: 336 | self.unscale_(optimizer) 337 | 338 | assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer." 339 | 340 | retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) 341 | 342 | optimizer_state["stage"] = OptState.STEPPED 343 | 344 | return retval 345 | 346 | def update(self, new_scale=None): 347 | """ 348 | Updates the scale factor. 349 | 350 | If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 351 | to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 352 | the scale is multiplied by ``growth_factor`` to increase it. 353 | 354 | Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 355 | used directly, it's used to fill GradScaler's internal scale tensor. So if 356 | ``new_scale`` was a tensor, later in-place changes to that tensor will not further 357 | affect the scale GradScaler uses internally.) 358 | 359 | Args: 360 | new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. 361 | 362 | .. warning:: 363 | :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 364 | been invoked for all optimizers used this iteration. 365 | """ 366 | if not self._enabled: 367 | return 368 | 369 | _scale, _growth_tracker = self._check_scale_growth_tracker("update") 370 | 371 | if new_scale is not None: 372 | # Accept a new user-defined scale. 373 | if isinstance(new_scale, float): 374 | self._scale.fill_(new_scale) # type: ignore[union-attr] 375 | else: 376 | reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." 377 | assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] 378 | assert new_scale.numel() == 1, reason 379 | assert new_scale.requires_grad is False, reason 380 | self._scale.copy_(new_scale) # type: ignore[union-attr] 381 | else: 382 | # Consume shared inf/nan data collected from optimizers to update the scale. 383 | # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 384 | found_infs = [state["found_inf_per_device"] 385 | for state in self._per_optimizer_states.values() 386 | ] 387 | 388 | assert len(found_infs) > 0, "No inf checks were recorded prior to update." 389 | 390 | found_inf_combined = found_infs[0] 391 | if len(found_infs) > 1: 392 | for i in range(1, len(found_infs)): 393 | found_inf_combined += found_infs[i] 394 | 395 | 396 | current_scale = _scale 397 | if found_inf_combined: 398 | current_scale *=self._backoff_factor 399 | _growth_tracker = 0 400 | else: 401 | successful = _growth_tracker+1 402 | if successful == self._growth_interval: 403 | new_scale = current_scale*self._growth_factor 404 | if new_scale < 1e9: 405 | current_scale = new_scale 406 | _growth_tracker = 0 407 | else: 408 | _growth_tracker = successful 409 | 410 | self._scale, self._growth_tracker = current_scale,_growth_tracker 411 | 412 | # To prepare for next iteration, clear the data collected from optimizers this iteration. 413 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 414 | 415 | def _get_scale_async(self): 416 | return self._scale 417 | 418 | def get_scale(self): 419 | """ 420 | Returns a Python float containing the current scale, or 1.0 if scaling is disabled. 421 | 422 | .. warning:: 423 | :meth:`get_scale` incurs a CPU-GPU sync. 424 | """ 425 | if self._enabled: 426 | return self._init_scale if self._scale is None else self._get_scale_async() 427 | else: 428 | return 1.0 429 | 430 | def get_growth_factor(self): 431 | r""" 432 | Returns a Python float containing the scale growth factor. 433 | """ 434 | return self._growth_factor 435 | 436 | def set_growth_factor(self, new_factor): 437 | r""" 438 | Args: 439 | new_scale (float): Value to use as the new scale growth factor. 440 | """ 441 | self._growth_factor = new_factor 442 | 443 | def get_backoff_factor(self): 444 | r""" 445 | Returns a Python float containing the scale backoff factor. 446 | """ 447 | return self._backoff_factor 448 | 449 | def set_backoff_factor(self, new_factor): 450 | r""" 451 | Args: 452 | new_scale (float): Value to use as the new scale backoff factor. 453 | """ 454 | self._backoff_factor = new_factor 455 | 456 | def get_growth_interval(self): 457 | r""" 458 | Returns a Python int containing the growth interval. 459 | """ 460 | return self._growth_interval 461 | 462 | def set_growth_interval(self, new_interval): 463 | r""" 464 | Args: 465 | new_interval (int): Value to use as the new growth interval. 466 | """ 467 | self._growth_interval = new_interval 468 | 469 | def _get_growth_tracker(self): 470 | if self._enabled: 471 | return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() 472 | else: 473 | return 0 474 | 475 | def is_enabled(self): 476 | r""" 477 | Returns a bool indicating whether this instance is enabled. 478 | """ 479 | return self._enabled 480 | 481 | def state_dict(self): 482 | r""" 483 | Returns the state of the scaler as a :class:`dict`. It contains five entries: 484 | 485 | * ``"scale"`` - a Python float containing the current scale 486 | * ``"growth_factor"`` - a Python float containing the current growth factor 487 | * ``"backoff_factor"`` - a Python float containing the current backoff factor 488 | * ``"growth_interval"`` - a Python int containing the current growth interval 489 | * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. 490 | 491 | If this instance is not enabled, returns an empty dict. 492 | 493 | .. note:: 494 | If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` 495 | should be called after :meth:`update`. 496 | """ 497 | return {"scale": self.get_scale(), 498 | "growth_factor": self._growth_factor, 499 | "backoff_factor": self._backoff_factor, 500 | "growth_interval": self._growth_interval, 501 | "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} 502 | 503 | def load_state_dict(self, state_dict): 504 | r""" 505 | Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. 506 | 507 | Args: 508 | state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. 509 | """ 510 | if not self._enabled: 511 | return 512 | 513 | if len(state_dict) == 0: 514 | raise RuntimeError("The source state dict is empty, possibly because it was saved " 515 | "from a disabled instance of GradScaler.") 516 | 517 | self._init_scale = state_dict["scale"] 518 | if self._scale is not None: 519 | self._scale.fill_(state_dict["scale"]) 520 | self._growth_factor = state_dict["growth_factor"] 521 | self._backoff_factor = state_dict["backoff_factor"] 522 | self._growth_interval = state_dict["growth_interval"] 523 | self._init_growth_tracker = state_dict["_growth_tracker"] 524 | if self._growth_tracker is not None: 525 | self._growth_tracker.fill_(state_dict["_growth_tracker"]) 526 | 527 | def __getstate__(self): 528 | state = self.__dict__.copy() 529 | if self._enabled: 530 | assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ 531 | "of an iteration, or at the end after scaler.update()." 532 | # Pickling _scale and _growth_tracker Tensors directly triggers 533 | # "warnings.warn("pickle support for Storage will be removed in 1.5..." 534 | # so instead, we set the unpickled instance up to reinitialize them lazily. 535 | state['_init_scale'] = self.get_scale() 536 | state['_init_growth_tracker'] = self._get_growth_tracker() 537 | state['_scale'] = None 538 | state['_growth_tracker'] = None 539 | return state 540 | 541 | def __setstate__(self, state): 542 | self.__dict__.update(state) 543 | 544 | def _check_inf_per_device(self, optimizer): 545 | _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") 546 | 547 | dummy_inv_scale = 1.0 548 | found_inf = 0.0 549 | 550 | self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ 551 | self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) 552 | 553 | return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 554 | 555 | def _found_inf_per_device(self, optimizer): 556 | return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] 557 | --------------------------------------------------------------------------------