├── README.md ├── tensorfn ├── vision │ ├── __init__.py │ ├── mix.py │ ├── transforms.py │ └── autoaugment.py ├── nn │ ├── carafe │ │ ├── __init__.py │ │ ├── carafe_cuda.cpp │ │ ├── carafe.py │ │ └── carafe_cuda_kernel.cu │ ├── __init__.py │ ├── model_util.py │ ├── dropblock.py │ ├── loss.py │ └── interpolate_spline.py ├── checker │ ├── __init__.py │ ├── checker.py │ └── backend.py ├── optim │ ├── __init__.py │ ├── lamb.py │ ├── lr_scheduler.py │ └── rmsprop_tf.py ├── data │ ├── __init__.py │ ├── s3reader.py │ ├── batch.py │ ├── grouped_sampler.py │ ├── lmdb_reader.py │ └── cluegen.py ├── nsml_wrapper.py ├── distributed │ ├── __init__.py │ ├── distributed.py │ └── launch.py ├── __init__.py ├── util │ ├── lazy_extension.py │ ├── __init__.py │ ├── ensure.py │ ├── logger.py │ └── config.py ├── config │ ├── __init__.py │ ├── data.py │ ├── checker.py │ ├── optimizer.py │ ├── lr_scheduler.py │ ├── builder.py │ └── config.py ├── test_util.py └── trainer.py ├── setup.cfg ├── requirements.txt ├── PKG-INFO ├── test └── test_carafe.py ├── setup.py ├── .gitignore └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # tensorfn -------------------------------------------------------------------------------- /tensorfn/vision/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_build = 3 | tag_date = 0 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1 2 | pydantic>=1.5 3 | pyhocon>=0.3.54 -------------------------------------------------------------------------------- /tensorfn/nn/carafe/__init__.py: -------------------------------------------------------------------------------- 1 | from .carafe import CARAFE, carafe 2 | -------------------------------------------------------------------------------- /tensorfn/checker/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorfn.checker.checker import Checker 2 | from tensorfn.checker.backend import Local, Logger, S3, NSML, WandB 3 | -------------------------------------------------------------------------------- /tensorfn/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorfn.optim import lr_scheduler 2 | from tensorfn.optim.lamb import LAMB 3 | from tensorfn.optim.rmsprop_tf import RMSpropTF 4 | -------------------------------------------------------------------------------- /tensorfn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorfn.data.lmdb_reader import LMDBReader 2 | from tensorfn.data.s3reader import S3Reader 3 | from tensorfn.data.grouped_sampler import create_groups, GroupedBatchSampler 4 | from tensorfn.data.batch import batch 5 | -------------------------------------------------------------------------------- /tensorfn/nsml_wrapper.py: -------------------------------------------------------------------------------- 1 | SESSION_ID = "" 2 | SESSION_NAME = "" 3 | GPU_NUM = 0 4 | 5 | IS_DATASET = False 6 | HAS_DATASET = False 7 | DATASET_PATH = "" 8 | DATASET_NAME = "" 9 | IS_ON_NSML = False 10 | NSML_NFS_OUTPUT = "" 11 | 12 | -------------------------------------------------------------------------------- /PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: tensorfn 3 | Version: 0.1.10 4 | Summary: Non-opionated utility library for PyTorch 5 | Home-page: https://github.com/rosinality/tensorfn 6 | Author: Kim Seonghyeon 7 | Author-email: kim.seonghyeon@navercorp.com 8 | License: MIT 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /tensorfn/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import ( 2 | get_rank, 3 | get_local_rank, 4 | is_primary, 5 | synchronize, 6 | get_world_size, 7 | all_reduce, 8 | all_gather, 9 | reduce_dict, 10 | data_sampler, 11 | LOCAL_PROCESS_GROUP, 12 | ) 13 | from .launch import launch, run 14 | -------------------------------------------------------------------------------- /tensorfn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # from .carafe import CARAFE, carafe 2 | from .dropblock import DropBlock2d 3 | from .loss import LabelSmoothingLoss, MixLoss 4 | from .interpolate_spline import ( 5 | regular_control, 6 | regular_grid, 7 | interpolate_spline, 8 | InterpolateSpline, 9 | ) 10 | from .model_util import repeat 11 | -------------------------------------------------------------------------------- /tensorfn/nn/model_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | def repeat(object, times): 5 | copied = [] 6 | 7 | for _ in range(times): 8 | if isinstance(object, (list, tuple)): 9 | for obj in object: 10 | copied.append(copy.deepcopy(obj)) 11 | 12 | else: 13 | copied.append(copy.deepcopy(object)) 14 | 15 | return copied 16 | -------------------------------------------------------------------------------- /tensorfn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import nsml 3 | 4 | except: 5 | from tensorfn import nsml_wrapper as nsml 6 | 7 | from tensorfn.util import ( 8 | read_config, 9 | preset_argparser, 10 | load_config, 11 | load_arg_config, 12 | add_distributed_args, 13 | load_wandb, 14 | ensure_tuple, 15 | get_logger, 16 | create_small_table, 17 | ) 18 | from tensorfn.checker import Checker 19 | from tensorfn.trainer import Trainer 20 | -------------------------------------------------------------------------------- /tensorfn/util/lazy_extension.py: -------------------------------------------------------------------------------- 1 | from torch.utils import cpp_extension as cpp_ext 2 | 3 | 4 | class LazyExtension: 5 | def __init__(self, name, sources): 6 | self.name = name 7 | self.sources = sources 8 | self.loaded = False 9 | self.extension = None 10 | 11 | def get(self): 12 | if self.extension is None: 13 | self.extension = cpp_ext.load(self.name, sources=self.sources) 14 | 15 | return self.extension 16 | -------------------------------------------------------------------------------- /tensorfn/config/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorfn.config.config import ( 2 | Config, 3 | TypedConfig, 4 | MainConfig, 5 | get_models, 6 | get_model, 7 | config_model, 8 | override, 9 | Instance, 10 | instantiate, 11 | ) 12 | from tensorfn.config.optimizer import Optimizer, make_optimizer 13 | from tensorfn.config.lr_scheduler import Scheduler 14 | from tensorfn.config.data import DataLoader, make_dataloader 15 | from tensorfn.config.checker import Checker 16 | -------------------------------------------------------------------------------- /tensorfn/util/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorfn.util.config import ( 2 | read_config, 3 | preset_argparser, 4 | load_config, 5 | load_arg_config, 6 | add_distributed_args, 7 | ) 8 | from tensorfn.util.ensure import ensure_tuple 9 | from tensorfn.util.lazy_extension import LazyExtension 10 | from tensorfn.util.logger import get_logger, create_small_table 11 | 12 | 13 | def load_wandb(): 14 | try: 15 | import wandb 16 | 17 | except ImportError: 18 | wandb = None 19 | 20 | return wandb 21 | -------------------------------------------------------------------------------- /test/test_carafe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import gradcheck 3 | 4 | from tensorfn.nn import carafe 5 | 6 | 7 | def test_carafe_gradcheck(): 8 | feat = torch.randn(2, 64, 3, 3, requires_grad=True, device='cuda:0').double() 9 | mask = ( 10 | torch.randn(2, 100, 6, 6, requires_grad=True, device='cuda:0') 11 | .sigmoid() 12 | .double() 13 | ) 14 | 15 | assert gradcheck( 16 | lambda feat, mask: carafe(feat, mask, 5, 4, 2), 17 | (feat, mask), 18 | atol=1e-4, 19 | eps=1e-4, 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="tensorfn", 5 | version="0.1.23", 6 | description="Non-opionated utility library for PyTorch", 7 | url="https://github.com/rosinality/tensorfn", 8 | author="Kim Seonghyeon", 9 | author_email="kim.seonghyeon@navercorp.com", 10 | license="MIT", 11 | install_requires=[ 12 | "torch>=1.1", 13 | "pydantic>=1.8", 14 | "pyhocon>=0.3.54", 15 | "termcolor", 16 | "tabulate", 17 | "boto3", 18 | "rich", 19 | ], 20 | packages=find_packages(), 21 | ) 22 | -------------------------------------------------------------------------------- /tensorfn/util/ensure.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | from itertools import repeat 3 | 4 | 5 | def ensure_tuple(x, n_item): 6 | if isinstance(x, abc.Iterable): 7 | try: 8 | if len(x) != n_item: 9 | raise ValueError( 10 | f"length of {x} (length: {len(x)}) does not match with the condition. expected length: {n_item}" 11 | ) 12 | 13 | except TypeError: 14 | pass 15 | 16 | return x 17 | 18 | return tuple(repeat(x, n_item)) 19 | 20 | 21 | if __name__ == "__main__": 22 | print(ensure_tuple(range(2), 2)) 23 | print(list(ensure_tuple((i ** 2 for i in range(5)), 2))) 24 | print(ensure_tuple(3, 2)) 25 | print(ensure_tuple(range(2), 3)) 26 | -------------------------------------------------------------------------------- /tensorfn/vision/mix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def rand_bbox(size, lam): 6 | H = size[2] 7 | W = size[3] 8 | cut_rat = np.sqrt(1.0 - lam) 9 | cut_w = np.int(W * cut_rat) 10 | cut_h = np.int(H * cut_rat) 11 | 12 | # uniform 13 | cx = np.random.randint(W) 14 | cy = np.random.randint(H) 15 | 16 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 17 | bby1 = np.clip(cy - cut_h // 2, 0, H) 18 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 19 | bby2 = np.clip(cy + cut_h // 2, 0, H) 20 | 21 | return bby1, bby2, bbx1, bbx2 22 | 23 | 24 | def cutmix(input, target, alpha): 25 | lam = np.random.beta(alpha, alpha) 26 | rand_i = torch.randperm(input.shape[0], device=input.device) 27 | target_a = target 28 | target_b = target[rand_i] 29 | by1, by2, bx1, bx2 = rand_bbox(input.shape, lam) 30 | input[:, :, by1:by2, bx1:bx2] = input[rand_i, :, by1:by2, bx1:bx2] 31 | lam = 1 - ((bx2 - bx1) * (by2 - by1) / input.shape[-1] * input.shape[-2]) 32 | 33 | return input, target_a, target_b, lam 34 | -------------------------------------------------------------------------------- /tensorfn/test_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | from pydantic import StrictInt, StrictFloat 6 | 7 | 8 | class StrictLinear(nn.Module): 9 | def __init__(self, in_dim: StrictInt, out_dim: StrictInt): 10 | super().__init__() 11 | 12 | self.linear = nn.Linear(in_dim, out_dim) 13 | 14 | 15 | class StrictFeedForward(nn.Module): 16 | def __init__( 17 | self, in_dim: StrictInt, dim: StrictInt, out_dim: StrictInt, dropout=0.1 18 | ): 19 | super().__init__() 20 | 21 | self.linear = nn.Linear(in_dim, out_dim) 22 | self.dropout = nn.Dropout(dropout) 23 | 24 | 25 | def model_runner(x: torch.Tensor, encoder: nn.Module): 26 | return x, encoder 27 | 28 | 29 | def model_wrapper(encoder: nn.Module, n_layer: StrictInt): 30 | return encoder, n_layer 31 | 32 | 33 | def return_list(x): 34 | return [x] 35 | 36 | 37 | def model_list_wrapper(encoder: List[nn.Module], weight: StrictFloat): 38 | return encoder, weight 39 | -------------------------------------------------------------------------------- /tensorfn/trainer.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Trainer: 5 | def __init__(self): 6 | self.__has_state_dict = {} 7 | self.__additional_obj = {} 8 | 9 | def register(self, name, value): 10 | self.__additional_obj[name] = value 11 | 12 | def __getattr__(self, name): 13 | if name not in self.__additional_obj: 14 | raise AttributeError(f"Cannot find attribute {name}") 15 | 16 | return self.__additional_obj[name] 17 | 18 | def __setattr__(self, name, value): 19 | if hasattr(value, "state_dict"): 20 | self.__has_state_dict[name] = value 21 | 22 | super().__setattr__(name, value) 23 | 24 | def state_dict(self): 25 | result = {} 26 | 27 | for k, v in self.__has_state_dict.items(): 28 | result[k] = v.state_dict() 29 | 30 | for k, v in self.__additional_obj.items(): 31 | if isinstance(v, BaseModel): 32 | v = v.dict() 33 | 34 | result[k] = v 35 | 36 | return result 37 | 38 | def load_state_dict(self, state_dict): 39 | for k, v in state_dict.items(): 40 | if k in self.__has_state_dict: 41 | self.__has_state_dict[k].load_state_dict(v) 42 | 43 | if k in self.__additional_obj: 44 | if isinstance(self.__additional_obj[k], BaseModel): 45 | self.__additional_obj[k] = self.__additional_obj[k].parse_obj(v) 46 | 47 | else: 48 | self.__additional_obj[k] = v 49 | -------------------------------------------------------------------------------- /tensorfn/config/data.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List, Optional, Any, Callable 2 | 3 | from pydantic import BaseModel, StrictInt, StrictBool 4 | from torch.utils import data 5 | 6 | from tensorfn.config import Config 7 | 8 | 9 | class DataLoader(Config): 10 | batch_size: StrictInt = 1 11 | shuffle: StrictBool = False 12 | num_workers: StrictInt = 0 13 | pin_memory: StrictBool = False 14 | drop_last: StrictBool = False 15 | timeout: StrictInt = 0 16 | 17 | def make( 18 | self, 19 | dataset, 20 | sampler=None, 21 | batch_sampler=None, 22 | collate_fn=None, 23 | worker_init_fn=None, 24 | multiprocessing_context=None, 25 | ): 26 | return data.DataLoader( 27 | dataset, 28 | self.batch_size, 29 | self.shuffle, 30 | sampler, 31 | batch_sampler, 32 | self.num_workers, 33 | collate_fn, 34 | self.pin_memory, 35 | self.drop_last, 36 | self.timeout, 37 | worker_init_fn, 38 | multiprocessing_context, 39 | ) 40 | 41 | 42 | def make_dataloader( 43 | config, 44 | dataset, 45 | sampler=None, 46 | batch_sampler=None, 47 | collate_fn=None, 48 | worker_init_fn=None, 49 | multiprocessing_context=None, 50 | ): 51 | return config.make( 52 | dataset, 53 | sampler, 54 | batch_sampler, 55 | collate_fn, 56 | worker_init_fn, 57 | multiprocessing_context, 58 | ) 59 | -------------------------------------------------------------------------------- /tensorfn/checker/checker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | try: 5 | import private 6 | 7 | except ImportError: 8 | private = None 9 | 10 | try: 11 | from rich.pretty import pretty_repr 12 | 13 | pformat = pretty_repr 14 | 15 | except ImportError: 16 | from pprint import pformat 17 | 18 | from tensorfn import distributed as dist 19 | from tensorfn.checker.backend import Local 20 | 21 | 22 | class Checker: 23 | def __init__(self, storages=None, reporters=None): 24 | self.storages = storages 25 | self.reporters = reporters 26 | 27 | def catalog(self, conf): 28 | if not dist.is_primary(): 29 | return 30 | 31 | if not isinstance(conf, dict): 32 | conf = conf.dict() 33 | 34 | conf = pformat(conf) 35 | 36 | argvs = " ".join([os.path.basename(sys.executable)] + sys.argv) 37 | 38 | template = f"""{argvs} 39 | 40 | {conf}""" 41 | template = template.encode("utf-8") 42 | 43 | for storage in self.storages: 44 | storage.save(template, "catalog.txt") 45 | 46 | def save(self, data, name): 47 | if dist.is_primary(): 48 | for storage in self.storages: 49 | storage.save(data, name) 50 | 51 | def checkpoint(self, obj, name): 52 | if dist.is_primary(): 53 | for storage in self.storages: 54 | storage.checkpoint(obj, name) 55 | 56 | def log(self, step=None, **kwargs): 57 | if dist.is_primary(): 58 | for reporter in self.reporters: 59 | reporter.log(step, **kwargs) 60 | -------------------------------------------------------------------------------- /tensorfn/nn/dropblock.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | from tensorfn import ensure_tuple 4 | 5 | 6 | class DropBlock2d(nn.Module): 7 | def __init__(self, p, block_size, share_mask_across_batch=False): 8 | super().__init__() 9 | 10 | self.p = p 11 | self.block_size = ensure_tuple(block_size, 2) 12 | self.share_mask = share_mask_across_batch 13 | pad_h = self.block_size[0] - 1 14 | pad_w = self.block_size[1] - 1 15 | self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 16 | 17 | def forward(self, input): 18 | if not self.training or self.p <= 0: 19 | return input 20 | 21 | batch, channel, height, width = input.shape 22 | 23 | prob = ( 24 | self.p 25 | * (height * width) 26 | / (self.block_size[0] * self.block_size[1]) 27 | / ((height - self.block_size[0] + 1) * (width - self.block_size[1] + 1)) 28 | ) 29 | 30 | mask_batch = 1 if self.share_mask else batch 31 | 32 | mask = input.new_empty(mask_batch, channel, height, width).bernoulli_(prob) 33 | 34 | if any(i != self.padding[0] for i in self.padding): 35 | mask = F.pad(mask, self.padding) 36 | mask = F.max_pool2d(mask, self.block_size, stride=1) 37 | 38 | else: 39 | mask = F.max_pool2d( 40 | mask, 41 | self.block_size, 42 | stride=1, 43 | padding=(self.padding[2], self.padding[0]), 44 | ) 45 | 46 | mask = mask.mul_(-1).add_(1) 47 | weight = ( 48 | mask.sum((2, 3), keepdim=True).add_(1e-8).reciprocal_().mul_(height * width) 49 | ) 50 | 51 | mask = mask.mul_(weight) 52 | 53 | out = input * mask 54 | 55 | return out 56 | -------------------------------------------------------------------------------- /tensorfn/data/s3reader.py: -------------------------------------------------------------------------------- 1 | try: 2 | import boto3 3 | from tqdm import tqdm 4 | 5 | except ImportError: 6 | boto3 = None 7 | 8 | from tensorfn.data.lmdb_reader import get_reader 9 | 10 | 11 | class S3Reader: 12 | def __init__( 13 | self, 14 | bucket, 15 | path=None, 16 | reader="torch", 17 | access_key=None, 18 | secret_key=None, 19 | endpoint=None, 20 | ): 21 | if boto3 is None: 22 | raise ImportError("boto3 should be installed for S3 storage") 23 | 24 | self.s3 = boto3.client( 25 | "s3", 26 | aws_access_key_id=access_key, 27 | aws_secret_access_key=secret_key, 28 | endpoint_url=endpoint, 29 | ) 30 | 31 | self.bucket = bucket 32 | self.path = path 33 | self.reader = self.get_reader(reader) 34 | self.length = None 35 | 36 | def open(self): 37 | try: 38 | self.length = int(self.get("length", "str")) 39 | 40 | except KeyError: 41 | self.length = 0 42 | 43 | def get_reader(self, reader): 44 | return get_reader(reader) 45 | 46 | def get(self, key, reader=None): 47 | if self.path is None: 48 | path_key = key 49 | 50 | else: 51 | path_key = f"{self.path}/{key}" 52 | 53 | return self.get_path(path_key, reader) 54 | 55 | def get_path(self, path_key, reader=None): 56 | if reader is not None: 57 | read_fn = self.get_reader(reader) 58 | 59 | else: 60 | read_fn = self.reader 61 | 62 | try: 63 | value = self.s3.get_object(Bucket=self.bucket, Key=path_key)["Body"].read() 64 | 65 | except self.s3.exceptions.NoSuchKey as e: 66 | raise KeyError(f"S3 bucket {self.bucket} does not have key {path_key}") 67 | 68 | return read_fn(value) 69 | 70 | def __len__(self): 71 | if self.length is None: 72 | self.open() 73 | 74 | return self.length 75 | 76 | def __iter__(self): 77 | i = 0 78 | 79 | while i < self.length: 80 | yield self.__getitem__(i) 81 | i += 1 82 | 83 | def __getitem__(self, index): 84 | return self.get(str(index)) 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tensorfn/data/batch.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from .cluegen import Datum, cluegen, all_clues 4 | 5 | 6 | def _set_attr(cls, name, value): 7 | if name in cls.__dict__: 8 | return True 9 | 10 | setattr(cls, name, value) 11 | 12 | return False 13 | 14 | 15 | def _to_fn(self, *args, **kwargs): 16 | fields = dataclasses.fields(self) 17 | new_batch = {} 18 | 19 | for field in fields: 20 | value = getattr(self, field.name) 21 | 22 | if hasattr(value, "to"): 23 | new_value = value.to(*args, **kwargs) 24 | 25 | else: 26 | new_value = value 27 | 28 | new_batch[field.name] = new_value 29 | 30 | return self.__class__(**new_batch) 31 | 32 | 33 | def to(self, *args, **kwargs): 34 | batch = {} 35 | for name in (c1, c2, c3): 36 | value = getattr(self, name) 37 | if hasattr(value, "to"): 38 | batch[name] = value.to(*args, **kwargs) 39 | 40 | else: 41 | batch[name] = value 42 | return self.__class__(**batch) 43 | 44 | 45 | class Batch(Datum): 46 | @cluegen 47 | def to(cls): 48 | clues = all_clues(cls) 49 | 50 | params = ", ".join(f'"{c}"' for c in clues) 51 | head = ( 52 | "def to(self, *args, **kwargs):\n" 53 | " batch = {}\n" 54 | f" for name in ({params}):\n" 55 | " value = getattr(self, name)\n" 56 | ' if hasattr(value, "to"):\n' 57 | " batch[name] = value.to(*args, **kwargs)\n" 58 | " else:\n" 59 | " batch[name] = value\n" 60 | " return self.__class__(**batch)" 61 | ) 62 | 63 | return head 64 | 65 | 66 | def batch(cls): 67 | _set_attr(cls, "to", _to_fn) 68 | 69 | return dataclasses.dataclass(cls) 70 | 71 | 72 | if __name__ == "__main__": 73 | import torch 74 | from typing import List 75 | 76 | @batch 77 | class Test: 78 | input: torch.Tensor 79 | label: List[int] 80 | 81 | abc = Test(input=torch.tensor([1, 2]), label=[0, 1]) 82 | abc_cuda = abc.to("cuda") 83 | print(abc.to("cuda")) 84 | print(id(abc.input), id(abc.label)) 85 | print(id(abc_cuda.input), id(abc_cuda.label)) 86 | 87 | class Test(Batch): 88 | input: torch.Tensor 89 | label: List[int] 90 | 91 | abc = Test(input=torch.tensor([1, 2]), label=[0, 1]) 92 | abc_cuda = abc.to("cuda") 93 | print(abc.to("cuda")) 94 | print(id(abc.input), id(abc.label)) 95 | print(id(abc_cuda.input), id(abc_cuda.label)) 96 | -------------------------------------------------------------------------------- /tensorfn/nn/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class LabelSmoothingLoss(nn.Module): 7 | def __init__(self, ignore_index, eps=0.1, reduction="mean"): 8 | super().__init__() 9 | 10 | self.ignore_index = ignore_index 11 | self.eps = eps 12 | self.reduction = reduction 13 | 14 | def forward(self, output, target): 15 | n_class = output.shape[-1] 16 | output = F.log_softmax(output, -1) 17 | 18 | if self.ignore_index > -1: 19 | n_class -= 1 20 | 21 | true_dist = torch.full_like(output, self.eps / n_class) 22 | true_dist.scatter_( 23 | 1, target.data.unsqueeze(1), 1 - self.eps + self.eps / n_class 24 | ) 25 | 26 | if self.ignore_index > -1: 27 | true_dist[:, self.ignore_index] = 0 28 | padding_mat = target.data == self.ignore_index 29 | mask = torch.nonzero(padding_mat, as_tuple=False) 30 | 31 | if mask.dim() > 0: 32 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 33 | 34 | loss = F.kl_div( 35 | output, 36 | true_dist.detach(), 37 | reduction="sum" if self.reduction != "none" else "none", 38 | ) 39 | 40 | if self.reduction == "none": 41 | loss = loss.sum(1) 42 | 43 | elif self.reduction == "mean": 44 | if self.ignore_index > -1: 45 | loss = loss / (target.shape[0] - padding_mat.sum().item()) 46 | 47 | else: 48 | loss = loss / target.shape[0] 49 | 50 | return loss 51 | 52 | 53 | class MixLoss(nn.Module): 54 | def __init__(self, eps=0, reduction="mean"): 55 | super().__init__() 56 | 57 | self.eps = eps 58 | self.reduction = reduction 59 | 60 | def forward(self, output, target1, target2, interpolation): 61 | n_class = output.shape[-1] 62 | output = F.log_softmax(output, -1) 63 | 64 | true_dist = torch.full_like(output, self.eps / n_class) 65 | true1 = true_dist.scatter( 66 | 1, target1.data.unsqueeze(1), 1 - self.eps + self.eps / n_class 67 | ) 68 | true2 = true_dist.scatter( 69 | 1, target2.data.unsqueeze(1), 1 - self.eps + self.eps / n_class 70 | ) 71 | inter = torch.as_tensor(interpolation).unsqueeze(-1) 72 | true_dist = inter * true1 + (1 - inter) * true2 73 | 74 | loss = F.kl_div( 75 | output, 76 | true_dist.detach(), 77 | reduction="sum" if self.reduction != "none" else "none", 78 | ) 79 | 80 | if self.reduction == "none": 81 | loss = loss.sum(1) 82 | 83 | elif self.reduction == "mean": 84 | loss = loss / target1.shape[0] 85 | 86 | return loss 87 | -------------------------------------------------------------------------------- /tensorfn/data/grouped_sampler.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import copy 3 | from collections import defaultdict 4 | import math 5 | from itertools import chain, repeat 6 | 7 | import numpy as np 8 | from torch.utils.data.sampler import BatchSampler, Sampler 9 | 10 | 11 | def quantize(x, bins): 12 | bins = copy.deepcopy(bins) 13 | bins = sorted(bins) 14 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 15 | 16 | return quantized 17 | 18 | 19 | def create_groups(sizes, bins): 20 | groups = quantize(sizes, bins) 21 | counts = np.unique(groups, return_counts=True)[1] 22 | fbins = [0] + bins + [np.inf] 23 | 24 | return groups, fbins, counts 25 | 26 | 27 | def _repeat_to_at_least(iterable, n): 28 | repeat_times = math.ceil(n / len(iterable)) 29 | repeated = chain.from_iterable(repeat(iterable, repeat_times)) 30 | return list(repeated) 31 | 32 | 33 | class GroupedBatchSampler(BatchSampler): 34 | def __init__(self, sampler, group_ids, batch_size): 35 | if not isinstance(sampler, Sampler): 36 | raise ValueError( 37 | "sampler should be an instance of " 38 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 39 | ) 40 | 41 | self.sampler = sampler 42 | self.group_ids = group_ids 43 | self.batch_size = batch_size 44 | 45 | def __iter__(self): 46 | buffer_per_group = defaultdict(list) 47 | samples_per_group = defaultdict(list) 48 | 49 | num_batches = 0 50 | for idx in self.sampler: 51 | group_id = self.group_ids[idx] 52 | buffer_per_group[group_id].append(idx) 53 | samples_per_group[group_id].append(idx) 54 | if len(buffer_per_group[group_id]) == self.batch_size: 55 | yield buffer_per_group[group_id] 56 | num_batches += 1 57 | del buffer_per_group[group_id] 58 | assert len(buffer_per_group[group_id]) < self.batch_size 59 | 60 | # now we have run out of elements that satisfy 61 | # the group criteria, let's return the remaining 62 | # elements so that the size of the sampler is 63 | # deterministic 64 | expected_num_batches = len(self) 65 | num_remaining = expected_num_batches - num_batches 66 | if num_remaining > 0: 67 | # for the remaining batches, take first the buffers with largest number 68 | # of elements 69 | for group_id, _ in sorted( 70 | buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True 71 | ): 72 | remaining = self.batch_size - len(buffer_per_group[group_id]) 73 | samples_from_group_id = _repeat_to_at_least( 74 | samples_per_group[group_id], remaining 75 | ) 76 | buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) 77 | assert len(buffer_per_group[group_id]) == self.batch_size 78 | yield buffer_per_group[group_id] 79 | num_remaining -= 1 80 | if num_remaining == 0: 81 | break 82 | assert num_remaining == 0 83 | 84 | def __len__(self): 85 | return len(self.sampler) // self.batch_size 86 | -------------------------------------------------------------------------------- /tensorfn/data/lmdb_reader.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | 4 | import lmdb 5 | import torch 6 | 7 | 8 | def pickle_reader(byte_str): 9 | return pickle.loads(byte_str) 10 | 11 | 12 | def torch_reader(byte_str): 13 | return torch.load(io.BytesIO(byte_str), map_location=lambda storage, loc: storage) 14 | 15 | 16 | def raw_reader(byte_str): 17 | return byte_str 18 | 19 | 20 | def str_reader(byte_str): 21 | return byte_str.decode("utf-8") 22 | 23 | 24 | def get_reader(reader): 25 | if isinstance(reader, str): 26 | read_fn = { 27 | "pickle": pickle_reader, 28 | "torch": torch_reader, 29 | "raw": raw_reader, 30 | "str": str_reader, 31 | }[reader] 32 | 33 | elif callable(reader): 34 | read_fn = reader 35 | 36 | else: 37 | raise ValueError('reader should be "pickle", "torch", "raw", "str" or callable') 38 | 39 | return read_fn 40 | 41 | 42 | class LMDBReader: 43 | def __init__( 44 | self, path, reader="torch", map_size=1024 ** 4, max_readers=126, lazy=True 45 | ): 46 | self.path = path 47 | self.map_size = map_size 48 | self.max_readers = max_readers 49 | 50 | self.env = None 51 | self.length = None 52 | 53 | self.reader = self.get_reader(reader) 54 | 55 | def open(self): 56 | self.env = lmdb.open( 57 | self.path, 58 | self.map_size, 59 | readonly=True, 60 | create=False, 61 | readahead=False, 62 | lock=False, 63 | max_readers=self.max_readers, 64 | ) 65 | 66 | if not self.env: 67 | raise IOError(f"Cannot open lmdb dataset {self.path}") 68 | 69 | try: 70 | self.length = int(self.get(b"length", "str")) 71 | 72 | except KeyError: 73 | self.length = 0 74 | 75 | def get_reader(self, reader): 76 | return get_reader(reader) 77 | 78 | def get(self, key, reader=None): 79 | if self.env is None: 80 | self.open() 81 | 82 | if reader is not None: 83 | read_fn = self.get_reader(reader) 84 | 85 | else: 86 | read_fn = self.reader 87 | 88 | with self.env.begin(write=False) as txn: 89 | value = txn.get(key) 90 | 91 | if value is None: 92 | raise KeyError(f"lmdb dataset does not have key {key}") 93 | 94 | return read_fn(value) 95 | 96 | def __len__(self): 97 | if self.length is None: 98 | self.open() 99 | self.close() 100 | 101 | return self.length 102 | 103 | def __iter__(self): 104 | i = 0 105 | 106 | while i < self.length: 107 | yield self.__getitem__(i) 108 | i += 1 109 | 110 | def __getitem__(self, index): 111 | return self.get(str(index).encode("utf-8")) 112 | 113 | def close(self): 114 | if self.env is not None: 115 | self.env.close() 116 | self.env = None 117 | 118 | def __del__(self): 119 | self.close() 120 | 121 | def __enter__(self): 122 | return self 123 | 124 | def __exit__(self, exc_type, exc_value, traceback): 125 | self.close() 126 | -------------------------------------------------------------------------------- /tensorfn/config/checker.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, List 2 | 3 | from pydantic import StrictStr, StrictBool, StrictInt 4 | from tensorfn.config import TypedConfig, override, Config 5 | from tensorfn import checker 6 | 7 | 8 | class Local(TypedConfig): 9 | __type__ = "local" 10 | path: StrictStr 11 | keep: StrictInt = -1 12 | 13 | def make(self, **kwargs): 14 | argument = override(kwargs, path=self.path, keep=self.keep) 15 | 16 | return checker.Local(**argument) 17 | 18 | 19 | class S3(TypedConfig): 20 | __type__ = "s3" 21 | bucket: StrictStr 22 | path: StrictStr 23 | access_key: StrictStr 24 | secret_key: StrictStr 25 | keep: StrictInt = -1 26 | endpoint: Optional[StrictStr] 27 | show_progress: StrictBool 28 | 29 | def make(self, **kwargs): 30 | argument = override( 31 | kwargs, 32 | bucket=self.bucket, 33 | path=self.path, 34 | access_key=self.access_key, 35 | secret_key=self.secret_key, 36 | keep=self.keep, 37 | endpoint=self.endpoint, 38 | show_progress=self.show_progress, 39 | ) 40 | 41 | return checker.S3(**argument) 42 | 43 | 44 | class Logger(TypedConfig): 45 | __type__ = "logger" 46 | 47 | def make(self, formatter=None): 48 | return checker.Logger(formatter) 49 | 50 | 51 | class WandB(TypedConfig): 52 | __type__ = "wandb" 53 | project: StrictStr 54 | group: Optional[StrictStr] = None 55 | name: Optional[StrictStr] = None 56 | notes: Optional[StrictStr] = None 57 | resume: Optional[Union[StrictBool, StrictStr]] = None 58 | tags: Optional[List[StrictStr]] = None 59 | id: Optional[StrictStr] = None 60 | 61 | def make(self, **kwargs): 62 | argument = override( 63 | kwargs, 64 | project=self.project, 65 | group=self.group, 66 | name=self.name, 67 | notes=self.notes, 68 | resume=self.resume, 69 | tags=self.tags, 70 | id=self.id, 71 | ) 72 | 73 | return checker.WandB(**argument) 74 | 75 | 76 | class NSML(TypedConfig): 77 | __type__ = "nsml" 78 | 79 | def make(self): 80 | return checker.NSML() 81 | 82 | 83 | Storage = Union[Local, S3] 84 | Reporter = Union[Logger, NSML, WandB] 85 | 86 | 87 | class Checker(Config): 88 | storage: Union[Storage, List[Storage]] = Local(type="local", path="experiment") 89 | reporter: Union[Reporter, List[Reporter]] = Logger(type="logger") 90 | 91 | def make(self, storage=None, reporter=None): 92 | if storage is None: 93 | if not isinstance(self.storage, list): 94 | storage_list = [self.storage] 95 | 96 | else: 97 | storage_list = self.storage 98 | 99 | storages = [] 100 | 101 | for storage in storage_list: 102 | storages.append(storage.make()) 103 | 104 | if reporter is None: 105 | if not isinstance(self.reporter, list): 106 | reporter_list = [self.reporter] 107 | 108 | else: 109 | reporter_list = self.reporter 110 | 111 | reporters = [] 112 | 113 | for reporter in reporter_list: 114 | reporters.append(reporter.make()) 115 | 116 | return checker.Checker(storages, reporters) 117 | -------------------------------------------------------------------------------- /tensorfn/distributed/distributed.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils import data 7 | 8 | 9 | LOCAL_PROCESS_GROUP = None 10 | 11 | 12 | def is_primary(): 13 | return get_rank() == 0 14 | 15 | 16 | def get_rank(): 17 | if not dist.is_available(): 18 | return 0 19 | 20 | if not dist.is_initialized(): 21 | return 0 22 | 23 | return dist.get_rank() 24 | 25 | 26 | def get_local_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | 30 | if not dist.is_initialized(): 31 | return 0 32 | 33 | if "LOCAL_RANK" in os.environ: 34 | return int(os.environ["LOCAL_RANK"]) 35 | 36 | if LOCAL_PROCESS_GROUP is None: 37 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 38 | 39 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 40 | 41 | 42 | def synchronize(): 43 | if not dist.is_available(): 44 | return 45 | 46 | if not dist.is_initialized(): 47 | return 48 | 49 | world_size = dist.get_world_size() 50 | 51 | if world_size == 1: 52 | return 53 | 54 | dist.barrier() 55 | 56 | 57 | def get_world_size(): 58 | if not dist.is_available(): 59 | return 1 60 | 61 | if not dist.is_initialized(): 62 | return 1 63 | 64 | return dist.get_world_size() 65 | 66 | 67 | def all_reduce(tensor, op=dist.ReduceOp.SUM): 68 | world_size = get_world_size() 69 | 70 | if world_size == 1: 71 | return tensor 72 | 73 | dist.all_reduce(tensor, op=op) 74 | 75 | return tensor 76 | 77 | 78 | def all_gather(data): 79 | world_size = get_world_size() 80 | 81 | if world_size == 1: 82 | return [data] 83 | 84 | buffer = pickle.dumps(data) 85 | storage = torch.ByteStorage.from_buffer(buffer) 86 | tensor = torch.ByteTensor(storage).to("cuda") 87 | 88 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 89 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 90 | dist.all_gather(size_list, local_size) 91 | size_list = [int(size.item()) for size in size_list] 92 | max_size = max(size_list) 93 | 94 | tensor_list = [] 95 | for _ in size_list: 96 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 97 | 98 | if local_size != max_size: 99 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 100 | tensor = torch.cat((tensor, padding), 0) 101 | 102 | dist.all_gather(tensor_list, tensor) 103 | 104 | data_list = [] 105 | 106 | for size, tensor in zip(size_list, tensor_list): 107 | buffer = tensor.cpu().numpy().tobytes()[:size] 108 | data_list.append(pickle.loads(buffer)) 109 | 110 | return data_list 111 | 112 | 113 | def reduce_dict(input_dict, average=True): 114 | world_size = get_world_size() 115 | 116 | if world_size < 2: 117 | return input_dict 118 | 119 | with torch.no_grad(): 120 | keys = [] 121 | values = [] 122 | 123 | for k in sorted(input_dict.keys()): 124 | keys.append(k) 125 | values.append(input_dict[k]) 126 | 127 | values = torch.stack(values, 0) 128 | dist.reduce(values, dst=0) 129 | 130 | if dist.get_rank() == 0 and average: 131 | values /= world_size 132 | 133 | reduced_dict = {k: v for k, v in zip(keys, values)} 134 | 135 | return reduced_dict 136 | 137 | 138 | def data_sampler(dataset, shuffle, distributed): 139 | if distributed: 140 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 141 | 142 | if shuffle: 143 | return data.RandomSampler(dataset) 144 | 145 | else: 146 | return data.SequentialSampler(dataset) 147 | -------------------------------------------------------------------------------- /tensorfn/data/cluegen.py: -------------------------------------------------------------------------------- 1 | # cluegen.py 2 | # 3 | # Classes generated from type clues. 4 | # 5 | # https://github.com/dabeaz/cluegen 6 | # 7 | # Author: David Beazley (@dabeaz). 8 | # http://www.dabeaz.com 9 | # 10 | # Copyright (C) 2018-2020. 11 | # 12 | # Permission is granted to use, copy, and modify this code in any 13 | # manner as long as this copyright message and disclaimer remain in 14 | # the source code. There is no warranty. Try to use the code for the 15 | # greater good. 16 | 17 | import types 18 | 19 | # Collect all type clues from a class and base classes. 20 | def all_clues(cls): 21 | clues = { } 22 | for c in reversed(cls.__mro__): 23 | clues.update(getattr(c, '__annotations__', {})) 24 | return clues 25 | 26 | # Decorator to define methods of a class as a code generator. 27 | def cluegen(func): 28 | def __get__(self, instance, cls): 29 | locs = { } 30 | code = func(cls) 31 | exec(code, locs) 32 | meth = locs[func.__name__] 33 | setattr(cls, func.__name__, meth) 34 | return meth.__get__(instance, cls) 35 | 36 | def __set_name__(self, cls, name): 37 | methods = cls.__dict__.get('_methods', list(cls._methods)) 38 | if '_methods' not in cls.__dict__: 39 | cls._methods = methods 40 | cls._methods.append((name, self)) 41 | 42 | return type(f'ClueGen_{func.__name__}', (), dict(__get__=__get__, 43 | __set_name__=__set_name__))() 44 | 45 | # Base class for defining data structures 46 | class DatumBase: 47 | __slots__ = () 48 | _methods = [] 49 | 50 | @classmethod 51 | def __init_subclass__(cls): 52 | submethods = [] 53 | for name, val in cls._methods: 54 | if name not in cls.__dict__: 55 | setattr(cls, name, val) 56 | submethods.append((name, val)) 57 | elif val is cls.__dict__[name]: 58 | submethods.append((name, val)) 59 | 60 | if submethods != cls._methods: 61 | cls._methods = submethods 62 | 63 | class Datum(DatumBase): 64 | __slots__ = () 65 | @cluegen 66 | def __init__(cls): 67 | clues = all_clues(cls) 68 | args = ', '.join(f'{name}={getattr(cls,name)!r}' 69 | if hasattr(cls, name) and not isinstance(getattr(cls, name), types.MemberDescriptorType) else name 70 | for name in clues) 71 | body = '\n'.join(f' self.{name} = {name}' 72 | for name in clues) 73 | return f'def __init__(self, {args}):\n{body}\n' 74 | 75 | @cluegen 76 | def __repr__(cls): 77 | clues = all_clues(cls) 78 | fmt = ', '.join('%s={self.%s!r}' % (name, name) for name in clues) 79 | return 'def __repr__(self):\n' \ 80 | ' return f"{type(self).__name__}(%s)"' % fmt 81 | 82 | @cluegen 83 | def __iter__(cls): 84 | clues = all_clues(cls) 85 | values = '\n'.join(f' yield self.{name}' for name in clues) 86 | return 'def __iter__(self):\n' + values 87 | 88 | 89 | @cluegen 90 | def __eq__(cls): 91 | clues = all_clues(cls) 92 | selfvals = ','.join(f'self.{name}' for name in clues) 93 | othervals = ','.join(f'other.{name}'for name in clues) 94 | return 'def __eq__(self, other):\n' \ 95 | ' if self.__class__ is other.__class__:\n' \ 96 | f' return ({selfvals},) == ({othervals},)\n' \ 97 | ' else:\n' \ 98 | ' return NotImplemented\n' 99 | 100 | @cluegen 101 | def __hash__(cls): 102 | clues = all_clues(cls) 103 | if clues: 104 | self_tuple = '(' + ','.join(f'self.{name}' for name in clues) + ',)' 105 | else: 106 | self_tuple = '()' 107 | return 'def __hash__(self):\n' \ 108 | f' return hash({self_tuple})\n' 109 | 110 | # Example use 111 | if __name__ == '__main__': 112 | # Start defining classes 113 | class Coordinates(Datum): 114 | x: int 115 | y: int 116 | 117 | 118 | -------------------------------------------------------------------------------- /tensorfn/config/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from pydantic import BaseModel, validator, StrictStr, StrictBool 4 | from torch import optim 5 | 6 | from tensorfn.config import Config, TypedConfig, override 7 | from tensorfn import optim as tensor_optim 8 | 9 | 10 | class SGD(Config): 11 | type: StrictStr 12 | 13 | lr: float 14 | momentum: float = 0.0 15 | dampening: float = 0.0 16 | weight_decay: float = 0.0 17 | nesterov: StrictBool = False 18 | 19 | @validator("type") 20 | def check_type(cls, v): 21 | if v != "sgd": 22 | raise ValueError("Optimizer type not match for sgd") 23 | 24 | return v 25 | 26 | def make(self, params, **kwargs): 27 | argument = override( 28 | kwargs, 29 | lr=self.lr, 30 | momentum=self.momentum, 31 | dampening=self.dampening, 32 | weight_decay=self.weight_decay, 33 | nesterov=self.nesterov, 34 | ) 35 | 36 | return optim.SGD(params, **argument) 37 | 38 | 39 | class Adam(Config): 40 | type: StrictStr 41 | 42 | lr: float = 0.001 43 | betas: Tuple[float, float] = (0.9, 0.999) 44 | eps: float = 1e-8 45 | weight_decay: float = 0 46 | amsgrad: StrictBool = False 47 | 48 | @validator("type") 49 | def check_type(cls, v): 50 | if v != "adam": 51 | raise ValueError("Optimizer type not match for adam") 52 | 53 | return v 54 | 55 | def make(self, params, **kwargs): 56 | argument = override( 57 | kwargs, 58 | lr=self.lr, 59 | betas=self.betas, 60 | eps=self.eps, 61 | weight_decay=self.weight_decay, 62 | amsgrad=self.amsgrad, 63 | ) 64 | 65 | return optim.Adam(params, **argument) 66 | 67 | 68 | class AdamW(Config): 69 | type: StrictStr 70 | 71 | lr: float = 0.001 72 | betas: Tuple[float, float] = (0.9, 0.999) 73 | eps: float = 1e-8 74 | weight_decay: float = 0 75 | amsgrad: StrictBool = False 76 | 77 | @validator("type") 78 | def check_type(cls, v): 79 | if v != "adamw": 80 | raise ValueError("Optimizer type not match for adam") 81 | 82 | return v 83 | 84 | def make(self, params, **kwargs): 85 | argument = override( 86 | kwargs, 87 | lr=self.lr, 88 | betas=self.betas, 89 | eps=self.eps, 90 | weight_decay=self.weight_decay, 91 | amsgrad=self.amsgrad, 92 | ) 93 | 94 | return optim.AdamW(params, **argument) 95 | 96 | 97 | class LAMB(Config): 98 | type: StrictStr 99 | 100 | lr: float = 0.001 101 | betas: Tuple[float, float] = (0.9, 0.999) 102 | eps: float = 1e-6 103 | weight_decay: float = 0 104 | 105 | @validator("type") 106 | def check_type(cls, v): 107 | if v != "lamb": 108 | raise ValueError("Optimizer type not match for adam") 109 | 110 | return v 111 | 112 | def make(self, params, **kwargs): 113 | argument = override( 114 | kwargs, 115 | lr=self.lr, 116 | betas=self.betas, 117 | eps=self.eps, 118 | weight_decay=self.weight_decay, 119 | ) 120 | 121 | return tensor_optim.LAMB(params, **argument) 122 | 123 | 124 | class RMSpropTF(TypedConfig): 125 | __type__ = "rmsprop_tf" 126 | 127 | lr: float = 0.01 128 | alpha: float = 0.9 129 | eps: float = 1e-10 130 | weight_decay: float = 0.0 131 | momentum: float = 0.0 132 | centered: StrictBool = False 133 | decoupled_decay: StrictBool = False 134 | lr_in_momentum: StrictBool = True 135 | 136 | def make(self, params, **kwargs): 137 | argument = override( 138 | kwargs, 139 | lr=self.lr, 140 | alpha=self.alpha, 141 | eps=self.eps, 142 | weight_decay=self.weight_decay, 143 | momentum=self.momentum, 144 | centered=self.centered, 145 | decoupled_decay=self.decoupled_decay, 146 | lr_in_momentum=self.lr_in_momentum, 147 | ) 148 | 149 | return tensor_optim.RMSpropTF(params, **argument) 150 | 151 | 152 | def make_optimizer(config, params): 153 | return config.make(params) 154 | 155 | 156 | Optimizer = Union[SGD, Adam, AdamW, LAMB, RMSpropTF] 157 | -------------------------------------------------------------------------------- /tensorfn/distributed/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import distributed as dist 5 | from torch import multiprocessing as mp 6 | from torch.distributed.launcher.api import elastic_launch 7 | 8 | from tensorfn import distributed as dist_fn 9 | 10 | 11 | def find_free_port(): 12 | import socket 13 | 14 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 15 | 16 | sock.bind(("", 0)) 17 | port = sock.getsockname()[1] 18 | sock.close() 19 | 20 | return port 21 | 22 | 23 | def run(conf, fn, args=()): 24 | launch( 25 | fn, 26 | conf.n_gpu, 27 | conf.n_machine, 28 | conf.machine_rank, 29 | conf.dist_url, 30 | conf.launch_config, 31 | args=args, 32 | ) 33 | 34 | 35 | def launch( 36 | fn, 37 | n_gpu_per_machine, 38 | n_machine=1, 39 | machine_rank=0, 40 | dist_url=None, 41 | launch_config=None, 42 | args=(), 43 | ): 44 | world_size = n_machine * n_gpu_per_machine 45 | 46 | if world_size > 1: 47 | if "OMP_NUM_THREADS" not in os.environ: 48 | os.environ["OMP_NUM_THREADS"] = "1" 49 | 50 | if launch_config is not None: 51 | elastic_launch(config=launch_config, entrypoint=elastic_worker)(fn, args) 52 | 53 | return 54 | 55 | if dist_url == "auto": 56 | if n_machine != 1: 57 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 58 | 59 | port = find_free_port() 60 | dist_url = f"tcp://127.0.0.1:{port}" 61 | 62 | if n_machine > 1 and dist_url.startswith("file://"): 63 | raise ValueError( 64 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 65 | ) 66 | 67 | mp.spawn( 68 | distributed_worker, 69 | nprocs=n_gpu_per_machine, 70 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 71 | daemon=False, 72 | ) 73 | 74 | else: 75 | fn(*args) 76 | 77 | 78 | def elastic_worker(fn, args): 79 | if not torch.cuda.is_available(): 80 | raise OSError("CUDA is not available. Please check your environments") 81 | 82 | local_rank = int(os.environ["LOCAL_RANK"]) 83 | n_gpu_per_machine = int(os.environ["LOCAL_WORLD_SIZE"]) 84 | 85 | try: 86 | dist.init_process_group( 87 | backend="NCCL", 88 | ) 89 | 90 | except Exception: 91 | raise OSError("failed to initialize NCCL groups") 92 | 93 | dist_fn.synchronize() 94 | 95 | if n_gpu_per_machine > torch.cuda.device_count(): 96 | raise ValueError( 97 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 98 | ) 99 | 100 | torch.cuda.set_device(local_rank) 101 | 102 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 103 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 104 | 105 | fn(*args) 106 | 107 | 108 | def distributed_worker( 109 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 110 | ): 111 | if not torch.cuda.is_available(): 112 | raise OSError("CUDA is not available. Please check your environments") 113 | 114 | if local_rank == "ENV": 115 | local_rank = int(os.environ["LOCAL_RANK"]) 116 | 117 | global_rank = machine_rank * n_gpu_per_machine + local_rank 118 | 119 | try: 120 | dist.init_process_group( 121 | backend="NCCL", 122 | init_method=dist_url, 123 | world_size=world_size, 124 | rank=global_rank, 125 | ) 126 | 127 | except Exception: 128 | raise OSError("failed to initialize NCCL groups") 129 | 130 | dist_fn.synchronize() 131 | 132 | if n_gpu_per_machine > torch.cuda.device_count(): 133 | raise ValueError( 134 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 135 | ) 136 | 137 | torch.cuda.set_device(local_rank) 138 | 139 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 140 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 141 | 142 | n_machine = world_size // n_gpu_per_machine 143 | 144 | for i in range(n_machine): 145 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 146 | pg = dist.new_group(ranks_on_i) 147 | 148 | if i == machine_rank: 149 | dist_fn.distributed.LOCAL_PROCESS_GROUP = pg 150 | 151 | fn(*args) 152 | -------------------------------------------------------------------------------- /tensorfn/optim/lamb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | class LAMB(Optimizer): 7 | r"""Implements LAMB algorithm. 8 | 9 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | 21 | .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: 22 | https://arxiv.org/abs/1904.00962 23 | """ 24 | 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0): 26 | if not 0.0 <= lr: 27 | raise ValueError("Invalid learning rate: {}".format(lr)) 28 | if not 0.0 <= eps: 29 | raise ValueError("Invalid epsilon value: {}".format(eps)) 30 | if not 0.0 <= betas[0] < 1.0: 31 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 32 | if not 0.0 <= betas[1] < 1.0: 33 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 34 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 35 | super(LAMB, self).__init__(params, defaults) 36 | 37 | def __setstate__(self, state): 38 | super(LAMB, self).__setstate__(state) 39 | 40 | @torch.no_grad() 41 | def step(self, closure=None): 42 | """Performs a single optimization step. 43 | 44 | Arguments: 45 | closure (callable, optional): A closure that reevaluates the model 46 | and returns the loss. 47 | """ 48 | loss = None 49 | if closure is not None: 50 | loss = closure() 51 | 52 | for group in self.param_groups: 53 | for p in group["params"]: 54 | if p.grad is None: 55 | continue 56 | grad = p.grad 57 | if grad.is_sparse: 58 | raise RuntimeError( 59 | "Adam does not support sparse gradients, please consider SparseAdam instead" 60 | ) 61 | 62 | state = self.state[p] 63 | 64 | # State initialization 65 | if len(state) == 0: 66 | state["step"] = 0 67 | # Exponential moving average of gradient values 68 | state["exp_avg"] = torch.zeros_like( 69 | p, memory_format=torch.preserve_format 70 | ) 71 | # Exponential moving average of squared gradient values 72 | state["exp_avg_sq"] = torch.zeros_like( 73 | p, memory_format=torch.preserve_format 74 | ) 75 | 76 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 77 | beta1, beta2 = group["betas"] 78 | 79 | state["step"] += 1 80 | 81 | # if group['weight_decay'] != 0: 82 | # grad.add_(group['weight_decay'], p.data) 83 | 84 | # Decay the first and second moment running average coefficient 85 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 86 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 87 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 88 | 89 | bias_correction1 = 1 - beta1 ** state["step"] 90 | bias_correction2 = 1 - beta2 ** state["step"] 91 | 92 | r1 = p.norm().item() 93 | step = exp_avg.div(denom).mul_( 94 | math.sqrt(bias_correction2) / bias_correction1 95 | ) 96 | 97 | if group["weight_decay"] != 0: 98 | step.add_(p, alpha=group["weight_decay"]) 99 | 100 | r2 = step.norm().item() 101 | 102 | if r1 > 0 and r2 > 0: 103 | r = r1 / r2 104 | 105 | else: 106 | r = 1 107 | 108 | lr = group["lr"] * r 109 | 110 | p.data.add_(step, alpha=-lr) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /tensorfn/config/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Sequence 2 | 3 | from pydantic import BaseModel, validator, StrictStr, StrictInt, StrictBool 4 | 5 | from tensorfn.config import Config, TypedConfig, override 6 | from tensorfn.optim import lr_scheduler 7 | 8 | 9 | class Constant(Config): 10 | type: StrictStr 11 | 12 | @validator("type") 13 | def check_type(cls, v): 14 | if v != "constant": 15 | raise ValueError("Optimizer options not match for constant scheduler") 16 | 17 | return v 18 | 19 | def make(self, optimizer): 20 | return lr_scheduler.ConstantScheduler(optimizer) 21 | 22 | 23 | class Cycle(Config): 24 | type: StrictStr 25 | 26 | lr: float 27 | n_iter: StrictInt = 0 28 | initial_multiplier: float = 4e-2 29 | final_multiplier: float = 1e-5 30 | warmup: StrictInt = 0 31 | plateau: StrictInt = 0 32 | decay: Sequence[StrictStr] = ("linear", "cos") 33 | 34 | @validator("type") 35 | def check_type(cls, v): 36 | if v != "cycle": 37 | raise ValueError("Optimizer options not match for cycle") 38 | 39 | return v 40 | 41 | def make(self, optimizer, **kwargs): 42 | argument = override( 43 | kwargs, 44 | lr=self.lr, 45 | n_iter=self.n_iter, 46 | initial_multiplier=self.initial_multiplier, 47 | final_multiplier=self.final_multiplier, 48 | warmup=self.warmup, 49 | plateau=self.plateau, 50 | decay=self.decay, 51 | ) 52 | 53 | return lr_scheduler.cycle_scheduler(optimizer, **argument) 54 | 55 | 56 | class Step(Config): 57 | type: StrictStr 58 | 59 | lr: float 60 | milestones: Sequence[StrictInt] 61 | gamma: float = 0.1 62 | warmup: StrictInt = 0 63 | warmup_multiplier = 4e-2 64 | 65 | @validator("type") 66 | def check_type(cls, v): 67 | if v != "step": 68 | raise ValueError("Optimizer options not match for cycle") 69 | 70 | return v 71 | 72 | def make(self, optimizer, **kwargs): 73 | argument = override( 74 | kwargs, 75 | lr=self.lr, 76 | milestones=self.milestones, 77 | gamma=self.gamma, 78 | warmup=self.warmup, 79 | warmup_multiplier=self.warmup_multiplier, 80 | ) 81 | 82 | return lr_scheduler.step_scheduler(optimizer, **argument) 83 | 84 | 85 | class Exp(TypedConfig): 86 | __type__ = "exp" 87 | 88 | lr: float 89 | step: StrictInt 90 | max_iter: StrictInt = 0 91 | gamma: float = 0.97 92 | warmup: StrictInt = 0 93 | warmup_multiplier: float = 4e-2 94 | 95 | def make(self, optimizer, **kwargs): 96 | argument = override( 97 | kwargs, 98 | lr=self.lr, 99 | step=self.step, 100 | max_iter=self.max_iter, 101 | gamma=self.gamma, 102 | warmup=self.warmup, 103 | warmup_multiplier=self.warmup_multiplier, 104 | ) 105 | 106 | return lr_scheduler.exp_scheduler(optimizer, **argument) 107 | 108 | 109 | class ExpEpoch(TypedConfig): 110 | __type__ = "exp_epoch" 111 | 112 | lr: float 113 | epoch: float 114 | max_iter: StrictInt = 0 115 | gamma: float = 0.97 116 | warmup: StrictInt = 0 117 | warmup_multiplier: float = 4e-2 118 | 119 | def make(self, optimizer, epoch_step, **kwargs): 120 | argument = override( 121 | kwargs, 122 | lr=self.lr, 123 | max_iter=self.max_iter, 124 | gamma=self.gamma, 125 | warmup=self.warmup, 126 | warmup_multiplier=self.warmup_multiplier, 127 | ) 128 | 129 | return lr_scheduler.exp_scheduler( 130 | optimizer, step=int(epoch_step * self.epoch), **argument 131 | ) 132 | 133 | 134 | class LRFind(Config): 135 | type: StrictStr 136 | 137 | lr_min: float 138 | lr_max: float 139 | n_iter: StrictInt 140 | linear: StrictBool = False 141 | 142 | @validator("type") 143 | def check_type(cls, v): 144 | if v != "lr_find": 145 | raise ValueError("Optimizer options not match for cycle") 146 | 147 | return v 148 | 149 | def make(self, optimizer, **kwargs): 150 | argument = override( 151 | kwargs, 152 | lr_min=self.lr_min, 153 | lr_max=self.lr_max, 154 | n_iter=self.n_iter, 155 | linear=self.linear, 156 | ) 157 | 158 | return lr_scheduler.lr_finder(optimizer, **argument) 159 | 160 | 161 | Scheduler = Union[Constant, Cycle, Step, ExpEpoch, Exp, LRFind] 162 | -------------------------------------------------------------------------------- /tensorfn/util/logger.py: -------------------------------------------------------------------------------- 1 | # shamelessly took from detectron2 2 | # https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/logger.py 3 | 4 | import functools 5 | import logging 6 | import sys 7 | import pprint 8 | 9 | from termcolor import colored 10 | from tabulate import tabulate 11 | 12 | try: 13 | from rich.logging import RichHandler 14 | 15 | except ImportError: 16 | RichHandler = None 17 | 18 | from tensorfn import distributed as dist 19 | 20 | 21 | class ColorfulFormatter(logging.Formatter): 22 | def __init__(self, *args, **kwargs): 23 | self._root_name = kwargs.pop("root_name") + "." 24 | self._abbrev_name = kwargs.pop("abbrev_name", "") 25 | if len(self._abbrev_name): 26 | self._abbrev_name = self._abbrev_name + "." 27 | super().__init__(*args, **kwargs) 28 | 29 | def formatMessage(self, record): 30 | record.name = record.name.replace(self._root_name, self._abbrev_name) 31 | log = super().formatMessage(record) 32 | if record.levelno == logging.WARNING: 33 | prefix = colored("WARNING", "red", attrs=["blink"]) 34 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 35 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 36 | else: 37 | return log 38 | return prefix + " " + log 39 | 40 | 41 | def wrap_log_record_factory(factory): 42 | def wrapper( 43 | name, level, fn, lno, msg, args, exc_info, func=None, sinfo=None, **kwargs 44 | ): 45 | if not isinstance(msg, str): 46 | msg = pprint.pformat(msg) 47 | 48 | return factory(name, level, fn, lno, msg, args, exc_info, func, sinfo, **kwargs) 49 | 50 | return wrapper 51 | 52 | 53 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 54 | def get_logger(distributed_rank=None, *, mode="rich", name="main", abbrev_name=None): 55 | """ 56 | Initialize the detectron2 logger and set its verbosity level to "DEBUG". 57 | Args: 58 | output (str): a file name or a directory to save log. If None, will not save log file. 59 | If ends with ".txt" or ".log", assumed to be a file name. 60 | Otherwise, logs will be saved to `output/log.txt`. 61 | name (str): the root module name of this logger 62 | abbrev_name (str): an abbreviation of the module, to avoid long names in logs. 63 | Set to "" to not log the root module in logs. 64 | By default, will abbreviate "detectron2" to "d2" and leave other 65 | modules unchanged. 66 | Returns: 67 | logging.Logger: a logger 68 | """ 69 | if distributed_rank is None: 70 | distributed_rank = dist.get_rank() 71 | 72 | logging.setLogRecordFactory(wrap_log_record_factory(logging.getLogRecordFactory())) 73 | 74 | logger = logging.getLogger(name) 75 | 76 | if logger.handlers: 77 | return logger 78 | 79 | logger.setLevel(logging.DEBUG) 80 | logger.propagate = False 81 | 82 | if abbrev_name is None: 83 | abbrev_name = name 84 | 85 | if mode == "rich" and RichHandler is None: 86 | mode = "color" 87 | 88 | if distributed_rank == 0: 89 | if mode == "color": 90 | ch = logging.StreamHandler(stream=sys.stdout) 91 | ch.setLevel(logging.DEBUG) 92 | formatter = ColorfulFormatter( 93 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 94 | datefmt="%m/%d %H:%M:%S", 95 | root_name=name, 96 | abbrev_name=str(abbrev_name), 97 | ) 98 | ch.setFormatter(formatter) 99 | logger.addHandler(ch) 100 | 101 | elif mode == "rich": 102 | logger.addHandler( 103 | RichHandler(level=logging.DEBUG, log_time_format="%m/%d %H:%M:%S") 104 | ) 105 | 106 | elif mode == "plain": 107 | ch = logging.StreamHandler(stream=sys.stdout) 108 | ch.setLevel(logging.DEBUG) 109 | formatter = logging.Formatter( 110 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", 111 | datefmt="%m/%d %H:%M:%S", 112 | ) 113 | ch.setFormatter(formatter) 114 | logger.addHandler(ch) 115 | 116 | return logger 117 | 118 | 119 | def create_small_table(small_dict): 120 | """ 121 | Create a small table using the keys of small_dict as headers. This is only 122 | suitable for small dictionaries. 123 | Args: 124 | small_dict (dict): a result dictionary of only a few items. 125 | Returns: 126 | str: the table as a string. 127 | """ 128 | 129 | keys, values = tuple(zip(*small_dict.items())) 130 | table = tabulate( 131 | [values], 132 | headers=keys, 133 | tablefmt="pipe", 134 | floatfmt=".3f", 135 | stralign="center", 136 | numalign="center", 137 | ) 138 | 139 | return table 140 | -------------------------------------------------------------------------------- /tensorfn/vision/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from PIL import Image, ImageOps, ImageEnhance 5 | 6 | PIL_INTER_MAP = { 7 | Image.NEAREST: "PIL.Image.NEAREST", 8 | Image.BILINEAR: "PIL.Image.BILINEAR", 9 | Image.BICUBIC: "PIL.Image.BICUBIC", 10 | Image.LANCZOS: "PIL.Image.LANCZOS", 11 | Image.HAMMING: "PIL.Image.HAMMING", 12 | Image.BOX: "PIL.Image.BOX", 13 | } 14 | 15 | IMAGENET_EIGVAL = (0.2175, 0.0188, 0.0045) 16 | IMAGENET_EIGVEC = ( 17 | (-0.5675, 0.7192, 0.4009), 18 | (-0.5808, -0.0045, -0.8140), 19 | (-0.5836, -0.6948, 0.4203), 20 | ) 21 | 22 | 23 | def check_prob(p): 24 | return p == 1.0 or random.random() < p 25 | 26 | 27 | class RandomTransform: 28 | def __init__(self, p): 29 | self.p = p 30 | 31 | def apply_img(self, img, **params): 32 | if not check_prob(self.p): 33 | return img 34 | 35 | return self._apply_img(img, **params) 36 | 37 | def apply_img_check(self, img, **params): 38 | if not check_prob(self.p): 39 | return img, False 40 | 41 | return self._apply_img(img, **params), True 42 | 43 | def _repr_params(self): 44 | params = dict(self.__dict__) 45 | 46 | return params 47 | 48 | def __repr__(self): 49 | params = [] 50 | 51 | for k, v in self._repr_params().items(): 52 | params.append(f"{k}={v}") 53 | 54 | param_str = ", ".join(params) 55 | repr_str = f"{self.__class__.__name__}({param_str})" 56 | 57 | return repr_str 58 | 59 | 60 | class Lighting(RandomTransform): 61 | def __init__( 62 | self, alpha_std, eigval=IMAGENET_EIGVAL, eigvec=IMAGENET_EIGVEC, p=1.0 63 | ): 64 | super().__init__(p) 65 | 66 | self.alpha_std = alpha_std 67 | self.eigval = torch.as_tensor(eigval) 68 | self.eigvec = torch.as_tensor(eigvec) 69 | 70 | def __call__(self, img): 71 | alpha = img.new_empty(3).normal_(0, self.alpha_std) 72 | rgb = ( 73 | self.eigvec.to(img) 74 | .mul(alpha.view(1, 3).expand(3, 3)) 75 | .mul(self.eigval.view(1, 3).expand(3, 3)) 76 | .sum(1) 77 | .squeeze() 78 | ) 79 | 80 | return img + rgb.view(3, 1, 1) 81 | 82 | 83 | class Affine(RandomTransform): 84 | def __init__(self, degrees, translate, scale, shear, p=1.0): 85 | super().__init__(p) 86 | 87 | self.degrees = degrees 88 | self.translate = translate 89 | self.scale = scale 90 | self.shear = shear 91 | 92 | def _apply_img(self, img, degrees, translate, scale, shear): 93 | pass 94 | 95 | 96 | class Posterize(RandomTransform): 97 | def __init__(self, bits, p=1.0): 98 | super().__init__(p) 99 | 100 | self.bits = int(bits) 101 | 102 | def sample(self): 103 | return {"bits": self.bits} 104 | 105 | def _apply_img(self, img, bits): 106 | return ImageOps.posterize(img, bits) 107 | 108 | 109 | class Invert(RandomTransform): 110 | def __init__(self, p): 111 | super().__init__(p) 112 | 113 | def sample(self): 114 | return {} 115 | 116 | def _apply_img(self, img): 117 | return ImageOps.invert(img) 118 | 119 | 120 | class AutoContrast(RandomTransform): 121 | def __init__(self, p): 122 | super().__init__(p) 123 | 124 | def sample(self): 125 | return {} 126 | 127 | def _apply_img(self, img): 128 | return ImageOps.autocontrast(img) 129 | 130 | 131 | class Equalize(RandomTransform): 132 | def __init__(self, p): 133 | super().__init__(p) 134 | 135 | def sample(self): 136 | return {} 137 | 138 | def _apply_img(self, img): 139 | return ImageOps.equalize(img) 140 | 141 | 142 | class Solarize(RandomTransform): 143 | def __init__(self, threshold, p=1.0): 144 | super().__init__(p) 145 | 146 | self.threshold = int(threshold) 147 | 148 | def sample(self): 149 | return {"threshold": self.threshold} 150 | 151 | def _apply_img(self, img, threshold): 152 | return ImageOps.solarize(img, threshold) 153 | 154 | 155 | class Saturation(RandomTransform): 156 | def __init__(self, saturation, p=1.0): 157 | super().__init__(p) 158 | 159 | self.saturation = saturation 160 | 161 | def sample(self): 162 | return {"saturation": self.saturation} 163 | 164 | def _apply_img(self, img, saturation): 165 | return ImageEnhance.Color(img).enhance(saturation) 166 | 167 | 168 | class Contrast(RandomTransform): 169 | def __init__(self, contrast, p=1.0): 170 | super().__init__(p) 171 | 172 | self.contrast = contrast 173 | 174 | def sample(self): 175 | return {"contrast": self.contrast} 176 | 177 | def _apply_img(self, img, contrast): 178 | return ImageEnhance.Contrast(img).enhance(contrast) 179 | 180 | 181 | class Brightness(RandomTransform): 182 | def __init__(self, brightness, p=1.0): 183 | super().__init__(p) 184 | 185 | self.brightness = brightness 186 | 187 | def sample(self): 188 | return {"brightness": self.brightness} 189 | 190 | def _apply_img(self, img, brightness): 191 | return ImageEnhance.Brightness(img).enhance(brightness) 192 | -------------------------------------------------------------------------------- /tensorfn/nn/carafe/carafe_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2018-2019 Open-MMLab 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks, 22 | const int kernel_size, const int group_size, 23 | const int scale_factor, const int batch_size, 24 | const int channels, const int input_height, 25 | const int input_width, const int output_height, 26 | const int output_width, const int mask_channels, 27 | at::Tensor rfeatures, at::Tensor routput, 28 | at::Tensor rmasks, at::Tensor output); 29 | 30 | int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures, 31 | const at::Tensor masks, const int kernel_size, 32 | const int group_size, const int scale_factor, 33 | const int batch_size, const int channels, 34 | const int input_height, const int input_width, 35 | const int output_height, const int output_width, 36 | const int mask_channels, at::Tensor rtop_grad, 37 | at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, 38 | at::Tensor rmask_grad, at::Tensor bottom_grad, 39 | at::Tensor mask_grad); 40 | 41 | #define CHECK_CUDA(x) \ 42 | TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ") 43 | #define CHECK_CONTIGUOUS(x) \ 44 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 45 | #define CHECK_INPUT(x) \ 46 | CHECK_CUDA(x); \ 47 | CHECK_CONTIGUOUS(x) 48 | 49 | int carafe_forward_cuda(at::Tensor features, at::Tensor rfeatures, 50 | at::Tensor masks, at::Tensor rmasks, int kernel_size, 51 | int group_size, int scale_factor, at::Tensor routput, 52 | at::Tensor output) { 53 | CHECK_INPUT(features); 54 | CHECK_INPUT(rfeatures); 55 | CHECK_INPUT(masks); 56 | CHECK_INPUT(rmasks); 57 | CHECK_INPUT(output); 58 | CHECK_INPUT(routput); 59 | at::DeviceGuard guard(features.device()); 60 | 61 | const int batch_size = output.size(0); 62 | const int num_channels = output.size(1); 63 | const int output_height = output.size(2); 64 | const int output_width = output.size(3); 65 | 66 | const int input_height = features.size(2); 67 | const int input_width = features.size(3); 68 | 69 | const int mask_channels = masks.size(1); 70 | 71 | rfeatures.resize_({batch_size, input_height, input_width, num_channels}); 72 | routput.resize_({batch_size, output_height, output_width, num_channels}); 73 | rmasks.resize_({batch_size, output_height, output_width, mask_channels}); 74 | 75 | CARAFEForwardLaucher(features, masks, kernel_size, group_size, scale_factor, 76 | batch_size, num_channels, input_height, input_width, 77 | output_height, output_width, mask_channels, rfeatures, 78 | routput, rmasks, output); 79 | 80 | return 1; 81 | } 82 | 83 | int carafe_backward_cuda(at::Tensor top_grad, at::Tensor rfeatures, 84 | at::Tensor masks, int kernel_size, int group_size, 85 | int scale_factor, at::Tensor rtop_grad, 86 | at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, 87 | at::Tensor rmask_grad, at::Tensor bottom_grad, 88 | at::Tensor mask_grad) { 89 | CHECK_INPUT(top_grad); 90 | CHECK_INPUT(rfeatures); 91 | CHECK_INPUT(masks); 92 | CHECK_INPUT(rtop_grad); 93 | CHECK_INPUT(rbottom_grad_hs); 94 | CHECK_INPUT(rbottom_grad); 95 | CHECK_INPUT(rmask_grad); 96 | CHECK_INPUT(bottom_grad); 97 | CHECK_INPUT(mask_grad); 98 | at::DeviceGuard guard(top_grad.device()); 99 | 100 | const int batch_size = top_grad.size(0); 101 | const int num_channels = top_grad.size(1); 102 | const int output_height = top_grad.size(2); 103 | const int output_width = top_grad.size(3); 104 | 105 | const int input_height = bottom_grad.size(2); 106 | const int input_width = bottom_grad.size(3); 107 | 108 | const int mask_channels = masks.size(1); 109 | 110 | rtop_grad.resize_({batch_size, output_height, output_width, num_channels}); 111 | rbottom_grad.resize_({batch_size, input_height, input_width, num_channels}); 112 | rbottom_grad_hs.resize_( 113 | {batch_size, output_height, output_width, num_channels}); 114 | rmask_grad.resize_({batch_size, output_height, output_width, mask_channels}); 115 | 116 | CARAFEBackwardLaucher(top_grad, rfeatures, masks, kernel_size, group_size, 117 | scale_factor, batch_size, num_channels, input_height, 118 | input_width, output_height, output_width, mask_channels, 119 | rtop_grad, rbottom_grad_hs, rbottom_grad, rmask_grad, 120 | bottom_grad, mask_grad); 121 | 122 | return 1; 123 | } 124 | 125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 126 | m.def("forward", &carafe_forward_cuda, "carafe forward (CUDA)"); 127 | m.def("backward", &carafe_backward_cuda, "carafe backward (CUDA)"); 128 | } 129 | -------------------------------------------------------------------------------- /tensorfn/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, tanh 2 | from functools import partial 3 | 4 | import torch 5 | 6 | 7 | __all__ = ["ConstantScheduler", "cycle_scheduler", "step_scheduler", "lr_finder"] 8 | 9 | 10 | def anneal_linear(start, end, proportion): 11 | return start + proportion * (end - start) 12 | 13 | 14 | def anneal_cos(start, end, proportion): 15 | cos_val = cos(pi * proportion) + 1 16 | 17 | return end + (start - end) / 2 * cos_val 18 | 19 | 20 | def anneal_cospow(start, end, proportion): 21 | power = 5 22 | 23 | cos_val = 0.5 * (cos(pi * proportion) + 1) + 1 24 | cos_val = power ** cos_val - power 25 | cos_val = cos_val / (power ** 2 - power) 26 | 27 | return end + (start - end) * cos_val 28 | 29 | 30 | def anneal_poly(start, end, proportion, power=0.9): 31 | return (start - end) * (1 - proportion) ** power + end 32 | 33 | 34 | def anneal_tanh(start, end, proportion, lower=-6, upper=3): 35 | return end + (start - end) / 2 * (1 - tanh(lower + (upper - lower) * proportion)) 36 | 37 | 38 | def anneal_flat(start, end, proportion): 39 | return start 40 | 41 | 42 | def anneal_exp(start, end, proportion): 43 | return start * (end / start) ** proportion 44 | 45 | 46 | class ConstantScheduler: 47 | def __init__(self, optimizer): 48 | self.optimizer = optimizer 49 | self.lr = self.optimizer.param_groups[0]["lr"] 50 | 51 | def step(self): 52 | return self.lr 53 | 54 | 55 | class PhaseScheduler: 56 | def __init__(self, optimizer, phases): 57 | self.optimizer = optimizer 58 | 59 | self.phase_param = phases 60 | 61 | self.lr_phase = self.make_phase(phases) 62 | 63 | self.phase = 0 64 | self.phase_step = 0 65 | 66 | self.latest_lr = None 67 | self.loss_log = [] 68 | 69 | def make_phase(self, phases): 70 | phase_map = { 71 | "linear": anneal_linear, 72 | "cos": anneal_cos, 73 | "cospow": anneal_cospow, 74 | "poly": anneal_poly, 75 | "tanh": anneal_tanh, 76 | "exp": anneal_exp, 77 | "flat": anneal_flat, 78 | } 79 | 80 | lr_phase = [] 81 | 82 | for phase in phases: 83 | if len(phase) == 4: 84 | phase_name, lr_from, lr_to, phase_iter = phase 85 | phase_fn = phase_map[phase_name] 86 | 87 | else: 88 | phase_name, lr_from, lr_to, phase_iter, phase_args = phase 89 | phase_fn = partial(phase_map[phase_name], **phase_args) 90 | 91 | lr_phase.append((lr_from, lr_to, phase_iter, phase_fn)) 92 | 93 | return lr_phase 94 | 95 | def state_dict(self): 96 | return { 97 | "phase": self.phase, 98 | "phase_param": self.phase_param, 99 | "phase_step": self.phase_step, 100 | "latest_lr": self.latest_lr, 101 | "loss_log": self.loss_log, 102 | } 103 | 104 | def load_state_dict(self, state_dict): 105 | self.phase_param = state_dict["phase_param"] 106 | self.lr_phase = self.make_phase(state_dict["phase_param"]) 107 | self.phase = state_dict["phase"] 108 | self.phase_step = state_dict["phase_step"] 109 | self.latest_lr = state_dict["latest_lr"] 110 | self.loss_log = state_dict["loss_log"] 111 | 112 | def __repr__(self): 113 | return f"PhaseScheduler(phases={self.lr_phase})" 114 | 115 | def step(self): 116 | if self.phase >= len(self.lr_phase): 117 | return 118 | 119 | # lr = self.lr_phase[self.phase].step() 120 | lr_from, lr_to, phase_iter, phase_fn = self.lr_phase[self.phase] 121 | lr = phase_fn(lr_from, lr_to, self.phase_step / phase_iter) 122 | self.phase_step += 1 123 | 124 | for group in self.optimizer.param_groups: 125 | group["lr"] = lr 126 | 127 | self.latest_lr = lr 128 | 129 | if self.phase_step > phase_iter: 130 | self.phase += 1 131 | self.phase_step = 0 132 | 133 | return lr 134 | 135 | def record_loss(self, loss): 136 | if isinstance(loss, torch.Tensor): 137 | loss = loss.item() 138 | 139 | self.loss_log.append((self.latest_lr, loss)) 140 | 141 | def write_log(self, filename): 142 | with open(filename, "w") as f: 143 | for lr, loss in self.loss_log: 144 | f.write(f"{lr},{loss}\n") 145 | 146 | 147 | def cycle_scheduler( 148 | optimizer, 149 | lr, 150 | n_iter, 151 | initial_multiplier=4e-2, 152 | final_multiplier=1e-5, 153 | warmup=500, 154 | plateau=0, 155 | decay=("cos", "cos"), 156 | ): 157 | phases = [] 158 | 159 | if warmup > 0: 160 | phases.append((decay[0], lr * initial_multiplier, lr, warmup)) 161 | 162 | if plateau > 0: 163 | phases.append(("linear", lr, lr, plateau)) 164 | 165 | phases.append((decay[1], lr, lr * final_multiplier, n_iter - warmup - plateau)) 166 | 167 | return PhaseScheduler(optimizer, phases) 168 | 169 | 170 | def step_scheduler( 171 | optimizer, lr, milestones, gamma=0.1, warmup=0, warmup_multiplier=4e-2 172 | ): 173 | phases = [] 174 | 175 | milestones = milestones.copy() 176 | 177 | steps = 0 178 | 179 | if warmup > 0: 180 | phases.append(("linear", lr * warmup_multiplier, lr, warmup)) 181 | steps += warmup 182 | 183 | current_lr = lr 184 | 185 | for current, forward in zip( 186 | [steps] + milestones, milestones + [milestones[-1] + 1] 187 | ): 188 | phases.append(("linear", current_lr, current_lr, forward - current)) 189 | 190 | current_lr *= gamma 191 | steps = current 192 | 193 | return PhaseScheduler(optimizer, phases) 194 | 195 | 196 | def exp_scheduler( 197 | optimizer, lr, step, max_iter, gamma=0.97, warmup=0, warmup_multiplier=4e-2 198 | ): 199 | milestones = [int(step) * i + warmup - 1 for i in range(1, max_iter)] 200 | 201 | return step_scheduler(optimizer, lr, milestones, gamma, warmup, warmup_multiplier) 202 | 203 | 204 | def lr_finder(optimizer, lr_min, lr_max, n_iter, linear=False): 205 | decay = "linear" if linear else "exp" 206 | 207 | phases = [(decay, lr_min, lr_max, n_iter)] 208 | 209 | return PhaseScheduler(optimizer, phases) 210 | -------------------------------------------------------------------------------- /tensorfn/nn/interpolate_spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | EPS = 1e-10 7 | 8 | 9 | def regular_control(n_control): 10 | control = torch.linspace(-1, 1, n_control) 11 | c1 = control.view(1, n_control).expand(n_control, -1).reshape(-1) 12 | c2 = control.view(n_control, 1).expand(-1, n_control).reshape(-1) 13 | control = torch.stack((c1, c2), -1) 14 | 15 | return control 16 | 17 | 18 | def regular_grid(height, width): 19 | xx, yy = torch.meshgrid(torch.linspace(-1, 1, height), torch.linspace(-1, 1, width)) 20 | 21 | return torch.stack([yy, xx], 2).view(-1, 2) 22 | 23 | 24 | def cross_sq_dist_mat(x, y): 25 | x_norm_sq = x.pow(2).sum(2) 26 | y_norm_sq = y.pow(2).sum(2) 27 | 28 | x_norm_sq_tile = x_norm_sq.unsqueeze(2) 29 | y_norm_sq_tile = y_norm_sq.unsqueeze(1) 30 | 31 | x_y_t = x @ y.transpose(1, 2) 32 | 33 | sq_dist = x_norm_sq_tile - 2 * x_y_t + y_norm_sq_tile 34 | 35 | return sq_dist 36 | 37 | 38 | def pairwise_sq_dist_mat(x): 39 | x_x_t = x @ x.transpose(1, 2) 40 | x_norm_sq = torch.diagonal(x_x_t, dim1=1, dim2=2) 41 | x_norm_sq_tile = x_norm_sq.unsqueeze(2) 42 | 43 | sq_dist = x_norm_sq_tile - 2 * x_x_t + x_norm_sq_tile.transpose(1, 2) 44 | 45 | return sq_dist 46 | 47 | 48 | def phi(r, order): 49 | r = r.clamp(min=EPS) 50 | 51 | if order == 1: 52 | r = torch.sqrt(r) 53 | 54 | return r 55 | 56 | elif order == 2: 57 | return 0.5 * r * torch.log(r) 58 | 59 | elif order == 4: 60 | return 0.5 * r.pow(2) * torch.log(r) 61 | 62 | elif order % 2 == 0: 63 | return 0.5 * r.pow(0.5 * order) * torch.log(r) 64 | 65 | else: 66 | return r.pow(0.5 * order) 67 | 68 | 69 | def solve_interpolation(train_points, train_values, order, regularization_weight): 70 | b, n, d = train_points.shape 71 | k = train_values.shape[2] 72 | 73 | c = train_points 74 | f = train_values 75 | 76 | mat_a = phi(pairwise_sq_dist_mat(c), order) 77 | 78 | if regularization_weight > 0: 79 | batch_eye_mat = torch.eye( 80 | n, dtype=train_points.dtype, device=train_points.device 81 | ).unsqueeze(0) 82 | mat_a += regularization_weight * batch_eye_mat 83 | 84 | ones = c.new_ones(b, n, 1) 85 | mat_b = torch.cat([c, ones], 2) 86 | 87 | left = torch.cat([mat_a, mat_b.transpose(1, 2)], 1) 88 | 89 | n_b_cols = mat_b.shape[2] 90 | lhs_zeros = train_points.new_zeros(b, n_b_cols, n_b_cols) 91 | right = torch.cat([mat_b, lhs_zeros], 1) 92 | lhs = torch.cat([left, right], 2) 93 | 94 | rhs_zeros = train_points.new_zeros(b, d + 1, k) 95 | rhs = torch.cat([f, rhs_zeros], 1) 96 | 97 | w_v = torch.solve(rhs, lhs).solution 98 | w = w_v[:, :n, :] 99 | v = w_v[:, n:, :] 100 | 101 | return w, v 102 | 103 | 104 | def apply_interpolation(query_points, train_points, w, v, order): 105 | pairwise_dist = cross_sq_dist_mat(query_points, train_points) 106 | phi_pairwise_dist = phi(pairwise_dist, order) 107 | 108 | rbf = phi_pairwise_dist @ w 109 | 110 | b, m, d = query_points.shape 111 | 112 | query_points_pad = torch.cat([query_points, query_points.new_ones(b, m, 1)], 2) 113 | linear = query_points_pad @ v 114 | 115 | print(phi_pairwise_dist.shape, query_points_pad.shape, w.shape, v.shape) 116 | 117 | alt1 = torch.cat([phi_pairwise_dist, query_points_pad], 2) 118 | alt2 = torch.cat([w, v], 1) 119 | alt = alt1 @ alt2 120 | 121 | print((rbf + linear - alt).abs().max()) 122 | 123 | return rbf + linear 124 | 125 | 126 | def interpolate_spline( 127 | train_points, train_values, query_points, order=2, regularization_weight=0 128 | ): 129 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight) 130 | query_values = apply_interpolation(query_points, train_points, w, v, order) 131 | 132 | return query_values 133 | 134 | 135 | def solve_interpolation_precomputed(train_points, order, regularization_weight): 136 | b, n, d = train_points.shape 137 | 138 | c = train_points 139 | 140 | mat_a = phi(pairwise_sq_dist_mat(c), order) 141 | 142 | if regularization_weight > 0: 143 | batch_eye_mat = torch.eye( 144 | n, dtype=train_points.dtype, device=train_points.device 145 | ).unsqueeze(0) 146 | mat_a += regularization_weight * batch_eye_mat 147 | 148 | ones = c.new_ones(b, n, 1) 149 | mat_b = torch.cat([c, ones], 2) 150 | 151 | left = torch.cat([mat_a, mat_b.transpose(1, 2)], 1) 152 | 153 | n_b_cols = mat_b.shape[2] 154 | lhs_zeros = train_points.new_zeros(b, n_b_cols, n_b_cols) 155 | right = torch.cat([mat_b, lhs_zeros], 1) 156 | lhs = torch.cat([left, right], 2) 157 | 158 | # w_inv = torch.inverse(lhs) 159 | 160 | return lhs 161 | 162 | 163 | def apply_interpolation_precomputed(phi_pairwise_query_points, train_values_pad, w_inv): 164 | # w = w_inv @ rhs 165 | w = torch.solve(train_values_pad, w_inv).solution 166 | rbf_linear = phi_pairwise_query_points @ w 167 | 168 | return rbf_linear 169 | 170 | 171 | class InterpolateSpline(nn.Module): 172 | def __init__(self, train_points, query_points, order, regularization_weight=0): 173 | super().__init__() 174 | 175 | train_points = train_points.unsqueeze(0) 176 | query_points = query_points.unsqueeze(0) 177 | 178 | kernel = solve_interpolation_precomputed( 179 | train_points, order, regularization_weight 180 | ) 181 | 182 | self.train_values_pad = train_points.shape[2] + 1 183 | 184 | pairwise_dist = cross_sq_dist_mat(query_points, train_points) 185 | phi_pairwise_dist = phi(pairwise_dist, order) 186 | query_points_pad = F.pad(query_points, (0, 1), value=1) 187 | phi_pairwise_query_points = torch.cat([phi_pairwise_dist, query_points_pad], 2) 188 | 189 | self.register_buffer("kernel", kernel) 190 | self.register_buffer("phi_pairwise_query_points", phi_pairwise_query_points) 191 | 192 | def forward(self, train_values): 193 | train_values_pad = F.pad(train_values, (0, 0, 0, self.train_values_pad)) 194 | query_values = apply_interpolation_precomputed( 195 | self.phi_pairwise_query_points, train_values_pad, self.kernel 196 | ) 197 | 198 | return query_values 199 | -------------------------------------------------------------------------------- /tensorfn/nn/carafe/carafe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018-2019 Open-MMLab 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | from torch.autograd import Function 21 | from torch.utils.cpp_extension import load 22 | 23 | from tensorfn.util import LazyExtension 24 | 25 | module_path = os.path.dirname(__file__) 26 | carafe_cuda = LazyExtension( 27 | "carafe", 28 | sources=[ 29 | os.path.join(module_path, "carafe_cuda.cpp"), 30 | os.path.join(module_path, "carafe_cuda_kernel.cu"), 31 | ], 32 | ) 33 | 34 | 35 | class CARAFEFunction(Function): 36 | @staticmethod 37 | def forward(ctx, features, masks, kernel_size, group_size, scale_factor): 38 | assert scale_factor >= 1 39 | assert masks.size(1) == kernel_size * kernel_size * group_size 40 | assert masks.size(-1) == features.size(-1) * scale_factor 41 | assert masks.size(-2) == features.size(-2) * scale_factor 42 | assert features.size(1) % group_size == 0 43 | assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 44 | ctx.kernel_size = kernel_size 45 | ctx.group_size = group_size 46 | ctx.scale_factor = scale_factor 47 | ctx.feature_size = features.size() 48 | ctx.mask_size = masks.size() 49 | 50 | n, c, h, w = features.size() 51 | output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) 52 | routput = features.new_zeros(output.size(), requires_grad=False) 53 | rfeatures = features.new_zeros(features.size(), requires_grad=False) 54 | rmasks = masks.new_zeros(masks.size(), requires_grad=False) 55 | if features.is_cuda: 56 | carafe_cuda.get().forward( 57 | features, 58 | rfeatures, 59 | masks, 60 | rmasks, 61 | kernel_size, 62 | group_size, 63 | scale_factor, 64 | routput, 65 | output, 66 | ) 67 | else: 68 | raise NotImplementedError 69 | 70 | if features.requires_grad or masks.requires_grad: 71 | ctx.save_for_backward(features, masks, rfeatures) 72 | return output 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | assert grad_output.is_cuda 77 | 78 | features, masks, rfeatures = ctx.saved_tensors 79 | kernel_size = ctx.kernel_size 80 | group_size = ctx.group_size 81 | scale_factor = ctx.scale_factor 82 | 83 | rgrad_output = torch.zeros_like(grad_output, requires_grad=False) 84 | rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False) 85 | rgrad_input = torch.zeros_like(features, requires_grad=False) 86 | rgrad_masks = torch.zeros_like(masks, requires_grad=False) 87 | grad_input = torch.zeros_like(features, requires_grad=False) 88 | grad_masks = torch.zeros_like(masks, requires_grad=False) 89 | carafe_cuda.get().backward( 90 | grad_output.contiguous(), 91 | rfeatures, 92 | masks, 93 | kernel_size, 94 | group_size, 95 | scale_factor, 96 | rgrad_output, 97 | rgrad_input_hs, 98 | rgrad_input, 99 | rgrad_masks, 100 | grad_input, 101 | grad_masks, 102 | ) 103 | return grad_input, grad_masks, None, None, None, None 104 | 105 | 106 | carafe_fn = CARAFEFunction.apply 107 | 108 | 109 | def carafe(features, masks, kernel_size, group_size, scale_factor): 110 | return carafe_fn(features, masks, kernel_size, group_size, scale_factor) 111 | 112 | 113 | class CARAFE(nn.Module): 114 | def __init__( 115 | self, 116 | channels, 117 | scale_factor, 118 | up_kernel=5, 119 | up_group=1, 120 | encoder_kernel=3, 121 | encoder_dilation=1, 122 | compressed_channels=64, 123 | ): 124 | super().__init__() 125 | 126 | self.channels = channels 127 | self.scale_factor = scale_factor 128 | self.up_kernel = up_kernel 129 | self.up_group = up_group 130 | self.encoder_kernel = encoder_kernel 131 | self.encoder_dilation = encoder_dilation 132 | self.compressed_channels = compressed_channels 133 | self.channel_compressor = nn.Conv2d(channels, self.compressed_channels, 1) 134 | self.content_encoder = nn.Conv2d( 135 | self.compressed_channels, 136 | self.up_kernel 137 | * self.up_kernel 138 | * self.up_group 139 | * self.scale_factor 140 | * self.scale_factor, 141 | self.encoder_kernel, 142 | padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), 143 | dilation=self.encoder_dilation, 144 | groups=1, 145 | ) 146 | self.init_weights() 147 | 148 | def init_weights(self): 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.xavier_uniform_(m.weight) 152 | nn.init.zeros_(m.bias) 153 | 154 | nn.init.normal_(self.content_encoder.weight, std=0.001) 155 | nn.init.zeros_(self.content_encoder.bias) 156 | 157 | def kernel_normalizer(self, mask): 158 | mask = F.pixel_shuffle(mask, self.scale_factor) 159 | n, mask_c, h, w = mask.size() 160 | mask_channel = int(mask_c / (self.up_kernel * self.up_kernel)) 161 | mask = mask.view(n, mask_channel, -1, h, w) 162 | 163 | mask = F.softmax(mask, dim=2) 164 | mask = mask.view(n, mask_c, h, w).contiguous() 165 | 166 | return mask 167 | 168 | def feature_reassemble(self, x, mask): 169 | x = carafe_fn(x, mask, self.up_kernel, self.up_group, self.scale_factor) 170 | return x 171 | 172 | def forward(self, x): 173 | compressed_x = self.channel_compressor(x) 174 | mask = self.content_encoder(compressed_x) 175 | mask = self.kernel_normalizer(mask) 176 | 177 | x = self.feature_reassemble(x, mask) 178 | return x 179 | -------------------------------------------------------------------------------- /tensorfn/optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | # Shamelessly took from here: 2 | 3 | """ RMSProp modified to behave like Tensorflow impl 4 | 5 | Originally cut & paste from PyTorch RMSProp 6 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py 7 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE 8 | 9 | Modifications Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch.optim import Optimizer 14 | 15 | 16 | class RMSpropTF(Optimizer): 17 | """Implements RMSprop algorithm (TensorFlow style epsilon) 18 | 19 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 20 | and a few other modifications to closer match Tensorflow for matching hyper-params. 21 | 22 | Noteworthy changes include: 23 | 1. Epsilon applied inside square-root 24 | 2. square_avg initialized to ones 25 | 3. LR scaling of update accumulated in momentum buffer 26 | 27 | Proposed by G. Hinton in his 28 | `course `_. 29 | 30 | The centered version first appears in `Generating Sequences 31 | With Recurrent Neural Networks `_. 32 | 33 | Arguments: 34 | params (iterable): iterable of parameters to optimize or dicts defining 35 | parameter groups 36 | lr (float, optional): learning rate (default: 1e-2) 37 | momentum (float, optional): momentum factor (default: 0) 38 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 39 | eps (float, optional): term added to the denominator to improve 40 | numerical stability (default: 1e-10) 41 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 42 | the gradient is normalized by an estimation of its variance 43 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 44 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 45 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 46 | update as per defaults in Tensorflow 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | params, 53 | lr=1e-2, 54 | alpha=0.9, 55 | eps=1e-10, 56 | weight_decay=0, 57 | momentum=0.0, 58 | centered=False, 59 | decoupled_decay=False, 60 | lr_in_momentum=True, 61 | ): 62 | if not 0.0 <= lr: 63 | raise ValueError("Invalid learning rate: {}".format(lr)) 64 | if not 0.0 <= eps: 65 | raise ValueError("Invalid epsilon value: {}".format(eps)) 66 | if not 0.0 <= momentum: 67 | raise ValueError("Invalid momentum value: {}".format(momentum)) 68 | if not 0.0 <= weight_decay: 69 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 70 | if not 0.0 <= alpha: 71 | raise ValueError("Invalid alpha value: {}".format(alpha)) 72 | 73 | defaults = dict( 74 | lr=lr, 75 | momentum=momentum, 76 | alpha=alpha, 77 | eps=eps, 78 | centered=centered, 79 | weight_decay=weight_decay, 80 | decoupled_decay=decoupled_decay, 81 | lr_in_momentum=lr_in_momentum, 82 | ) 83 | super(RMSpropTF, self).__init__(params, defaults) 84 | 85 | def __setstate__(self, state): 86 | super(RMSpropTF, self).__setstate__(state) 87 | for group in self.param_groups: 88 | group.setdefault("momentum", 0) 89 | group.setdefault("centered", False) 90 | 91 | def step(self, closure=None): 92 | """Performs a single optimization step. 93 | 94 | Arguments: 95 | closure (callable, optional): A closure that reevaluates the model 96 | and returns the loss. 97 | """ 98 | loss = None 99 | if closure is not None: 100 | loss = closure() 101 | 102 | for group in self.param_groups: 103 | for p in group["params"]: 104 | if p.grad is None: 105 | continue 106 | grad = p.grad.data 107 | if grad.is_sparse: 108 | raise RuntimeError("RMSprop does not support sparse gradients") 109 | state = self.state[p] 110 | 111 | # State initialization 112 | if len(state) == 0: 113 | state["step"] = 0 114 | state["square_avg"] = torch.ones_like( 115 | p.data 116 | ) # PyTorch inits to zero 117 | if group["momentum"] > 0: 118 | state["momentum_buffer"] = torch.zeros_like(p.data) 119 | if group["centered"]: 120 | state["grad_avg"] = torch.zeros_like(p.data) 121 | 122 | square_avg = state["square_avg"] 123 | one_minus_alpha = 1.0 - group["alpha"] 124 | 125 | state["step"] += 1 126 | 127 | if group["weight_decay"] != 0: 128 | if "decoupled_decay" in group and group["decoupled_decay"]: 129 | p.data.add_(-group["weight_decay"], p.data) 130 | else: 131 | grad = grad.add(p.data, alpha=group["weight_decay"]) 132 | 133 | # Tensorflow order of ops for updating squared avg 134 | square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) 135 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 136 | 137 | if group["centered"]: 138 | grad_avg = state["grad_avg"] 139 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 140 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 141 | avg = ( 142 | square_avg.addcmul(-1, grad_avg, grad_avg) 143 | .add(group["eps"]) 144 | .sqrt_() 145 | ) # eps moved in sqrt 146 | else: 147 | avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt 148 | 149 | if group["momentum"] > 0: 150 | buf = state["momentum_buffer"] 151 | # Tensorflow accumulates the LR scaling in the momentum buffer 152 | if "lr_in_momentum" in group and group["lr_in_momentum"]: 153 | buf.mul_(group["momentum"]).addcdiv_( 154 | grad, avg, value=group["lr"] 155 | ) 156 | p.data.add_(-buf) 157 | else: 158 | # PyTorch scales the param update by LR 159 | buf.mul_(group["momentum"]).addcdiv_(grad, avg) 160 | p.data.add_(-group["lr"], buf) 161 | else: 162 | p.data.addcdiv_(grad, avg, value=-group["lr"]) 163 | 164 | return loss 165 | -------------------------------------------------------------------------------- /tensorfn/checker/backend.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import io 3 | import os 4 | import functools 5 | import sys 6 | import re 7 | 8 | import torch 9 | from termcolor import colored 10 | from tabulate import tabulate 11 | 12 | try: 13 | import boto3 14 | from tqdm import tqdm 15 | 16 | except ImportError: 17 | boto3 = None 18 | 19 | 20 | from tensorfn import distributed as dist, get_logger, nsml 21 | 22 | 23 | def torch_serialize(obj): 24 | buf = io.BytesIO() 25 | torch.save(obj, buf) 26 | buf.seek(0) 27 | 28 | return buf.read() 29 | 30 | 31 | class Storage: 32 | def __init__(self, keep=-1): 33 | self.keep = keep 34 | self._saved_checkpoints = [] 35 | self._saved_checkpoints_value = [] 36 | 37 | def checkpoint(self, obj, path, value=None): 38 | if value is not None: 39 | exps = re.findall("<.+?>", path) 40 | path = path 41 | for exp in exps: 42 | if "value" in exp: 43 | path = path.replace(exp, f"{{{exp[1:-1]}}}", 1) 44 | path = path.format(value=value) 45 | 46 | keep = self.keep - 1 47 | 48 | if self.keep > 0: 49 | if len(self._saved_checkpoints) > keep: 50 | for head in self._saved_checkpoints[:-keep]: 51 | self._remove(head) 52 | 53 | self._saved_checkpoints = self._saved_checkpoints[-keep:] 54 | 55 | if len(self._saved_checkpoints_value) > keep: 56 | sorted_k = sorted( 57 | enumerate(self._saved_checkpoints_value), key=lambda x: x[1][1] 58 | ) 59 | bottom_k = sorted_k[:-keep] 60 | 61 | for _, (bottom, _) in bottom_k: 62 | self._remove(bottom) 63 | 64 | updated = [] 65 | keep_ids = [i[0] for i in sorted_k[-keep:]] 66 | for i, record in enumerate(self._saved_checkpoints_value): 67 | if i in keep_ids: 68 | updated.append(record) 69 | 70 | self._saved_checkpoints_value = updated 71 | 72 | binary = torch_serialize(obj) 73 | self.save(binary, path) 74 | 75 | if value is None: 76 | self._saved_checkpoints.append(path) 77 | 78 | else: 79 | self._saved_checkpoints_value.append((path, value)) 80 | 81 | def get_directory(self, path): 82 | # dup = len(self.list(path)) + 1 83 | # path = f"{path}/{str(dup).zfill(5)}" 84 | key = datetime.now().astimezone().isoformat().replace(":", ".") 85 | path = f"{path}/{key}" 86 | 87 | return path 88 | 89 | 90 | class Local(Storage): 91 | def __init__(self, path, keep=-1): 92 | super().__init__(keep) 93 | 94 | root, child = os.path.split(path) 95 | if root == "": 96 | root = "." 97 | 98 | path = os.path.join(root, child) 99 | 100 | self.path = self.get_directory(path) 101 | 102 | def list(self, path): 103 | try: 104 | dirs = os.listdir(path) 105 | 106 | except FileNotFoundError: 107 | dirs = [] 108 | 109 | return dirs 110 | 111 | def save(self, data, name): 112 | if isinstance(data, bytes): 113 | flag = "wb" 114 | 115 | else: 116 | flag = "w" 117 | 118 | target_path = os.path.join(self.path, name) 119 | 120 | os.makedirs(os.path.split(target_path)[0], exist_ok=True) 121 | 122 | with open(target_path, flag) as f: 123 | f.write(data) 124 | 125 | def load(self, name): 126 | pass 127 | 128 | 129 | class NSMLV2: 130 | def __init__(self, path): 131 | self.root_path = path 132 | self.name = os.environ["NSML_RUN_NAME"] 133 | self.prev_run = os.getenv("NSML_PREV_RUN_NAME", "") 134 | self.rerun_count = int(os.getenv("NSML_RERUN_COUNT", "0")) 135 | self.path = os.path.join(self.root_path, self.name.replace("/", "_")) 136 | 137 | def resume(self, name="resume.pt"): 138 | if self.rerun_count == 0: 139 | return None 140 | 141 | target_path = os.path.join( 142 | self.root_path, self.prev_run.replace("/", "_"), name 143 | ) 144 | 145 | if not os.path.exists(target_path): 146 | return None 147 | 148 | return torch.load(target_path) 149 | 150 | def save(self, data, name): 151 | if isinstance(data, bytes): 152 | flag = "wb" 153 | 154 | else: 155 | flag = "w" 156 | 157 | target_path = os.path.join(self.path, name) 158 | 159 | os.makedirs(os.path.split(target_path)[0], exist_ok=True) 160 | 161 | with open(target_path, flag) as f: 162 | f.write(data) 163 | 164 | 165 | def progress_callback(pbar): 166 | def wrap(bytes_amount): 167 | pbar.update(bytes_amount) 168 | 169 | return wrap 170 | 171 | 172 | class S3(Storage): 173 | def __init__( 174 | self, 175 | bucket, 176 | path, 177 | access_key, 178 | secret_key, 179 | keep=-1, 180 | endpoint=None, 181 | show_progress=True, 182 | ): 183 | super().__init__(keep) 184 | 185 | if boto3 is None: 186 | raise ImportError("boto3 should be installed for S3 storage") 187 | 188 | self.s3 = boto3.client( 189 | "s3", 190 | aws_access_key_id=access_key, 191 | aws_secret_access_key=secret_key, 192 | endpoint_url=endpoint, 193 | ) 194 | self.bucket = bucket 195 | self.path = self.get_directory(path) 196 | self.show_progress = show_progress 197 | 198 | def list(self, path): 199 | if path[-1] != "/": 200 | path += "/" 201 | 202 | resp = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=path, Delimiter="/") 203 | 204 | try: 205 | prefixes = [] 206 | 207 | for prefix in resp["CommonPrefixes"]: 208 | prefixes.append(prefix["Prefix"]) 209 | 210 | except KeyError: 211 | prefixes = [] 212 | 213 | return prefixes 214 | 215 | def save(self, data, name): 216 | buf = io.BytesIO(data) 217 | size = len(data) 218 | 219 | self._save(buf, name, size) 220 | 221 | def _remove(self, name): 222 | target_path = f"{self.path}/{name}" 223 | 224 | self.s3.delete_object(Bucket=self.bucket, Key=target_path) 225 | 226 | def _save(self, buf, name, size): 227 | target_path = f"{self.path}/{name}" 228 | 229 | if self.show_progress: 230 | with tqdm(total=size, unit="B", unit_scale=True, desc=target_path) as pbar: 231 | self.s3.upload_fileobj( 232 | buf, self.bucket, target_path, Callback=progress_callback(pbar) 233 | ) 234 | 235 | else: 236 | self.s3.upload_fileobj(buf, self.bucket, target_path) 237 | 238 | 239 | def get_decimal(value): 240 | for i in range(10): 241 | if value >= 10 ** (-i) - 1e-10: 242 | return i 243 | 244 | return 10 245 | 246 | 247 | def default_formatter(step, **kwargs): 248 | panels = [f"step: {step}"] 249 | 250 | for k, v in kwargs.items(): 251 | if isinstance(v, float): 252 | decimal = get_decimal(v) + 2 253 | v = round(v, decimal) 254 | panels.append(f"{k}: {v}") 255 | 256 | else: 257 | panels.append(f"{k}: {v}") 258 | 259 | return "; ".join(panels) 260 | 261 | 262 | class Logger: 263 | def __init__(self, formatter=None): 264 | if formatter is None: 265 | formatter = default_formatter 266 | 267 | self.logger = get_logger() 268 | self.formatter = formatter 269 | 270 | def log(self, step, **kwargs): 271 | self.logger.info(self.formatter(step, **kwargs)) 272 | 273 | 274 | class WandB: 275 | def __init__( 276 | self, 277 | project, 278 | group=None, 279 | name=None, 280 | notes=None, 281 | resume=None, 282 | tags=None, 283 | id=None, 284 | ): 285 | if dist.is_primary(): 286 | import wandb 287 | 288 | wandb.init( 289 | project=project, 290 | group=group, 291 | name=name, 292 | notes=notes, 293 | resume=resume, 294 | tags=tags, 295 | id=id, 296 | ) 297 | 298 | self.wandb = wandb 299 | 300 | def log(self, step, **kwargs): 301 | self.wandb.log(kwargs, step=step) 302 | 303 | def __del__(self): 304 | if dist.is_primary(): 305 | self.wandb.finish() 306 | 307 | 308 | class NSML: 309 | def log(self, step, **kwargs): 310 | nsml.report(summary=True, step=step, **kwargs) 311 | -------------------------------------------------------------------------------- /tensorfn/util/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Action 3 | import os 4 | from pprint import pprint 5 | import sys 6 | import json 7 | import uuid 8 | 9 | import torch 10 | from torch.distributed.elastic.multiprocessing import Std 11 | from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config 12 | from torch.distributed.launcher.api import LaunchConfig 13 | from pyhocon import ConfigFactory, ConfigTree 14 | 15 | try: 16 | import _jsonnet 17 | except ImportError: 18 | _jsonnet = None 19 | 20 | from tensorfn.distributed import is_primary 21 | 22 | 23 | def read_config(config_file, overrides=(), strict=False): 24 | if config_file.endswith(".jsonnet"): 25 | json_str = _jsonnet.evaluate_file(config_file) 26 | json_obj = json.loads(json_str) 27 | conf = ConfigFactory.from_dict(json_obj) 28 | 29 | elif config_file.endswith(".py"): 30 | from tensorfn.config.builder import PyConfig 31 | 32 | conf = ConfigFactory.from_dict(PyConfig.load(config_file)) 33 | 34 | else: 35 | conf = ConfigFactory.parse_file(config_file) 36 | 37 | if len(overrides) > 0: 38 | for override in overrides: 39 | conf_overrides = ConfigFactory.parse_string(override) 40 | conf = ConfigTree.merge_configs(conf, conf_overrides) 41 | 42 | return conf.as_plain_ordered_dict() 43 | 44 | 45 | def preset_argparser(elastic=False): 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument("--conf", type=str, required=True) 49 | parser.add_argument("--ckpt", type=str) 50 | 51 | if elastic: 52 | parser = add_elastic_args(parser) 53 | 54 | else: 55 | parser = add_distributed_args(parser) 56 | 57 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 58 | 59 | return parser 60 | 61 | 62 | def add_distributed_args(parser): 63 | parser.add_argument("--n_gpu", type=int, default=1) 64 | parser.add_argument("--n_machine", type=int, default=1) 65 | parser.add_argument("--machine_rank", type=int, default=0) 66 | 67 | port = ( 68 | 2 ** 15 69 | + 2 ** 14 70 | + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 71 | ) 72 | parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}") 73 | 74 | return parser 75 | 76 | 77 | class env(Action): 78 | def __init__(self, dest, default=None, required=False, **kwargs): 79 | env_name = f"PET_{dest.upper()}" 80 | default = os.environ.get(env_name, default) 81 | 82 | if default: 83 | required = False 84 | 85 | super().__init__(dest=dest, default=default, required=required, **kwargs) 86 | 87 | def __call__(self, parser, namespace, values, option_string=None): 88 | setattr(namespace, self.dest, values) 89 | 90 | 91 | class check_env(Action): 92 | def __init__(self, dest, default=None, **kwargs): 93 | env_name = f"PET_{dest.upper()}" 94 | default = bool(int(os.environ.get(env_name, "1" if default else "0"))) 95 | 96 | super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs) 97 | 98 | def __call__(self, parser, namespace, values, option_string=None): 99 | setattr(namespace, self.dest, self.const) 100 | 101 | 102 | def parse_min_max_nodes(n_node): 103 | ar = n_node.split(":") 104 | 105 | if len(ar) == 1: 106 | min_node = max_node = int(ar[0]) 107 | 108 | elif len(ar) == 2: 109 | min_node, max_node = int(ar[0]), int(ar[1]) 110 | 111 | else: 112 | raise ValueError(f'n_node={n_node} is not in "MIN:MAX" format') 113 | 114 | return min_node, max_node 115 | 116 | 117 | def local_world_size(n_gpu): 118 | try: 119 | return int(n_gpu) 120 | 121 | except ValueError: 122 | if n_gpu == "cpu": 123 | n_proc = os.cpu_count() 124 | 125 | elif n_gpu == "gpu": 126 | if not torch.cuda.is_available(): 127 | raise ValueError("CUDA is not available") 128 | 129 | n_proc = torch.cuda.device_count() 130 | 131 | elif n_gpu == "auto": 132 | if torch.cuda.is_available(): 133 | n_proc = torch.cuda.device_count() 134 | 135 | else: 136 | n_proc = os.cpu_count() 137 | 138 | else: 139 | raise ValueError(f"Unsupported n_proc value: {n_gpu}") 140 | 141 | return n_proc 142 | 143 | 144 | def find_free_port(): 145 | import socket 146 | 147 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 148 | 149 | sock.bind(("", 0)) 150 | port = sock.getsockname()[1] 151 | sock.close() 152 | 153 | return port 154 | 155 | 156 | def get_rdzv_endpoint(args, max_node): 157 | if args.rdzv_backend == "static" and not args.rdzv_endpoint: 158 | dist_url = args.dist_url 159 | 160 | if dist_url == "auto": 161 | if max_node != 1: 162 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 163 | 164 | port = find_free_port() 165 | dist_url = f"127.0.0.1:{port}" 166 | 167 | return dist_url 168 | 169 | return args.rdzv_endpoint 170 | 171 | 172 | def elastic_config(args): 173 | min_node, max_node = parse_min_max_nodes(args.n_node) 174 | n_proc = local_world_size(args.n_proc) 175 | 176 | rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) 177 | 178 | if args.rdzv_backend == "static": 179 | rdzv_configs["rank"] = args.node_rank 180 | 181 | rdzv_endpoint = get_rdzv_endpoint(args, max_node) 182 | 183 | config = LaunchConfig( 184 | min_nodes=min_node, 185 | max_nodes=max_node, 186 | nproc_per_node=n_proc, 187 | run_id=args.rdzv_id, 188 | role=args.role, 189 | rdzv_endpoint=rdzv_endpoint, 190 | rdzv_backend=args.rdzv_backend, 191 | rdzv_configs=rdzv_configs, 192 | max_restarts=args.max_restarts, 193 | monitor_interval=args.monitor_interval, 194 | start_method=args.start_method, 195 | redirects=Std.from_str(args.redirects), 196 | tee=Std.from_str(args.tee), 197 | log_dir=args.log_dir, 198 | ) 199 | 200 | return config 201 | 202 | 203 | def add_elastic_args(parser): 204 | parser.add_argument("--n_proc", type=str, default="1") 205 | parser.add_argument("--n_node", type=str, default="1:1") 206 | parser.add_argument("--node_rank", type=int, default=0) 207 | 208 | port = ( 209 | 2 ** 15 210 | + 2 ** 14 211 | + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 212 | ) 213 | parser.add_argument("--dist_url", default=f"127.0.0.1:{port}") 214 | 215 | parser.add_argument("--rdzv_backend", action=env, type=str, default="static") 216 | parser.add_argument( 217 | "--rdzv_endpoint", 218 | action=env, 219 | type=str, 220 | default="", 221 | help="Rendezvous backend endpoint; usually in form :.", 222 | ) 223 | parser.add_argument( 224 | "--rdzv_id", action=env, type=str, default="none", help="User-defined group id." 225 | ) 226 | parser.add_argument( 227 | "--rdzv_conf", 228 | action=env, 229 | type=str, 230 | default="", 231 | help="Additional rendezvous configuration (=, =, ...).", 232 | ) 233 | parser.add_argument( 234 | "--standalone", 235 | action=check_env, 236 | help="Start a local standalone rendezvous backend", 237 | ) 238 | 239 | parser.add_argument("--max_restarts", action=env, type=int, default=0) 240 | parser.add_argument("--monitor_interval", action=env, type=float, default=5) 241 | parser.add_argument( 242 | "--start_method", 243 | action=env, 244 | type=str, 245 | default="spawn", 246 | choices=["spawn", "fork", "forkserver"], 247 | ) 248 | parser.add_argument("--role", action=env, type=str, default="default") 249 | parser.add_argument("--log_dir", action=env, type=str, default=None) 250 | parser.add_argument("-r", "--redirects", action=env, type=str, default="0") 251 | parser.add_argument("-t", "--tee", action=env, type=str, default="0") 252 | 253 | return parser 254 | 255 | 256 | def load_config(config_model, config, overrides=(), show=True): 257 | conf = config_model(**read_config(config, overrides=overrides)) 258 | 259 | if show and is_primary(): 260 | pprint(conf.dict()) 261 | 262 | return conf 263 | 264 | 265 | def load_arg_config(config_model, show=False, elastic=False, parser=None): 266 | if parser is None: 267 | parser = preset_argparser(elastic=elastic) 268 | 269 | args = parser.parse_args() 270 | 271 | conf = load_config(config_model, args.conf, args.opts, show) 272 | 273 | if elastic: 274 | if args.standalone: 275 | args.rdzv_backend = "c10d" 276 | args.rdzv_endpoint = "localhost:29400" 277 | args.rdzv_id = str(uuid.uuid4()) 278 | 279 | launch_config = elastic_config(args) 280 | 281 | conf.n_gpu = launch_config.nproc_per_node 282 | conf.n_machine = launch_config.max_nodes 283 | conf.machine_rank = args.node_rank 284 | conf.dist_url = args.dist_url 285 | conf.ckpt = args.ckpt 286 | conf.launch_config = launch_config 287 | 288 | else: 289 | conf.n_gpu = args.n_gpu 290 | conf.n_machine = args.n_machine 291 | conf.machine_rank = args.machine_rank 292 | conf.dist_url = args.dist_url 293 | conf.ckpt = args.ckpt 294 | 295 | return conf 296 | -------------------------------------------------------------------------------- /tensorfn/config/builder.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import builtins 3 | import os 4 | import pydoc 5 | import uuid 6 | import importlib 7 | from contextlib import contextmanager 8 | from collections.abc import Mapping, Sequence 9 | from typing import Union, Tuple 10 | 11 | from tensorfn.config.config import resolve_module 12 | 13 | CFG_PACKAGE_NAME = "tensorfn._conf_loader" 14 | 15 | 16 | def str_to_import(name): 17 | obj = pydoc.locate(name) 18 | 19 | if obj is None: 20 | obj = resolve_module(name) 21 | 22 | return obj 23 | 24 | 25 | def validate_syntax(filename): 26 | with open(filename) as f: 27 | code = f.read() 28 | 29 | try: 30 | ast.parse(code) 31 | 32 | except SyntaxError as e: 33 | raise SyntaxError(f"{filename} has syntax error") from e 34 | 35 | 36 | def random_package_name(filename): 37 | # generate a random package name when loading config files 38 | return CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) 39 | 40 | 41 | def import_to_str(obj): 42 | module, qualname = obj.__module__, obj.__qualname__ 43 | 44 | module_parts = module.split(".") 45 | 46 | for i in range(1, len(module_parts)): 47 | prefix = ".".join(module_parts[:i]) 48 | candid = f"{prefix}.{qualname}" 49 | 50 | try: 51 | if str_to_import(candid) is obj: 52 | return candid 53 | 54 | except ImportError: 55 | pass 56 | 57 | return f"{module}.{qualname}" 58 | 59 | 60 | def build(__key, __name, *args, **kwargs): 61 | node = {__key: __name} 62 | 63 | if len(args) > 0: 64 | node["__args"] = list(args) 65 | 66 | node = {**node, **kwargs} 67 | 68 | return node 69 | 70 | 71 | def build_init(__name, *args, **kwargs): 72 | return build("__init", __name, *args, **kwargs) 73 | 74 | 75 | def build_fn(__name, *args, **kwargs): 76 | return build("__fn", __name, *args, **kwargs) 77 | 78 | 79 | class Init: 80 | def __init__(self, name, fn=False, key=None): 81 | self.name = name 82 | self.fn = fn 83 | self.key = key 84 | 85 | def __call__(self, *args, **kwargs): 86 | if self.fn: 87 | return build_fn(self.name, *args, **kwargs) 88 | 89 | res = build_init(self.name, *args, **kwargs) 90 | if self.key is not None: 91 | res["__key"] = self.key 92 | 93 | return res 94 | 95 | 96 | class Single: 97 | def __init__(self): 98 | self.counter = 0 99 | 100 | def __getitem__(self, obj): 101 | fn = False 102 | 103 | if not isinstance(obj, str): 104 | obj = import_to_str(obj) 105 | 106 | key = f"{obj}#{self.counter}" 107 | self.counter += 1 108 | 109 | return Init(obj, fn, key=key) 110 | 111 | 112 | class LazyCall: 113 | def __getitem__(self, obj): 114 | fn = False 115 | 116 | if isinstance(obj, tuple): 117 | obj, fn = obj 118 | 119 | if not isinstance(obj, str): 120 | obj = import_to_str(obj) 121 | 122 | return Init(obj, fn) 123 | 124 | 125 | class LazyFn: 126 | def __getitem__(self, obj): 127 | if not isinstance(obj, str): 128 | obj = import_to_str(obj) 129 | 130 | return Init(obj, True) 131 | 132 | 133 | @contextmanager 134 | def patch_import(): 135 | """ 136 | Enhance relative import statements in config files, so that they: 137 | 1. locate files purely based on relative location, regardless of packages. 138 | e.g. you can import file without having __init__ 139 | 2. do not cache modules globally; modifications of module states has no side effect 140 | 3. support other storage system through PathManager 141 | 4. imported dict are turned into omegaconf.DictConfig automatically 142 | """ 143 | old_import = builtins.__import__ 144 | 145 | def find_relative_file(original_file, relative_import_path, level): 146 | cur_file = os.path.dirname(original_file) 147 | 148 | for _ in range(level - 1): 149 | cur_file = os.path.dirname(cur_file) 150 | 151 | cur_name = relative_import_path.lstrip(".") 152 | 153 | for part in cur_name.split("."): 154 | cur_file = os.path.join(cur_file, part) 155 | 156 | # NOTE: directory import is not handled. Because then it's unclear 157 | # if such import should produce python module or DictConfig. This can 158 | # be discussed further if needed. 159 | if not cur_file.endswith(".py"): 160 | cur_file += ".py" 161 | 162 | if not os.path.isfile(cur_file): 163 | raise ImportError( 164 | f"cannot import name {relative_import_path} from " 165 | f"{original_file}: {cur_file} has to exist" 166 | ) 167 | 168 | return cur_file 169 | 170 | def new_import(name, globals=None, locals=None, fromlist=(), level=0): 171 | if ( 172 | # Only deal with relative imports inside config files 173 | level != 0 174 | and globals is not None 175 | and (globals.get("__package__", "") or "").startswith(CFG_PACKAGE_NAME) 176 | ): 177 | cur_file = find_relative_file(globals["__file__"], name, level) 178 | validate_syntax(cur_file) 179 | spec = importlib.machinery.ModuleSpec( 180 | random_package_name(cur_file), None, origin=cur_file 181 | ) 182 | module = importlib.util.module_from_spec(spec) 183 | module.__file__ = cur_file 184 | 185 | with open(cur_file) as f: 186 | content = f.read() 187 | exec(compile(content, cur_file, "exec"), module.__dict__) 188 | 189 | # for name in fromlist: # turn imported dict into DictConfig automatically 190 | # val = _cast_to_config(module.__dict__[name]) 191 | # module.__dict__[name] = val 192 | 193 | return module 194 | 195 | return old_import(name, globals, locals, fromlist=fromlist, level=level) 196 | 197 | builtins.__import__ = new_import 198 | yield new_import 199 | builtins.__import__ = old_import 200 | 201 | 202 | class PyConfig: 203 | @staticmethod 204 | def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): 205 | """ 206 | Load a config file. 207 | Args: 208 | filename: absolute path or relative path w.r.t. the current working directory 209 | keys: keys to load and return. If not given, return all keys 210 | (whose values are config objects) in a dict. 211 | """ 212 | has_keys = keys is not None 213 | filename = filename.replace("/./", "/") # redundant 214 | 215 | if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: 216 | raise ValueError(f"Config file {filename} has to be a python or yaml file.") 217 | 218 | if filename.endswith(".py"): 219 | validate_syntax(filename) 220 | 221 | with patch_import(): 222 | # Record the filename 223 | module_namespace = { 224 | "__file__": filename, 225 | "__package__": random_package_name(filename), 226 | } 227 | with open(filename) as f: 228 | content = f.read() 229 | # Compile first with filename to: 230 | # 1. make filename appears in stacktrace 231 | # 2. make load_rel able to find its parent's (possibly remote) location 232 | exec(compile(content, filename, "exec"), module_namespace) 233 | 234 | ret = module_namespace 235 | 236 | ret = ret["conf"].to_dict() 237 | 238 | if has_keys: 239 | return tuple(ret[a] for a in keys) 240 | 241 | return ret 242 | 243 | 244 | def unfold_field(x): 245 | if isinstance(x, Sequence) and not isinstance(x, str): 246 | return [unfold_field(i) for i in x] 247 | 248 | if isinstance(x, Mapping): 249 | res = {} 250 | 251 | for k, v in x.items(): 252 | res[k] = unfold_field(v) 253 | 254 | return res 255 | 256 | return x 257 | 258 | 259 | class Field(dict): 260 | def __init__(self, *args, **kwargs): 261 | self.update(*args, **kwargs) 262 | 263 | def __getattr__(self, key): 264 | try: 265 | return object.__getattribute__(self, key) 266 | 267 | except AttributeError: 268 | try: 269 | return self[key] 270 | 271 | except KeyError: 272 | raise AttributeError(key) 273 | 274 | def __setattr__(self, key, value): 275 | try: 276 | object.__getattribute__(self, key) 277 | 278 | except AttributeError: 279 | try: 280 | self[key] = value 281 | 282 | except: 283 | raise AttributeError(key) 284 | 285 | else: 286 | object.__setattr__(self, key, value) 287 | 288 | def __delattr__(self, key): 289 | try: 290 | object.__getattribute__(self, key) 291 | 292 | except AttributeError: 293 | try: 294 | del self[key] 295 | 296 | except KeyError: 297 | raise AttributeError(key) 298 | 299 | else: 300 | object.__delattr__(self, key) 301 | 302 | def __repr__(self): 303 | return f"{self.__class__.__name__}({dict.__repr__(self)})" 304 | 305 | def to_dict(self): 306 | return unfold_field(self) 307 | 308 | 309 | L = LazyCall() 310 | F = LazyFn() 311 | field = Field 312 | single = Single() 313 | -------------------------------------------------------------------------------- /tensorfn/vision/autoaugment.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image, ImageOps, ImageEnhance 4 | 5 | from tensorfn.vision import transforms 6 | from tensorfn.vision.transforms import check_prob, PIL_INTER_MAP, RandomTransform 7 | 8 | 9 | def rescale_float(level, max_val, param_max=10): 10 | return float(level) * max_val / param_max 11 | 12 | 13 | def rescale_int(level, max_val, param_max=10): 14 | return int(level * max_val / param_max) 15 | 16 | 17 | class AutoAugmentAffine(RandomTransform): 18 | def __init__(self, mirror=True, resample=Image.NEAREST, p=1.0): 19 | super().__init__(p) 20 | 21 | self.mirror = mirror 22 | self.resample = resample 23 | 24 | def _mirror(self, val): 25 | if self.mirror and check_prob(0.5): 26 | val *= -1 27 | 28 | return val 29 | 30 | def _repr_params(self): 31 | params = dict(self.__dict__) 32 | params["resample"] = PIL_INTER_MAP[self.resample] 33 | 34 | return params 35 | 36 | def _apply_img_fn(self, img, translate, shear): 37 | trans_x, trans_y = translate 38 | shear_x, shear_y = shear 39 | 40 | return img.transform( 41 | img.size, 42 | Image.AFFINE, 43 | (1, shear_x, trans_x, shear_y, 1, trans_y), 44 | self.resample, 45 | ) 46 | 47 | 48 | class ShearX(AutoAugmentAffine): 49 | def __init__(self, shear_x, mirror=True, resample=Image.NEAREST, p=1.0): 50 | super().__init__(mirror=mirror, resample=resample, p=p) 51 | 52 | self.shear_x = shear_x 53 | 54 | def sample(self): 55 | shear_x = self._mirror(self.shear_x) 56 | 57 | return {"shear_x": shear_x} 58 | 59 | def _apply_img(self, img, shear_x): 60 | return self._apply_img_fn(img, (0, 0), (shear_x, 0)) 61 | 62 | 63 | class ShearY(AutoAugmentAffine): 64 | def __init__(self, shear_y, mirror=True, resample=Image.NEAREST, p=1.0): 65 | super().__init__(mirror=mirror, resample=resample, p=p) 66 | 67 | self.shear_y = shear_y 68 | 69 | def sample(self): 70 | shear_y = self._mirror(self.shear_y) 71 | 72 | return {"shear_y": shear_y} 73 | 74 | def _apply_img(self, img, shear_y): 75 | return self._apply_img_fn(img, (0, 0), (0, shear_y)) 76 | 77 | 78 | class TranslateX(AutoAugmentAffine): 79 | def __init__(self, translate_x, mirror=True, resample=Image.NEAREST, p=1.0): 80 | super().__init__(mirror=mirror, resample=resample, p=p) 81 | 82 | self.translate_x = translate_x 83 | 84 | def sample(self): 85 | trans_x = self._mirror(self.translate_x) 86 | 87 | return {"translate_x": trans_x} 88 | 89 | def _apply_img(self, img, translate_x): 90 | return self._apply_img_fn(img, (translate_x, 0), (0, 0)) 91 | 92 | 93 | class TranslateY(AutoAugmentAffine): 94 | def __init__(self, translate_y, mirror=True, resample=Image.NEAREST, p=1.0): 95 | super().__init__(mirror=mirror, resample=resample, p=p) 96 | 97 | self.translate_y = translate_y 98 | 99 | def sample(self): 100 | trans_y = self._mirror(self.translate_y) 101 | 102 | return {"translate_y": trans_x} 103 | 104 | def _apply_img(self, img, translate_y): 105 | return self._apply_img_fn(img, (0, translate_y), (0, 0)) 106 | 107 | 108 | class Rotate(AutoAugmentAffine): 109 | def __init__(self, rotate, mirror=True, resample=Image.NEAREST, p=1.0): 110 | super().__init__(mirror=mirror, resample=resample, p=p) 111 | 112 | self.rotate = rotate 113 | 114 | def sample(self): 115 | rotate = self._mirror(self.rotate) 116 | 117 | return {"rotate": rotate} 118 | 119 | def _apply_img(self, img, rotate): 120 | return img.rotate(rotate, resample=self.resample) 121 | 122 | 123 | class Posterize(RandomTransform): 124 | def __init__(self, bits, p=1.0): 125 | super().__init__(p) 126 | 127 | self.bits = bits 128 | 129 | def sample(self): 130 | return {"bits": self.bits} 131 | 132 | def _apply_img(self, img, bits): 133 | return ImageOps.posterize(img, bits) 134 | 135 | 136 | class Solarize(RandomTransform): 137 | def __init__(self, threshold, p=1.0): 138 | super().__init__(p) 139 | 140 | self.threshold = threshold 141 | 142 | def sample(self): 143 | return {"threshold": self.threshold} 144 | 145 | def _apply_img(self, img, threshold): 146 | return ImageOps.solarize(img, threshold) 147 | 148 | 149 | class Saturation(RandomTransform): 150 | def __init__(self, saturation, p=1.0): 151 | super().__init__(p) 152 | 153 | self.saturation = saturation 154 | 155 | def sample(self): 156 | return {"saturation": self.saturation} 157 | 158 | def _apply_img(self, img, saturation): 159 | return ImageEnhance.Color(img).enhance(saturation) 160 | 161 | 162 | class Contrast(RandomTransform): 163 | def __init__(self, contrast, p=1.0): 164 | super().__init__(p) 165 | 166 | self.contrast = contrast 167 | 168 | def sample(self): 169 | return {"contrast": self.contrast} 170 | 171 | def _apply_img(self, img, contrast): 172 | return ImageEnhance.Contrast(img).enhance(contrast) 173 | 174 | 175 | class Brightness(RandomTransform): 176 | def __init__(self, brightness, p=1.0): 177 | super().__init__(p) 178 | 179 | self.brightness = brightness 180 | 181 | def sample(self): 182 | return {"brightness": self.brightness} 183 | 184 | def _apply_img(self, img, brightness): 185 | return ImageEnhance.Brightness(img).enhance(brightness) 186 | 187 | 188 | class Sharpness(RandomTransform): 189 | def __init__(self, sharpness, p=1.0): 190 | super().__init__(p) 191 | 192 | self.sharpness = sharpness 193 | 194 | def sample(self): 195 | return {"sharpness": self.sharpness} 196 | 197 | def _apply_img(self, img, sharpness): 198 | return ImageEnhance.Sharpness(img).enhance(sharpness) 199 | 200 | 201 | def autoaugment_policy(): 202 | autoaugment_map = { 203 | "ShearX": (ShearX, lambda level: rescale_float(level, 0.3)), 204 | "ShearY": (ShearY, lambda level: rescale_float(level, 0.3)), 205 | "TranslateX": (TranslateX, lambda level: rescale_int(level, 10)), 206 | "TranslateY": (TranslateY, lambda level: rescale_int(level, 10)), 207 | "Rotate": (Rotate, lambda level: rescale_int(level, 30)), 208 | "Solarize": (Solarize, lambda level: 256 - rescale_int(level, 256)), 209 | "Posterize": (Posterize, lambda level: 4 - rescale_int(level, 4)), 210 | "Contrast": (Contrast, lambda level: rescale_float(level, 1.8) + 0.1), 211 | "Color": (Saturation, lambda level: rescale_float(level, 1.8) + 0.1), 212 | "Brightness": (Brightness, lambda level: rescale_float(level, 1.8) + 0.1), 213 | "Sharpness": (Sharpness, lambda level: rescale_float(level, 1.8) + 0.1), 214 | "Invert": (transforms.Invert, None), 215 | "AutoContrast": (transforms.AutoContrast, None), 216 | "Equalize": (transforms.Equalize, None), 217 | } 218 | 219 | policy_list = [ 220 | [("Posterize", 0.4, 8), ("Rotate", 0.6, 9)], 221 | [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)], 222 | [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)], 223 | [("Posterize", 0.6, 7), ("Posterize", 0.6, 6)], 224 | [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)], 225 | [("Equalize", 0.4, 4), ("Rotate", 0.8, 8)], 226 | [("Solarize", 0.6, 3), ("Equalize", 0.6, 7)], 227 | [("Posterize", 0.8, 5), ("Equalize", 1.0, 2)], 228 | [("Rotate", 0.2, 3), ("Solarize", 0.6, 8)], 229 | [("Equalize", 0.6, 8), ("Posterize", 0.4, 6)], 230 | [("Rotate", 0.8, 8), ("Color", 0.4, 0)], 231 | [("Rotate", 0.4, 9), ("Equalize", 0.6, 2)], 232 | [("Equalize", 0.0, 7), ("Equalize", 0.8, 8)], 233 | [("Invert", 0.6, 4), ("Equalize", 1.0, 8)], 234 | [("Color", 0.6, 4), ("Contrast", 1.0, 8)], 235 | [("Rotate", 0.8, 8), ("Color", 1.0, 0)], 236 | [("Color", 0.8, 8), ("Solarize", 0.8, 7)], 237 | [("Sharpness", 0.4, 7), ("Invert", 0.6, 8)], 238 | [("ShearX", 0.6, 5), ("Equalize", 1.0, 9)], 239 | [("Color", 0.4, 0), ("Equalize", 0.6, 3)], 240 | [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)], 241 | [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)], 242 | [("Invert", 0.6, 4), ("Equalize", 1.0, 8)], 243 | [("Color", 0.6, 4), ("Contrast", 1.0, 8)], 244 | [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)], 245 | ] 246 | 247 | reparam_policy = [] 248 | 249 | for policy in policy_list: 250 | sub_pol = [] 251 | 252 | for pol in policy: 253 | augment, prob, magnitude = pol 254 | augment_fn, reparam_fn = autoaugment_map[augment] 255 | 256 | if reparam_fn is not None: 257 | magnitude = reparam_fn(magnitude) 258 | sub_pol.append(augment_fn(magnitude, p=prob)) 259 | 260 | else: 261 | sub_pol.append(augment_fn(p=prob)) 262 | 263 | reparam_policy.append(sub_pol) 264 | 265 | return reparam_policy 266 | 267 | 268 | class AutoAugment: 269 | def __init__(self, policy): 270 | self.policy = policy 271 | 272 | def __call__(self, img): 273 | selected_policy = random.choice(self.policy) 274 | 275 | for pol in selected_policy: 276 | sample = pol.sample() 277 | img = pol.apply_img(img, **sample) 278 | 279 | return img 280 | 281 | def __repr__(self): 282 | policy_str = ",\n ".join(repr(p) for p in self.policy) 283 | policy_str = f" [{policy_str}]" 284 | return f"{self.__class__.__name__}(\n{policy_str}\n)" 285 | 286 | def check(self, img): 287 | log = [] 288 | 289 | selected_policy = random.choice(self.policy) 290 | 291 | for pol in selected_policy: 292 | sample = pol.sample() 293 | img, check = pol.apply_img_check(img, **sample) 294 | log.append((pol, sample, check)) 295 | 296 | return img, log 297 | -------------------------------------------------------------------------------- /tensorfn/config/config.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import collections 4 | import inspect 5 | import functools 6 | import typing 7 | from typing import Optional, Union 8 | 9 | from pydantic import ( 10 | BaseModel, 11 | create_model, 12 | validator, 13 | StrictStr, 14 | StrictInt, 15 | StrictBool, 16 | ValidationError, 17 | ArbitraryTypeError, 18 | dataclasses, 19 | ) 20 | from torch.distributed.launcher.api import LaunchConfig 21 | 22 | # LaunchConfig = dataclasses.dataclass(LaunchConfig) 23 | 24 | 25 | class Config(BaseModel): 26 | class Config: 27 | extra = "forbid" 28 | 29 | 30 | class MainConfig(BaseModel): 31 | class Config: 32 | extra = "forbid" 33 | 34 | n_gpu: Optional[StrictInt] 35 | n_machine: Optional[StrictInt] 36 | machine_rank: Optional[StrictInt] 37 | dist_url: Optional[StrictStr] 38 | distributed: Optional[StrictBool] 39 | ckpt: Optional[StrictStr] 40 | launch_config: typing.Any 41 | 42 | 43 | class TypedConfig(BaseModel): 44 | class Config: 45 | extra = "forbid" 46 | 47 | type: StrictStr 48 | 49 | @validator("type") 50 | def check_type(cls, v): 51 | if v != cls.__type__: 52 | raise ValueError("Options does not match for " + cls.__type__) 53 | 54 | return v 55 | 56 | 57 | CONFIG_REGISTRY = {} 58 | 59 | 60 | def config_model(name=None, namespace=None, exclude=(), use_type=False): 61 | def _decorate(fn): 62 | if name is None: 63 | fn_name = fn.__name__ 64 | 65 | else: 66 | fn_name = name 67 | 68 | if namespace not in CONFIG_REGISTRY: 69 | CONFIG_REGISTRY[namespace] = {} 70 | 71 | if fn_name in CONFIG_REGISTRY[namespace]: 72 | prev_fn = CONFIG_REGISTRY[namespace][fn_name][1] 73 | raise KeyError(f"Conflict occured in config registry: {prev_fn} vs {fn}") 74 | 75 | CONFIG_REGISTRY[namespace][fn_name] = ( 76 | use_type, 77 | fn, 78 | inspect.signature(fn), 79 | exclude, 80 | ) 81 | 82 | return fn 83 | 84 | return _decorate 85 | 86 | 87 | def _check_type(type_name): 88 | @validator("type", allow_reuse=True) 89 | def check_type(cls, v): 90 | if v != type_name: 91 | raise ValueError(f"Type does not match for {type_name}") 92 | 93 | return v 94 | 95 | return check_type 96 | 97 | 98 | class StrictConfig: 99 | extra = "forbid" 100 | arbitrary_types_allowed = True 101 | 102 | 103 | def make_model_from_signature( 104 | name, init_fn, signature, exclude, type_name=None, strict=True 105 | ): 106 | params = {} 107 | 108 | if type_name is not None: 109 | params["type"] = (StrictStr, ...) 110 | 111 | for k, v in signature.parameters.items(): 112 | if k in exclude: 113 | continue 114 | 115 | if v.kind == v.VAR_POSITIONAL or v.kind == v.VAR_KEYWORD: 116 | strict = False 117 | 118 | continue 119 | 120 | annotation = v.annotation 121 | if annotation is inspect._empty: 122 | annotation = typing.Any 123 | 124 | if v.default is inspect._empty: 125 | params[k] = (annotation, ...) 126 | 127 | else: 128 | params[k] = (annotation, v.default) 129 | 130 | def _params(self): 131 | values = self.dict() 132 | 133 | if type_name is not None: 134 | values.pop("type") 135 | 136 | return values 137 | 138 | @functools.wraps(init_fn) 139 | def _init_fn(self, *args, **kwargs): 140 | params = self.params() 141 | params.update(kwargs) 142 | pos_replace = list(signature.parameters.keys())[: len(args)] 143 | for pos in pos_replace: 144 | params.pop(pos) 145 | 146 | return init_fn(*args, **params) 147 | 148 | validators = {"params": _params, "make": _init_fn} 149 | 150 | if type_name is not None: 151 | validators["check_type"] = _check_type(type_name) 152 | 153 | if strict: 154 | config = StrictConfig 155 | 156 | else: 157 | config = None 158 | 159 | model = create_model( 160 | name, 161 | __config__=config, 162 | __validators__=validators, 163 | __module__=__name__, 164 | **params, 165 | ) 166 | 167 | setattr(sys.modules[__name__], name, model) 168 | 169 | return model 170 | 171 | 172 | CONFIG_MODEL_REGISTRY = {} 173 | 174 | 175 | def get_models(namespace): 176 | names = CONFIG_REGISTRY[namespace].keys() 177 | for i, name in enumerate(names): 178 | model = get_model(name, namespace) 179 | 180 | if i == 0: 181 | models = Union[model] 182 | 183 | else: 184 | models = Union[models, model] 185 | 186 | return models 187 | 188 | 189 | def get_model(name, namespace=None): 190 | if namespace not in CONFIG_MODEL_REGISTRY: 191 | CONFIG_MODEL_REGISTRY[namespace] = {} 192 | 193 | if name in CONFIG_MODEL_REGISTRY[namespace]: 194 | return CONFIG_MODEL_REGISTRY[namespace][name] 195 | 196 | use_type, init_fn, signature, exclude = CONFIG_REGISTRY[namespace][name] 197 | model = make_model_from_signature( 198 | name, init_fn, signature, exclude, name if use_type else None 199 | ) 200 | CONFIG_MODEL_REGISTRY[namespace][name] = model 201 | 202 | return model 203 | 204 | 205 | def override(overrides, **defaults): 206 | result = {} 207 | 208 | for k, v in defaults.items(): 209 | result[k] = overrides.get(k, v) 210 | 211 | return result 212 | 213 | 214 | def resolve_module(path): 215 | from importlib import import_module 216 | 217 | sub_path = path.split(".") 218 | module = None 219 | 220 | for i in reversed(range(len(sub_path))): 221 | try: 222 | mod = ".".join(sub_path[:i]) 223 | module = import_module(mod) 224 | 225 | except (ModuleNotFoundError, ImportError): 226 | continue 227 | 228 | if module is not None: 229 | break 230 | 231 | obj = module 232 | 233 | for sub in sub_path[i:]: 234 | mod = f"{mod}.{sub}" 235 | 236 | if not hasattr(obj, sub): 237 | try: 238 | import_module(mod) 239 | 240 | except (ModuleNotFoundError, ImportError) as e: 241 | raise ImportError( 242 | f"Encountered error: '{e}' when loading module '{path}'" 243 | ) from e 244 | 245 | obj = getattr(obj, sub) 246 | 247 | return obj 248 | 249 | 250 | def flatten_tree(node): 251 | res = [] 252 | 253 | if isinstance(node, collections.abc.Sequence) and not isinstance(node, str): 254 | for n in node: 255 | res.extend(flatten_tree(n)) 256 | 257 | return res 258 | 259 | if isinstance(node, collections.abc.Mapping): 260 | res.append(node) 261 | 262 | for v in node.values(): 263 | res.extend(flatten_tree(v)) 264 | 265 | return res 266 | 267 | 268 | SINGLETON = {} 269 | 270 | 271 | def init_singleton(nodes): 272 | key_key = "__key" 273 | 274 | for node in nodes: 275 | if key_key not in node: 276 | continue 277 | 278 | node_key = node[key_key] 279 | 280 | if node_key in SINGLETON: 281 | continue 282 | 283 | restrict_node = {k: v for k, v in node.items() if k != key_key} 284 | instance_traverse(restrict_node) 285 | SINGLETON[node_key] = instance_traverse(restrict_node, instantiate=True) 286 | 287 | 288 | def instance_traverse( 289 | node, *args, recursive=True, instantiate=False, keyword_args=None 290 | ): 291 | if isinstance(node, collections.abc.Sequence) and not isinstance(node, str): 292 | seq = [ 293 | instance_traverse(i, recursive=recursive, instantiate=instantiate) 294 | for i in node 295 | ] 296 | 297 | return seq 298 | 299 | if isinstance(node, collections.abc.Mapping): 300 | target_key = "__target" 301 | init_key = "__init" 302 | fn_key = "__fn" 303 | validate_key = "__validate" 304 | partial_key = "__partial" 305 | args_key = "__args" 306 | key_key = "__key" 307 | 308 | exclude_keys = { 309 | target_key, 310 | init_key, 311 | fn_key, 312 | validate_key, 313 | partial_key, 314 | args_key, 315 | key_key, 316 | } 317 | 318 | if target_key in node or init_key in node or fn_key in node: 319 | return_fn = False 320 | partial = node.get(partial_key, False) 321 | do_validate = node.get(validate_key, True) 322 | 323 | if init_key in node: 324 | target = node.get(init_key) 325 | 326 | elif fn_key in node: 327 | target = node.get(fn_key) 328 | 329 | if len([k for k in node if k not in exclude_keys]) > 0: 330 | partial = True 331 | 332 | else: 333 | return_fn = True 334 | do_validate = False 335 | 336 | else: 337 | target = node.get(target_key) 338 | 339 | obj = resolve_module(target) 340 | signature = inspect.signature(obj) 341 | 342 | if instantiate: 343 | if key_key in node: 344 | return SINGLETON[node[key_key]] 345 | 346 | if args_key in node: 347 | args_node = node[args_key] 348 | 349 | if len(args_node) > len(args): 350 | args_init = [] 351 | 352 | for a in args_node[len(args) :]: 353 | args_init.append( 354 | instance_traverse( 355 | a, recursive=recursive, instantiate=instantiate 356 | ) 357 | ) 358 | 359 | args = list(args) + args_init 360 | 361 | pos_replace = list(signature.parameters.keys())[: len(args)] 362 | 363 | kwargs = {} 364 | for k, v in node.items(): 365 | if k in exclude_keys: 366 | continue 367 | 368 | if k in pos_replace: 369 | continue 370 | 371 | if keyword_args is not None and k in keyword_args: 372 | kwargs[k] = keyword_args[k] 373 | 374 | continue 375 | 376 | kwargs[k] = instance_traverse( 377 | v, recursive=recursive, instantiate=instantiate 378 | ) 379 | 380 | if return_fn: 381 | return obj 382 | 383 | elif partial: 384 | return functools.partial(obj, *args, **kwargs) 385 | 386 | else: 387 | return obj(*args, **kwargs) 388 | 389 | else: 390 | rest = {} 391 | 392 | args_replaced = [] 393 | if args_key in node: 394 | for arg, k in zip(node[args_key], signature.parameters.keys()): 395 | rest[k] = arg 396 | args_replaced.append(k) 397 | 398 | for k, v in node.items(): 399 | if k in exclude_keys: 400 | continue 401 | 402 | rest[k] = instance_traverse( 403 | v, recursive=recursive, instantiate=instantiate 404 | ) 405 | 406 | if k in args_replaced: 407 | raise TypeError( 408 | f"{target} got multiple values for argument '{k}'" 409 | ) 410 | 411 | if do_validate: 412 | name = "instance." + target 413 | 414 | if partial: 415 | rest_key = list(rest.keys()) 416 | exclude = [] 417 | 418 | for k in signature.parameters.keys(): 419 | if k not in rest_key: 420 | exclude.append(k) 421 | 422 | model = make_model_from_signature( 423 | name, obj, signature, exclude, strict=False 424 | ) 425 | 426 | else: 427 | model = make_model_from_signature(name, obj, signature, ()) 428 | 429 | try: 430 | model.validate(rest) 431 | 432 | except ValidationError as e: 433 | arbitrary_flag = True 434 | 435 | for error in e.errors(): 436 | if error["type"] != "type_error.arbitrary_type": 437 | arbitrary_flag = False 438 | 439 | break 440 | 441 | if not arbitrary_flag: 442 | raise ValueError( 443 | f"Validation for {target} with {v} is failed:\n{e}" 444 | ) from e 445 | 446 | for arg in args_replaced: 447 | del rest[arg] 448 | 449 | return_dict = {**node, **rest} 450 | 451 | return return_dict 452 | 453 | else: 454 | mapping = {} 455 | 456 | for k, v in node.items(): 457 | mapping[k] = instance_traverse( 458 | v, recursive=recursive, instantiate=instantiate 459 | ) 460 | 461 | return mapping 462 | 463 | else: 464 | return node 465 | 466 | 467 | class Instance(dict): 468 | @classmethod 469 | def __get_validators__(cls): 470 | yield cls.validate 471 | 472 | @classmethod 473 | def validate(cls, v): 474 | v_new = instance_traverse(v) 475 | instance = cls(v_new) 476 | 477 | return instance 478 | 479 | def make(self, *args, **kwargs): 480 | init_singleton(flatten_tree(self)) 481 | 482 | return instance_traverse(self, *args, instantiate=True, keyword_args=kwargs) 483 | 484 | def instantiate(self, *args, **kwargs): 485 | return self.make(*args, **kwargs) 486 | 487 | 488 | def instantiate(instance, *args, **kwargs): 489 | try: 490 | return instance.make(*args, **kwargs) 491 | 492 | except AttributeError: 493 | init_singleton(flatten_tree(instance)) 494 | 495 | return instance_traverse(instance, *args, instantiate=True, keyword_args=kwargs) 496 | -------------------------------------------------------------------------------- /tensorfn/nn/carafe/carafe_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2018-2019 Open-MMLab 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | using namespace at; 24 | 25 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 26 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 27 | i += blockDim.x * gridDim.x) 28 | 29 | #define THREADS_PER_BLOCK 1024 // 32 * 32 30 | #define WARP_SIZE 32 31 | #define THREADS_PER_PIXEL 32 32 | #define MAX_SHARED_MEMORY 49152 33 | #define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 34 | #define MAXIMIZE_KERNEL_SIZE true 35 | #define kTileDim 32 36 | #define kBlockRows 8 37 | #define FULL_MASK 0xffffffff 38 | 39 | inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } 40 | 41 | __device__ inline int Loc2Index(const int n, const int c, const int h, 42 | const int w, const int channel_num, 43 | const int height, const int width) { 44 | int index = w + (h + (c + n * channel_num) * height) * width; 45 | return index; 46 | } 47 | /* TODO: move this to a common place */ 48 | template 49 | __device__ inline scalar_t min(scalar_t a, scalar_t b) { 50 | return a < b ? a : b; 51 | } 52 | 53 | template 54 | __device__ inline scalar_t max(scalar_t a, scalar_t b) { 55 | return a > b ? a : b; 56 | } 57 | 58 | template 59 | __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { 60 | for (int offset = 16; offset > 0; offset /= 2) 61 | val += __shfl_down_sync(FULL_MASK, val, offset); 62 | return val; 63 | } 64 | 65 | // Splits the original matrix into submatrices with size 32 * 32. 66 | // Each block transposes one submatrix by loading it into shared memory. 67 | // Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ 68 | template 69 | __global__ void 70 | BatchTranspose2DCUDAKernel(const int N, const int H, const int W, const int dh, 71 | const int dw, const scalar_t *__restrict__ X, 72 | scalar_t *__restrict__ Y) { 73 | __shared__ scalar_t tile[kTileDim][kTileDim + 1]; 74 | const int n = blockIdx.x / (dh * dw); 75 | const int k = blockIdx.x % (dh * dw); 76 | const int r = k / dw; 77 | const int c = k % dw; 78 | const int offset = n * H * W; 79 | int x = c * kTileDim + threadIdx.x; 80 | int y = r * kTileDim + threadIdx.y; 81 | if (x < W) { 82 | for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { 83 | tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x]; 84 | } 85 | } 86 | __syncthreads(); 87 | x = r * kTileDim + threadIdx.x; 88 | y = c * kTileDim + threadIdx.y; 89 | if (x < H) { 90 | for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { 91 | Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; 92 | } 93 | } 94 | } 95 | template 96 | __global__ void 97 | CARAFEForward(const int num_kernels, const scalar_t *__restrict__ bottom_data, 98 | const scalar_t *__restrict__ bottom_masks, const int kernel_size, 99 | const int group_size, const int scale_factor, const int channels, 100 | const int down_height, const int down_width, const int height, 101 | const int width, const int mask_channels, 102 | scalar_t *__restrict__ top_data) { 103 | #if MAXIMIZE_KERNEL_SIZE 104 | __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; 105 | #else 106 | __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; 107 | #endif 108 | 109 | int index = threadIdx.x + blockIdx.x * blockDim.x; 110 | if (index > num_kernels - 1) { 111 | return; 112 | } 113 | const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; 114 | const int split_id = threadIdx.x % THREADS_PER_PIXEL; 115 | index = index / THREADS_PER_PIXEL; 116 | const int pw = index % width; 117 | const int ph = (index / width) % height; 118 | const int n = index / width / height; 119 | 120 | const int down_pw = pw / scale_factor; 121 | const int down_ph = ph / scale_factor; 122 | 123 | const int start_w = down_pw - (kernel_size - 1) / 2; 124 | const int end_w = down_pw + (kernel_size - 1) / 2 + 1; 125 | const int start_h = down_ph - (kernel_size - 1) / 2; 126 | const int end_h = down_ph + (kernel_size - 1) / 2 + 1; 127 | for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { 128 | int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels); 129 | shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; 130 | } 131 | __syncthreads(); 132 | 133 | const int channels_per_group = ceilf(channels / (float)group_size); 134 | #pragma unroll 135 | for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { 136 | int mask_group = c / channels_per_group; 137 | scalar_t output_val = 0; 138 | #pragma unroll 139 | for (int iy = start_h; iy < end_h; iy++) { 140 | #pragma unroll 141 | for (int ix = start_w; ix < end_w; ix++) { 142 | if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) { 143 | continue; 144 | } 145 | int mask_iy = iy - down_ph + (kernel_size - 1) / 2; 146 | int mask_ix = ix - down_pw + (kernel_size - 1) / 2; 147 | int mask_c = 148 | (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; 149 | int feat_index = 150 | Loc2Index(n, iy, ix, c, down_height, down_width, channels); 151 | 152 | output_val += bottom_data[feat_index] * 153 | shared_mask[mask_c * WARP_SIZE + pixel_id]; 154 | } 155 | } 156 | 157 | int top_index = Loc2Index(n, ph, pw, c, height, width, channels); 158 | top_data[top_index] = output_val; 159 | } 160 | } 161 | 162 | int CARAFEForwardLaucher(const at::Tensor features, const at::Tensor masks, 163 | const int kernel_size, const int group_size, 164 | const int scale_factor, const int batch_size, 165 | const int channels, const int input_height, 166 | const int input_width, const int output_height, 167 | const int output_width, const int mask_channels, 168 | at::Tensor rfeatures, at::Tensor routput, 169 | at::Tensor rmasks, at::Tensor output) { 170 | // one warp per pixel 171 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 172 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 173 | features.scalar_type(), "NCHW2NHWC_Feature", ([&] { 174 | const scalar_t *bottom_data = features.data_ptr(); 175 | scalar_t *top_data = rfeatures.data_ptr(); 176 | const int dh = divideUP(channels, kTileDim); 177 | const int dw = divideUP(input_height * input_width, kTileDim); 178 | BatchTranspose2DCUDAKernel 179 | <<>>( 180 | batch_size, channels, input_height * input_width, dh, dw, 181 | bottom_data, top_data); 182 | })); 183 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 184 | features.scalar_type(), "NCHW2NHWC_Masks", ([&] { 185 | const scalar_t *bottom_data = masks.data_ptr(); 186 | scalar_t *top_data = rmasks.data_ptr(); 187 | const int dh = divideUP(mask_channels, kTileDim); 188 | const int dw = divideUP(output_height * output_width, kTileDim); 189 | BatchTranspose2DCUDAKernel 190 | <<>>( 191 | batch_size, mask_channels, output_height * output_width, dh, dw, 192 | bottom_data, top_data); 193 | })); 194 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 195 | features.scalar_type(), "CARAFELaucherForward", ([&] { 196 | const int num_kernels = 197 | batch_size * output_height * output_width * THREADS_PER_PIXEL; 198 | const scalar_t *bottom_data = rfeatures.data_ptr(); 199 | const scalar_t *bottom_masks = rmasks.data_ptr(); 200 | scalar_t *top_data = routput.data_ptr(); 201 | 202 | CARAFEForward 203 | <<>>( 205 | num_kernels, bottom_data, bottom_masks, kernel_size, group_size, 206 | scale_factor, channels, input_height, input_width, 207 | output_height, output_width, mask_channels, top_data); 208 | })); 209 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 210 | features.scalar_type(), "NHWC2NCHW", ([&] { 211 | const scalar_t *bottom_data = routput.data_ptr(); 212 | scalar_t *top_data = output.data_ptr(); 213 | const int dh = divideUP(output_height * output_width, kTileDim); 214 | const int dw = divideUP(channels, kTileDim); 215 | BatchTranspose2DCUDAKernel 216 | <<>>( 217 | batch_size, output_height * output_width, channels, dh, dw, 218 | bottom_data, top_data); 219 | })); 220 | cudaError_t err = cudaGetLastError(); 221 | if (cudaSuccess != err) { 222 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 223 | exit(-1); 224 | } 225 | 226 | return 1; 227 | } 228 | 229 | template 230 | __global__ void CARAFEBackward_Feature( 231 | const int num_kernels, const scalar_t *__restrict__ top_diff, 232 | const scalar_t *__restrict__ bottom_masks, const int kernel_size, 233 | const int group_size, const int scale_factor, const int channels, 234 | const int down_height, const int down_width, const int height, 235 | const int width, const int mask_channels, 236 | scalar_t *__restrict__ bottom_diff) { 237 | #if MAXIMIZE_KERNEL_SIZE 238 | __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2]; 239 | #else 240 | __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T]; 241 | #endif 242 | 243 | int index = threadIdx.x + blockIdx.x * blockDim.x; 244 | if (index > num_kernels - 1) { 245 | return; 246 | } 247 | 248 | const int pixel_id = threadIdx.x / THREADS_PER_PIXEL; 249 | const int split_id = threadIdx.x % THREADS_PER_PIXEL; 250 | // (n, c, ph, pw) is an element in the bottom_data 251 | index = index / THREADS_PER_PIXEL; 252 | const int pw = index % width; 253 | const int ph = (index / width) % height; 254 | const int n = index / width / height; 255 | 256 | const int start_w = pw - (kernel_size - 1) * scale_factor / 2; 257 | const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1; 258 | const int start_h = ph - (kernel_size - 1) * scale_factor / 2; 259 | const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1; 260 | for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) { 261 | const int mask_w = (c % kernel_size) * scale_factor; 262 | const int mask_h = (c / kernel_size % kernel_size) * scale_factor; 263 | const int mask_x = start_w + mask_w; 264 | const int mask_y = start_h + mask_h; 265 | if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) { 266 | shared_mask[c * WARP_SIZE + pixel_id] = 0; 267 | continue; 268 | } 269 | const int mask_group = c / (kernel_size * kernel_size); 270 | const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1; 271 | int mask_index = 272 | Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width); 273 | shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index]; 274 | } 275 | __syncthreads(); 276 | const int channels_per_group = ceilf(channels / (float)group_size); 277 | #pragma unroll 278 | for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { 279 | int mask_group = c / channels_per_group; 280 | int top_index = Loc2Index(n, ph, pw, c, height, width, channels); 281 | scalar_t output_val = 0; 282 | #pragma unroll 283 | for (int iy = start_h; iy < end_h; iy += scale_factor) { 284 | #pragma unroll 285 | for (int ix = start_w; ix < end_w; ix += scale_factor) { 286 | if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) { 287 | continue; 288 | } 289 | int mask_iy = 290 | (iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor; 291 | int mask_ix = 292 | (ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor; 293 | int mask_c = 294 | (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix; 295 | int feat_index = Loc2Index(n, iy, ix, c, height, width, channels); 296 | output_val += 297 | shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index]; 298 | } 299 | } 300 | bottom_diff[top_index] = output_val; 301 | } 302 | } 303 | 304 | template 305 | __global__ void 306 | FeatureSum(const int num_kernels, const scalar_t *__restrict__ input_data, 307 | const int scale_factor, const int channels, const int height, 308 | const int width, scalar_t *__restrict__ output_data) { 309 | int index = threadIdx.x + blockIdx.x * blockDim.x; 310 | if (index > num_kernels - 1) { 311 | return; 312 | } 313 | const int split_id = threadIdx.x % THREADS_PER_PIXEL; 314 | index = index / THREADS_PER_PIXEL; 315 | const int pw = index % width; 316 | const int ph = (index / width) % height; 317 | const int n = index / width / height; 318 | for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) { 319 | scalar_t output_val = 0; 320 | for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) { 321 | for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) { 322 | int input_id = Loc2Index(n, iy, ix, c, height * scale_factor, 323 | width * scale_factor, channels); 324 | output_val += input_data[input_id]; 325 | } 326 | } 327 | const int output_id = Loc2Index(n, ph, pw, c, height, width, channels); 328 | output_data[output_id] = output_val; 329 | } 330 | } 331 | 332 | template 333 | __global__ void CARAFEBackward_Mask(const int num_kernels, 334 | const scalar_t *__restrict__ top_diff, 335 | const scalar_t *__restrict__ bottom_data, 336 | const int kernel_size, const int group_size, 337 | const int scale_factor, const int channels, 338 | const int down_height, const int down_width, 339 | const int height, const int width, 340 | const int mask_channels, 341 | scalar_t *__restrict__ mask_diff) { 342 | int index = threadIdx.x + blockIdx.x * blockDim.x; 343 | if (index > num_kernels - 1) { 344 | return; 345 | } 346 | 347 | const int lane_id = index % WARP_SIZE; 348 | index = index / WARP_SIZE; 349 | const int mask_c = index % mask_channels; 350 | // (n, c, ph, pw) is an element in the bottom_data 351 | index = index / mask_channels; 352 | const int pw = index % width; 353 | const int ph = (index / width) % height; 354 | const int n = index / width / height; 355 | 356 | const int down_pw = pw / scale_factor; 357 | const int down_ph = ph / scale_factor; 358 | 359 | const int mask_group = mask_c / (kernel_size * kernel_size); 360 | const int mask_loc = mask_c % (kernel_size * kernel_size); 361 | 362 | const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2; 363 | const int offset_y = 364 | mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2; 365 | 366 | const int down_x = down_pw + offset_x; 367 | const int down_y = down_ph + offset_y; 368 | 369 | scalar_t output_val = 0; 370 | 371 | if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 && 372 | down_x <= down_width - 1) { 373 | const int channels_per_mask = ceilf(channels / (float)group_size); 374 | const int start = channels_per_mask * mask_group; 375 | const int end = min(channels_per_mask * (mask_group + 1), channels); 376 | for (int c = start + lane_id; c < end; c += WARP_SIZE) { 377 | int bottom_id = 378 | Loc2Index(n, down_y, down_x, c, down_height, down_width, channels); 379 | int top_id = Loc2Index(n, ph, pw, c, height, width, channels); 380 | output_val += top_diff[top_id] * bottom_data[bottom_id]; 381 | } 382 | } 383 | __syncwarp(); 384 | output_val = warpReduceSum(output_val); 385 | if (lane_id == 0) { 386 | const int mask_id = 387 | Loc2Index(n, ph, pw, mask_c, height, width, mask_channels); 388 | mask_diff[mask_id] = output_val; 389 | } 390 | } 391 | 392 | int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures, 393 | const at::Tensor masks, const int kernel_size, 394 | const int group_size, const int scale_factor, 395 | const int batch_size, const int channels, 396 | const int input_height, const int input_width, 397 | const int output_height, const int output_width, 398 | const int mask_channels, at::Tensor rtop_grad, 399 | at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad, 400 | at::Tensor rmask_grad, at::Tensor bottom_grad, 401 | at::Tensor mask_grad) { 402 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 403 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 404 | top_grad.scalar_type(), "NCHW2NHWC_Top_Grad", ([&] { 405 | const scalar_t *bottom_data = top_grad.data(); 406 | scalar_t *top_data = rtop_grad.data(); 407 | const int dh = divideUP(channels, kTileDim); 408 | const int dw = divideUP(output_height * output_width, kTileDim); 409 | BatchTranspose2DCUDAKernel 410 | <<>>( 411 | batch_size, channels, output_height * output_width, dh, dw, 412 | bottom_data, top_data); 413 | })); 414 | 415 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 416 | top_grad.scalar_type(), "CARAFELaucherBackward_Feature", ([&] { 417 | const int num_kernels = 418 | batch_size * output_height * output_width * THREADS_PER_PIXEL; 419 | const scalar_t *top_diff = rtop_grad.data(); 420 | const scalar_t *bottom_masks = masks.data(); 421 | scalar_t *bottom_diff = rbottom_grad_hs.data(); 422 | 423 | CARAFEBackward_Feature 424 | <<>>( 426 | num_kernels, top_diff, bottom_masks, kernel_size, group_size, 427 | scale_factor, channels, input_height, input_width, 428 | output_height, output_width, mask_channels, bottom_diff); 429 | })); 430 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 431 | top_grad.scalar_type(), "FeatureSum", ([&] { 432 | const int num_kernels = 433 | batch_size * input_height * input_width * THREADS_PER_PIXEL; 434 | const scalar_t *bottom_diff_hs = rbottom_grad_hs.data_ptr(); 435 | scalar_t *bottom_diff = rbottom_grad.data_ptr(); 436 | 437 | FeatureSum 438 | <<>>( 440 | num_kernels, bottom_diff_hs, scale_factor, channels, 441 | input_height, input_width, bottom_diff); 442 | })); 443 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 444 | top_grad.scalar_type(), "NHWC2NCHW_Bottom_Grad", ([&] { 445 | const scalar_t *bottom_data = rbottom_grad.data_ptr(); 446 | scalar_t *top_data = bottom_grad.data_ptr(); 447 | const int dh = divideUP(input_height * input_width, kTileDim); 448 | const int dw = divideUP(channels, kTileDim); 449 | BatchTranspose2DCUDAKernel 450 | <<>>( 451 | batch_size, input_height * input_width, channels, dh, dw, 452 | bottom_data, top_data); 453 | })); 454 | 455 | AT_DISPATCH_FLOATING_TYPES( 456 | top_grad.scalar_type(), "CARAFELaucherBackward_Mask", ([&] { 457 | const int num_kernels = batch_size * output_height * output_width * 458 | mask_channels * WARP_SIZE; 459 | const scalar_t *top_diff = rtop_grad.data_ptr(); 460 | const scalar_t *bottom_data = rfeatures.data_ptr(); 461 | scalar_t *mask_diff = rmask_grad.data_ptr(); 462 | 463 | CARAFEBackward_Mask 464 | <<>>( 466 | num_kernels, top_diff, bottom_data, kernel_size, group_size, 467 | scale_factor, channels, input_height, input_width, 468 | output_height, output_width, mask_channels, mask_diff); 469 | })); 470 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 471 | top_grad.scalar_type(), "NHWC2NCHW_Mask_Grad", ([&] { 472 | const scalar_t *bottom_data = rmask_grad.data_ptr(); 473 | scalar_t *top_data = mask_grad.data_ptr(); 474 | const int dh = divideUP(output_height * output_width, kTileDim); 475 | const int dw = divideUP(mask_channels, kTileDim); 476 | BatchTranspose2DCUDAKernel 477 | <<>>( 478 | batch_size, output_height * output_width, mask_channels, dh, dw, 479 | bottom_data, top_data); 480 | })); 481 | cudaError_t err = cudaGetLastError(); 482 | if (cudaSuccess != err) { 483 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 484 | exit(-1); 485 | } 486 | 487 | return 1; 488 | } 489 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | # CARAFE 25 | 26 | Copyright 2018-2019 Open-MMLab. All rights reserved. 27 | 28 | Apache License 29 | Version 2.0, January 2004 30 | http://www.apache.org/licenses/ 31 | 32 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 33 | 34 | 1. Definitions. 35 | 36 | "License" shall mean the terms and conditions for use, reproduction, 37 | and distribution as defined by Sections 1 through 9 of this document. 38 | 39 | "Licensor" shall mean the copyright owner or entity authorized by 40 | the copyright owner that is granting the License. 41 | 42 | "Legal Entity" shall mean the union of the acting entity and all 43 | other entities that control, are controlled by, or are under common 44 | control with that entity. For the purposes of this definition, 45 | "control" means (i) the power, direct or indirect, to cause the 46 | direction or management of such entity, whether by contract or 47 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 48 | outstanding shares, or (iii) beneficial ownership of such entity. 49 | 50 | "You" (or "Your") shall mean an individual or Legal Entity 51 | exercising permissions granted by this License. 52 | 53 | "Source" form shall mean the preferred form for making modifications, 54 | including but not limited to software source code, documentation 55 | source, and configuration files. 56 | 57 | "Object" form shall mean any form resulting from mechanical 58 | transformation or translation of a Source form, including but 59 | not limited to compiled object code, generated documentation, 60 | and conversions to other media types. 61 | 62 | "Work" shall mean the work of authorship, whether in Source or 63 | Object form, made available under the License, as indicated by a 64 | copyright notice that is included in or attached to the work 65 | (an example is provided in the Appendix below). 66 | 67 | "Derivative Works" shall mean any work, whether in Source or Object 68 | form, that is based on (or derived from) the Work and for which the 69 | editorial revisions, annotations, elaborations, or other modifications 70 | represent, as a whole, an original work of authorship. For the purposes 71 | of this License, Derivative Works shall not include works that remain 72 | separable from, or merely link (or bind by name) to the interfaces of, 73 | the Work and Derivative Works thereof. 74 | 75 | "Contribution" shall mean any work of authorship, including 76 | the original version of the Work and any modifications or additions 77 | to that Work or Derivative Works thereof, that is intentionally 78 | submitted to Licensor for inclusion in the Work by the copyright owner 79 | or by an individual or Legal Entity authorized to submit on behalf of 80 | the copyright owner. For the purposes of this definition, "submitted" 81 | means any form of electronic, verbal, or written communication sent 82 | to the Licensor or its representatives, including but not limited to 83 | communication on electronic mailing lists, source code control systems, 84 | and issue tracking systems that are managed by, or on behalf of, the 85 | Licensor for the purpose of discussing and improving the Work, but 86 | excluding communication that is conspicuously marked or otherwise 87 | designated in writing by the copyright owner as "Not a Contribution." 88 | 89 | "Contributor" shall mean Licensor and any individual or Legal Entity 90 | on behalf of whom a Contribution has been received by Licensor and 91 | subsequently incorporated within the Work. 92 | 93 | 2. Grant of Copyright License. Subject to the terms and conditions of 94 | this License, each Contributor hereby grants to You a perpetual, 95 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 96 | copyright license to reproduce, prepare Derivative Works of, 97 | publicly display, publicly perform, sublicense, and distribute the 98 | Work and such Derivative Works in Source or Object form. 99 | 100 | 3. Grant of Patent License. Subject to the terms and conditions of 101 | this License, each Contributor hereby grants to You a perpetual, 102 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 103 | (except as stated in this section) patent license to make, have made, 104 | use, offer to sell, sell, import, and otherwise transfer the Work, 105 | where such license applies only to those patent claims licensable 106 | by such Contributor that are necessarily infringed by their 107 | Contribution(s) alone or by combination of their Contribution(s) 108 | with the Work to which such Contribution(s) was submitted. If You 109 | institute patent litigation against any entity (including a 110 | cross-claim or counterclaim in a lawsuit) alleging that the Work 111 | or a Contribution incorporated within the Work constitutes direct 112 | or contributory patent infringement, then any patent licenses 113 | granted to You under this License for that Work shall terminate 114 | as of the date such litigation is filed. 115 | 116 | 4. Redistribution. You may reproduce and distribute copies of the 117 | Work or Derivative Works thereof in any medium, with or without 118 | modifications, and in Source or Object form, provided that You 119 | meet the following conditions: 120 | 121 | (a) You must give any other recipients of the Work or 122 | Derivative Works a copy of this License; and 123 | 124 | (b) You must cause any modified files to carry prominent notices 125 | stating that You changed the files; and 126 | 127 | (c) You must retain, in the Source form of any Derivative Works 128 | that You distribute, all copyright, patent, trademark, and 129 | attribution notices from the Source form of the Work, 130 | excluding those notices that do not pertain to any part of 131 | the Derivative Works; and 132 | 133 | (d) If the Work includes a "NOTICE" text file as part of its 134 | distribution, then any Derivative Works that You distribute must 135 | include a readable copy of the attribution notices contained 136 | within such NOTICE file, excluding those notices that do not 137 | pertain to any part of the Derivative Works, in at least one 138 | of the following places: within a NOTICE text file distributed 139 | as part of the Derivative Works; within the Source form or 140 | documentation, if provided along with the Derivative Works; or, 141 | within a display generated by the Derivative Works, if and 142 | wherever such third-party notices normally appear. The contents 143 | of the NOTICE file are for informational purposes only and 144 | do not modify the License. You may add Your own attribution 145 | notices within Derivative Works that You distribute, alongside 146 | or as an addendum to the NOTICE text from the Work, provided 147 | that such additional attribution notices cannot be construed 148 | as modifying the License. 149 | 150 | You may add Your own copyright statement to Your modifications and 151 | may provide additional or different license terms and conditions 152 | for use, reproduction, or distribution of Your modifications, or 153 | for any such Derivative Works as a whole, provided Your use, 154 | reproduction, and distribution of the Work otherwise complies with 155 | the conditions stated in this License. 156 | 157 | 5. Submission of Contributions. Unless You explicitly state otherwise, 158 | any Contribution intentionally submitted for inclusion in the Work 159 | by You to the Licensor shall be under the terms and conditions of 160 | this License, without any additional terms or conditions. 161 | Notwithstanding the above, nothing herein shall supersede or modify 162 | the terms of any separate license agreement you may have executed 163 | with Licensor regarding such Contributions. 164 | 165 | 6. Trademarks. This License does not grant permission to use the trade 166 | names, trademarks, service marks, or product names of the Licensor, 167 | except as required for reasonable and customary use in describing the 168 | origin of the Work and reproducing the content of the NOTICE file. 169 | 170 | 7. Disclaimer of Warranty. Unless required by applicable law or 171 | agreed to in writing, Licensor provides the Work (and each 172 | Contributor provides its Contributions) on an "AS IS" BASIS, 173 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 174 | implied, including, without limitation, any warranties or conditions 175 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 176 | PARTICULAR PURPOSE. You are solely responsible for determining the 177 | appropriateness of using or redistributing the Work and assume any 178 | risks associated with Your exercise of permissions under this License. 179 | 180 | 8. Limitation of Liability. In no event and under no legal theory, 181 | whether in tort (including negligence), contract, or otherwise, 182 | unless required by applicable law (such as deliberate and grossly 183 | negligent acts) or agreed to in writing, shall any Contributor be 184 | liable to You for damages, including any direct, indirect, special, 185 | incidental, or consequential damages of any character arising as a 186 | result of this License or out of the use or inability to use the 187 | Work (including but not limited to damages for loss of goodwill, 188 | work stoppage, computer failure or malfunction, or any and all 189 | other commercial damages or losses), even if such Contributor 190 | has been advised of the possibility of such damages. 191 | 192 | 9. Accepting Warranty or Additional Liability. While redistributing 193 | the Work or Derivative Works thereof, You may choose to offer, 194 | and charge a fee for, acceptance of support, warranty, indemnity, 195 | or other liability obligations and/or rights consistent with this 196 | License. However, in accepting such obligations, You may act only 197 | on Your own behalf and on Your sole responsibility, not on behalf 198 | of any other Contributor, and only if You agree to indemnify, 199 | defend, and hold each Contributor harmless for any liability 200 | incurred by, or claims asserted against, such Contributor by reason 201 | of your accepting any such warranty or additional liability. 202 | 203 | END OF TERMS AND CONDITIONS 204 | 205 | APPENDIX: How to apply the Apache License to your work. 206 | 207 | To apply the Apache License to your work, attach the following 208 | boilerplate notice, with the fields enclosed by brackets "[]" 209 | replaced with your own identifying information. (Don't include 210 | the brackets!) The text should be enclosed in the appropriate 211 | comment syntax for the file format. We also recommend that a 212 | file or class name and description of purpose be included on the 213 | same "printed page" as the copyright notice for easier 214 | identification within third-party archives. 215 | 216 | Copyright 2018-2019 Open-MMLab. 217 | 218 | Licensed under the Apache License, Version 2.0 (the "License"); 219 | you may not use this file except in compliance with the License. 220 | You may obtain a copy of the License at 221 | 222 | http://www.apache.org/licenses/LICENSE-2.0 223 | 224 | Unless required by applicable law or agreed to in writing, software 225 | distributed under the License is distributed on an "AS IS" BASIS, 226 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 227 | See the License for the specific language governing permissions and 228 | limitations under the License. 229 | 230 | 231 | # RMSpropTF 232 | 233 | Apache License 234 | Version 2.0, January 2004 235 | http://www.apache.org/licenses/ 236 | 237 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 238 | 239 | 1. Definitions. 240 | 241 | "License" shall mean the terms and conditions for use, reproduction, 242 | and distribution as defined by Sections 1 through 9 of this document. 243 | 244 | "Licensor" shall mean the copyright owner or entity authorized by 245 | the copyright owner that is granting the License. 246 | 247 | "Legal Entity" shall mean the union of the acting entity and all 248 | other entities that control, are controlled by, or are under common 249 | control with that entity. For the purposes of this definition, 250 | "control" means (i) the power, direct or indirect, to cause the 251 | direction or management of such entity, whether by contract or 252 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 253 | outstanding shares, or (iii) beneficial ownership of such entity. 254 | 255 | "You" (or "Your") shall mean an individual or Legal Entity 256 | exercising permissions granted by this License. 257 | 258 | "Source" form shall mean the preferred form for making modifications, 259 | including but not limited to software source code, documentation 260 | source, and configuration files. 261 | 262 | "Object" form shall mean any form resulting from mechanical 263 | transformation or translation of a Source form, including but 264 | not limited to compiled object code, generated documentation, 265 | and conversions to other media types. 266 | 267 | "Work" shall mean the work of authorship, whether in Source or 268 | Object form, made available under the License, as indicated by a 269 | copyright notice that is included in or attached to the work 270 | (an example is provided in the Appendix below). 271 | 272 | "Derivative Works" shall mean any work, whether in Source or Object 273 | form, that is based on (or derived from) the Work and for which the 274 | editorial revisions, annotations, elaborations, or other modifications 275 | represent, as a whole, an original work of authorship. For the purposes 276 | of this License, Derivative Works shall not include works that remain 277 | separable from, or merely link (or bind by name) to the interfaces of, 278 | the Work and Derivative Works thereof. 279 | 280 | "Contribution" shall mean any work of authorship, including 281 | the original version of the Work and any modifications or additions 282 | to that Work or Derivative Works thereof, that is intentionally 283 | submitted to Licensor for inclusion in the Work by the copyright owner 284 | or by an individual or Legal Entity authorized to submit on behalf of 285 | the copyright owner. For the purposes of this definition, "submitted" 286 | means any form of electronic, verbal, or written communication sent 287 | to the Licensor or its representatives, including but not limited to 288 | communication on electronic mailing lists, source code control systems, 289 | and issue tracking systems that are managed by, or on behalf of, the 290 | Licensor for the purpose of discussing and improving the Work, but 291 | excluding communication that is conspicuously marked or otherwise 292 | designated in writing by the copyright owner as "Not a Contribution." 293 | 294 | "Contributor" shall mean Licensor and any individual or Legal Entity 295 | on behalf of whom a Contribution has been received by Licensor and 296 | subsequently incorporated within the Work. 297 | 298 | 2. Grant of Copyright License. Subject to the terms and conditions of 299 | this License, each Contributor hereby grants to You a perpetual, 300 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 301 | copyright license to reproduce, prepare Derivative Works of, 302 | publicly display, publicly perform, sublicense, and distribute the 303 | Work and such Derivative Works in Source or Object form. 304 | 305 | 3. Grant of Patent License. Subject to the terms and conditions of 306 | this License, each Contributor hereby grants to You a perpetual, 307 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 308 | (except as stated in this section) patent license to make, have made, 309 | use, offer to sell, sell, import, and otherwise transfer the Work, 310 | where such license applies only to those patent claims licensable 311 | by such Contributor that are necessarily infringed by their 312 | Contribution(s) alone or by combination of their Contribution(s) 313 | with the Work to which such Contribution(s) was submitted. If You 314 | institute patent litigation against any entity (including a 315 | cross-claim or counterclaim in a lawsuit) alleging that the Work 316 | or a Contribution incorporated within the Work constitutes direct 317 | or contributory patent infringement, then any patent licenses 318 | granted to You under this License for that Work shall terminate 319 | as of the date such litigation is filed. 320 | 321 | 4. Redistribution. You may reproduce and distribute copies of the 322 | Work or Derivative Works thereof in any medium, with or without 323 | modifications, and in Source or Object form, provided that You 324 | meet the following conditions: 325 | 326 | (a) You must give any other recipients of the Work or 327 | Derivative Works a copy of this License; and 328 | 329 | (b) You must cause any modified files to carry prominent notices 330 | stating that You changed the files; and 331 | 332 | (c) You must retain, in the Source form of any Derivative Works 333 | that You distribute, all copyright, patent, trademark, and 334 | attribution notices from the Source form of the Work, 335 | excluding those notices that do not pertain to any part of 336 | the Derivative Works; and 337 | 338 | (d) If the Work includes a "NOTICE" text file as part of its 339 | distribution, then any Derivative Works that You distribute must 340 | include a readable copy of the attribution notices contained 341 | within such NOTICE file, excluding those notices that do not 342 | pertain to any part of the Derivative Works, in at least one 343 | of the following places: within a NOTICE text file distributed 344 | as part of the Derivative Works; within the Source form or 345 | documentation, if provided along with the Derivative Works; or, 346 | within a display generated by the Derivative Works, if and 347 | wherever such third-party notices normally appear. The contents 348 | of the NOTICE file are for informational purposes only and 349 | do not modify the License. You may add Your own attribution 350 | notices within Derivative Works that You distribute, alongside 351 | or as an addendum to the NOTICE text from the Work, provided 352 | that such additional attribution notices cannot be construed 353 | as modifying the License. 354 | 355 | You may add Your own copyright statement to Your modifications and 356 | may provide additional or different license terms and conditions 357 | for use, reproduction, or distribution of Your modifications, or 358 | for any such Derivative Works as a whole, provided Your use, 359 | reproduction, and distribution of the Work otherwise complies with 360 | the conditions stated in this License. 361 | 362 | 5. Submission of Contributions. Unless You explicitly state otherwise, 363 | any Contribution intentionally submitted for inclusion in the Work 364 | by You to the Licensor shall be under the terms and conditions of 365 | this License, without any additional terms or conditions. 366 | Notwithstanding the above, nothing herein shall supersede or modify 367 | the terms of any separate license agreement you may have executed 368 | with Licensor regarding such Contributions. 369 | 370 | 6. Trademarks. This License does not grant permission to use the trade 371 | names, trademarks, service marks, or product names of the Licensor, 372 | except as required for reasonable and customary use in describing the 373 | origin of the Work and reproducing the content of the NOTICE file. 374 | 375 | 7. Disclaimer of Warranty. Unless required by applicable law or 376 | agreed to in writing, Licensor provides the Work (and each 377 | Contributor provides its Contributions) on an "AS IS" BASIS, 378 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 379 | implied, including, without limitation, any warranties or conditions 380 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 381 | PARTICULAR PURPOSE. You are solely responsible for determining the 382 | appropriateness of using or redistributing the Work and assume any 383 | risks associated with Your exercise of permissions under this License. 384 | 385 | 8. Limitation of Liability. In no event and under no legal theory, 386 | whether in tort (including negligence), contract, or otherwise, 387 | unless required by applicable law (such as deliberate and grossly 388 | negligent acts) or agreed to in writing, shall any Contributor be 389 | liable to You for damages, including any direct, indirect, special, 390 | incidental, or consequential damages of any character arising as a 391 | result of this License or out of the use or inability to use the 392 | Work (including but not limited to damages for loss of goodwill, 393 | work stoppage, computer failure or malfunction, or any and all 394 | other commercial damages or losses), even if such Contributor 395 | has been advised of the possibility of such damages. 396 | 397 | 9. Accepting Warranty or Additional Liability. While redistributing 398 | the Work or Derivative Works thereof, You may choose to offer, 399 | and charge a fee for, acceptance of support, warranty, indemnity, 400 | or other liability obligations and/or rights consistent with this 401 | License. However, in accepting such obligations, You may act only 402 | on Your own behalf and on Your sole responsibility, not on behalf 403 | of any other Contributor, and only if You agree to indemnify, 404 | defend, and hold each Contributor harmless for any liability 405 | incurred by, or claims asserted against, such Contributor by reason 406 | of your accepting any such warranty or additional liability. 407 | 408 | END OF TERMS AND CONDITIONS 409 | 410 | APPENDIX: How to apply the Apache License to your work. 411 | 412 | To apply the Apache License to your work, attach the following 413 | boilerplate notice, with the fields enclosed by brackets "{}" 414 | replaced with your own identifying information. (Don't include 415 | the brackets!) The text should be enclosed in the appropriate 416 | comment syntax for the file format. We also recommend that a 417 | file or class name and description of purpose be included on the 418 | same "printed page" as the copyright notice for easier 419 | identification within third-party archives. 420 | 421 | Copyright 2019 Ross Wightman 422 | 423 | Licensed under the Apache License, Version 2.0 (the "License"); 424 | you may not use this file except in compliance with the License. 425 | You may obtain a copy of the License at 426 | 427 | http://www.apache.org/licenses/LICENSE-2.0 428 | 429 | Unless required by applicable law or agreed to in writing, software 430 | distributed under the License is distributed on an "AS IS" BASIS, 431 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 432 | See the License for the specific language governing permissions and 433 | limitations under the License. 434 | --------------------------------------------------------------------------------