├── src ├── facefusion │ ├── __init__.py │ ├── affine.py │ └── gfpgan_onnx.py ├── midas │ ├── backbones │ │ ├── __init__.py │ │ ├── swin.py │ │ ├── swin2.py │ │ ├── next_vit.py │ │ ├── swin_common.py │ │ └── levit.py │ ├── base_model.py │ ├── midas_net.py │ └── __init__.py ├── timm │ ├── version.py │ ├── data │ │ ├── parsers │ │ │ ├── __init__.py │ │ │ ├── parser.py │ │ │ ├── class_map.py │ │ │ ├── parser_factory.py │ │ │ ├── img_extensions.py │ │ │ ├── parser_image_tar.py │ │ │ └── parser_image_folder.py │ │ ├── constants.py │ │ ├── __init__.py │ │ ├── real_labels.py │ │ └── config.py │ ├── utils │ │ ├── random.py │ │ ├── misc.py │ │ ├── __init__.py │ │ ├── clip_grad.py │ │ ├── distributed.py │ │ ├── metrics.py │ │ ├── log.py │ │ ├── summary.py │ │ ├── agc.py │ │ ├── cuda.py │ │ ├── decay_batch.py │ │ └── jit.py │ ├── loss │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ ├── jsd.py │ │ ├── binary_cross_entropy.py │ │ └── asymmetric_loss.py │ ├── __init__.py │ ├── scheduler │ │ ├── __init__.py │ │ ├── step_lr.py │ │ ├── multistep_lr.py │ │ └── plateau_lr.py │ ├── models │ │ ├── layers │ │ │ ├── trace_utils.py │ │ │ ├── linear.py │ │ │ ├── helpers.py │ │ │ ├── conv2d_same.py │ │ │ ├── blur_pool.py │ │ │ ├── create_conv2d.py │ │ │ ├── patch_embed.py │ │ │ ├── median_pool.py │ │ │ ├── space_to_depth.py │ │ │ ├── create_norm.py │ │ │ ├── mixed_conv2d.py │ │ │ ├── test_time_pool.py │ │ │ ├── padding.py │ │ │ ├── classifier.py │ │ │ ├── global_context.py │ │ │ ├── fast_norm.py │ │ │ ├── __init__.py │ │ │ ├── filter_response_norm.py │ │ │ ├── activations_jit.py │ │ │ ├── separable_conv.py │ │ │ ├── squeeze_excite.py │ │ │ ├── pool2d_same.py │ │ │ ├── split_attn.py │ │ │ ├── conv_bn_act.py │ │ │ ├── config.py │ │ │ ├── inplace_abn.py │ │ │ ├── split_batchnorm.py │ │ │ ├── create_attn.py │ │ │ ├── create_norm_act.py │ │ │ └── gather_excite.py │ │ ├── __init__.py │ │ └── factory.py │ └── optim │ │ ├── __init__.py │ │ ├── sgdp.py │ │ ├── lookahead.py │ │ ├── radam.py │ │ ├── adamp.py │ │ └── nadam.py ├── dwpose │ ├── dw_onnx │ │ ├── __init__.py │ │ └── cv_ox_yolo_nas.py │ ├── dw_torchscript │ │ └── __init__.py │ ├── types.py │ └── hand.py ├── canny │ └── __init__.py └── densepose │ └── __init__.py ├── .gitignore ├── README.md ├── nodes ├── face │ ├── __init__.py │ └── face_enhance_node.py ├── video │ ├── video_formats │ │ ├── h264-mp4.json │ │ ├── nvenc_h265-mp4.json │ │ ├── webm.json │ │ ├── h265-mp4.json │ │ └── nvenc_h264-mp4.json │ ├── info_node.py │ └── batch_node.py ├── preprocessor │ ├── canny_node.py │ ├── densepose_node.py │ ├── midas_node.py │ └── lineart_node.py ├── other │ └── vram_node.py ├── image │ ├── watermark_node.py │ ├── load_node.py │ └── save_node.py └── mask │ └── mask_node.py ├── requirements.txt ├── config.yaml ├── web └── fix_touchpad_pan_and_zoom.js ├── utils.py └── __init__.py /src/facefusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/midas/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | **/__pycache__ 3 | 4 | -------------------------------------------------------------------------------- /src/timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.6.13' 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Image Node: 2 | 3 | ## Video Node 4 | 5 | ## ControlNet PreProcessor -------------------------------------------------------------------------------- /src/dwpose/dw_onnx/__init__.py: -------------------------------------------------------------------------------- 1 | #Dummy file ensuring this package will be recognized -------------------------------------------------------------------------------- /src/dwpose/dw_torchscript/__init__.py: -------------------------------------------------------------------------------- 1 | #Dummy file ensuring this package will be recognized -------------------------------------------------------------------------------- /src/timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /nodes/face/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import folder_paths 4 | 5 | 6 | model_path = os.path.join(folder_paths.models_dir, "facefusion") 7 | folder_paths.add_model_folder_path('facefusion', model_path) 8 | 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | opencv-python-headless>=4.7.0.72 4 | scikit-learn 5 | scikit-image 6 | insightface 7 | ultralytics 8 | onnxruntime-gpu==1.18.0 9 | onnxruntime==1.18.0 10 | imageio_ffmpeg 11 | pykalman -------------------------------------------------------------------------------- /src/timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | annotator_ckpts_path: "../../models/annotator" 2 | custom_temp_path: "../../temp" 3 | USE_SYMLINKS: False 4 | EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"] -------------------------------------------------------------------------------- /nodes/video/video_formats/h264-mp4.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_pass": 3 | [ 4 | "-y", "-c:v", "libx264", 5 | "-pix_fmt", "yuv420p", 6 | "-crf", "20" 7 | ], 8 | "audio_pass": ["-c:a", "aac"], 9 | "extension": "mp4" 10 | } -------------------------------------------------------------------------------- /nodes/video/video_formats/nvenc_h265-mp4.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_pass": 3 | [ 4 | "-y", "-c:v", "hevc_nvenc", 5 | "-vtag", "hvc1", 6 | "-qp", "22" 7 | ], 8 | "audio_pass": ["-c:a", "aac"], 9 | "extension": "mp4" 10 | } -------------------------------------------------------------------------------- /nodes/video/video_formats/webm.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_pass": 3 | [ 4 | "-y", 5 | "-crf", 20, 6 | "-pix_fmt", "yuv420p", 7 | "-b:v", "0" 8 | ], 9 | "audio_pass": ["-c:a", "libvorbis"], 10 | "extension": "webm" 11 | } -------------------------------------------------------------------------------- /src/timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /src/timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ 4 | get_pretrained_cfg_value, is_model_pretrained 5 | -------------------------------------------------------------------------------- /src/timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler 9 | -------------------------------------------------------------------------------- /nodes/video/video_formats/h265-mp4.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_pass": 3 | [ 4 | "-y", "-c:v", "libx265", 5 | "-vtag", "hvc1", 6 | "-pix_fmt", "yuv420p10le", 7 | "-crf", "22", 8 | "-preset", "medium", 9 | "-x265-params", "log-level=quiet" 10 | ], 11 | "audio_pass": ["-c:a", "aac"], 12 | "extension": "mp4" 13 | } -------------------------------------------------------------------------------- /nodes/video/video_formats/nvenc_h264-mp4.json: -------------------------------------------------------------------------------- 1 | { 2 | "main_pass": 3 | [ 4 | "-y", "-c:v", "h264_nvenc", 5 | "-pix_fmt", "yuv420p", 6 | "-qp", "20" 7 | ], 8 | "audio_pass": ["-c:a", "aac"], 9 | "bitrate": ["bitrate","INT", {"default": 10, "min": 1, "max": 999, "step": 1 }], 10 | "megabit": ["megabit","BOOLEAN", {"default": true}], 11 | "extension": "mp4" 12 | } -------------------------------------------------------------------------------- /src/midas/backbones/swin.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from .swin_common import _make_swin_backbone 4 | 5 | 6 | def _make_pretrained_swinl12_384(pretrained, hooks=None): 7 | model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) 8 | 9 | hooks = [1, 1, 17, 1] if hooks == None else hooks 10 | return _make_swin_backbone( 11 | model, 12 | hooks=hooks 13 | ) 14 | -------------------------------------------------------------------------------- /src/timm/models/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /src/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 9 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 10 | -------------------------------------------------------------------------------- /src/timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adahessian import Adahessian 4 | from .adamp import AdamP 5 | from .adamw import AdamW 6 | from .lamb import Lamb 7 | from .lars import Lars 8 | from .lookahead import Lookahead 9 | from .madgrad import MADGRAD 10 | from .nadam import Nadam 11 | from .nvnovograd import NvNovoGrad 12 | from .radam import RAdam 13 | from .rmsprop_tf import RMSpropTF 14 | from .sgdp import SGDP 15 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 16 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /src/timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser,\ 10 | get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 11 | from .real_labels import RealLabelsImagenet 12 | from .transforms import * 13 | from .transforms_factory import create_transform 14 | -------------------------------------------------------------------------------- /src/timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /src/timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .decay_batch import decay_batch_step, check_batch_size_retry 6 | from .distributed import distribute_bn, reduce_tensor 7 | from .jit import set_jit_legacy, set_jit_fuser 8 | from .log import setup_default_logging, FormatterNoInfo 9 | from .metrics import AverageMeter, accuracy 10 | from .misc import natural_key, add_bool_arg 11 | from .model import unwrap_model, get_state_dict, freeze, unfreeze 12 | from .model_ema import ModelEma, ModelEmaV2 13 | from .random import random_seed 14 | from .summary import update_summary, get_outdir 15 | -------------------------------------------------------------------------------- /src/timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /src/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from common import resize_image_with_pad, common_input_validate, HWC3 6 | 7 | class CannyDetector: 8 | def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): 9 | input_image, output_type = common_input_validate(input_image, output_type, **kwargs) 10 | detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) 11 | detected_map = cv2.Canny(detected_map, low_threshold, high_threshold) 12 | detected_map = HWC3(remove_pad(detected_map)) 13 | 14 | if output_type == "pil": 15 | detected_map = Image.fromarray(detected_map) 16 | 17 | return detected_map -------------------------------------------------------------------------------- /src/timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /src/dwpose/types.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, List, Optional 2 | 3 | class Keypoint(NamedTuple): 4 | x: float 5 | y: float 6 | score: float = 1.0 7 | id: int = -1 8 | 9 | 10 | class BodyResult(NamedTuple): 11 | # Note: Using `Optional` instead of `|` operator as the ladder is a Python 12 | # 3.10 feature. 13 | # Annotator code should be Python 3.8 Compatible, as controlnet repo uses 14 | # Python 3.8 environment. 15 | # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 16 | keypoints: List[Optional[Keypoint]] 17 | total_score: float = 0.0 18 | total_parts: int = 0 19 | 20 | 21 | HandResult = List[Keypoint] 22 | FaceResult = List[Keypoint] 23 | AnimalPoseResult = List[Keypoint] 24 | 25 | 26 | class PoseResult(NamedTuple): 27 | body: BodyResult 28 | left_hand: Optional[HandResult] 29 | right_hand: Optional[HandResult] 30 | face: Optional[FaceResult] 31 | -------------------------------------------------------------------------------- /nodes/preprocessor/canny_node.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_annotator_call, create_node_input_types 2 | import comfy.model_management as model_management 3 | import nodes 4 | 5 | class Canny_Preprocessor: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return create_node_input_types( 9 | low_threshold=("INT", {"default": 100, "min": 0, "max": 255}), 10 | high_threshold=("INT", {"default": 100, "min": 0, "max": 255}), 11 | resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64}) 12 | ) 13 | 14 | RETURN_TYPES = ("IMAGE",) 15 | FUNCTION = "execute" 16 | 17 | CATEGORY = "tbox/ControlNet Preprocessors" 18 | 19 | def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs): 20 | from canny import CannyDetector 21 | 22 | return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), ) 23 | -------------------------------------------------------------------------------- /src/timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | def load_class_map(map_or_filename, root=''): 5 | if isinstance(map_or_filename, dict): 6 | assert dict, 'class_map dict must be non-empty' 7 | return map_or_filename 8 | class_map_path = map_or_filename 9 | if not os.path.exists(class_map_path): 10 | class_map_path = os.path.join(root, class_map_path) 11 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 12 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 13 | if class_map_ext == '.txt': 14 | with open(class_map_path) as f: 15 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 16 | elif class_map_ext == '.pkl': 17 | with open(class_map_path,'rb') as f: 18 | class_to_idx = pickle.load(f) 19 | else: 20 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 21 | return class_to_idx 22 | 23 | -------------------------------------------------------------------------------- /src/timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /src/timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /nodes/preprocessor/densepose_node.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_annotator_call, create_node_input_types 2 | import comfy.model_management as model_management 3 | 4 | class DensePose_Preprocessor: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return create_node_input_types( 8 | model=(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"], {"default": "densepose_r50_fpn_dl.torchscript"}), 9 | cmap=(["Viridis (MagicAnimate)", "Parula (CivitAI)"], {"default": "Viridis (MagicAnimate)"}) 10 | ) 11 | 12 | RETURN_TYPES = ("IMAGE",) 13 | FUNCTION = "execute" 14 | 15 | CATEGORY = "tbox/ControlNet Preprocessors" 16 | 17 | def execute(self, image, model, cmap, resolution=512): 18 | from densepose import DenseposeDetector 19 | model = DenseposeDetector \ 20 | .from_pretrained(filename=model) \ 21 | .to(model_management.get_torch_device()) 22 | return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), ) 23 | -------------------------------------------------------------------------------- /src/midas/backbones/swin2.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from .swin_common import _make_swin_backbone 4 | 5 | 6 | def _make_pretrained_swin2l24_384(pretrained, hooks=None): 7 | model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) 8 | 9 | hooks = [1, 1, 17, 1] if hooks == None else hooks 10 | return _make_swin_backbone( 11 | model, 12 | hooks=hooks 13 | ) 14 | 15 | 16 | def _make_pretrained_swin2b24_384(pretrained, hooks=None): 17 | model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) 18 | 19 | hooks = [1, 1, 17, 1] if hooks == None else hooks 20 | return _make_swin_backbone( 21 | model, 22 | hooks=hooks 23 | ) 24 | 25 | 26 | def _make_pretrained_swin2t16_256(pretrained, hooks=None): 27 | model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) 28 | 29 | hooks = [1, 1, 5, 1] if hooks == None else hooks 30 | return _make_swin_backbone( 31 | model, 32 | hooks=hooks, 33 | patch_grid=[64, 64] 34 | ) 35 | -------------------------------------------------------------------------------- /nodes/preprocessor/midas_node.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_annotator_call, create_node_input_types 2 | import comfy.model_management as model_management 3 | import numpy as np 4 | 5 | class MIDAS_Depth_Map_Preprocessor: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return create_node_input_types( 9 | a = ("FLOAT", {"default": np.pi * 2.0, "min": 0.0, "max": np.pi * 5.0, "step": 0.05}), 10 | bg_threshold = ("FLOAT", {"default": 0.1, "min": 0, "max": 1, "step": 0.05}) 11 | ) 12 | 13 | RETURN_TYPES = ("IMAGE",) 14 | FUNCTION = "execute" 15 | 16 | CATEGORY = "tbox/ControlNet Preprocessors" 17 | 18 | def execute(self, image, a, bg_threshold, resolution=512, **kwargs): 19 | from midas import MidasDetector 20 | 21 | # Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py 22 | model = MidasDetector.from_pretrained().to(model_management.get_torch_device()) 23 | out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold) 24 | del model 25 | return (out, ) 26 | -------------------------------------------------------------------------------- /src/timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_in_tar import ParserImageInTar 5 | 6 | 7 | def create_parser(name, root, split='train', **kwargs): 8 | name = name.lower() 9 | name = name.split('/', 2) 10 | prefix = '' 11 | if len(name) > 1: 12 | prefix = name[0] 13 | name = name[-1] 14 | 15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 16 | # explicitly select other options shortly 17 | if prefix == 'tfds': 18 | from .parser_tfds import ParserTfds # defer tensorflow import 19 | parser = ParserTfds(root, name, split=split, **kwargs) 20 | else: 21 | assert os.path.exists(root) 22 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 23 | # FIXME support split here, in parser? 24 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 25 | parser = ParserImageInTar(root, **kwargs) 26 | else: 27 | parser = ParserImageFolder(root, **kwargs) 28 | return parser 29 | -------------------------------------------------------------------------------- /nodes/video/info_node.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class VideoInfoNode: 4 | @classmethod 5 | def INPUT_TYPES(s): 6 | return { 7 | "required": { 8 | "video_info": ("VHS_VIDEOINFO",), 9 | } 10 | } 11 | 12 | CATEGORY = "tbox/Video" 13 | 14 | RETURN_TYPES = ("FLOAT", "INT", "FLOAT", "INT", "INT", "FLOAT","INT", "FLOAT", "INT", "INT") 15 | RETURN_NAMES = ( 16 | "source_fps", 17 | "source_frame_count", 18 | "source_duration", 19 | "source_width", 20 | "source_height", 21 | "loaded_fps", 22 | "loaded_frame_count", 23 | "loaded_duration", 24 | "loaded_width", 25 | "loaded_height", 26 | ) 27 | FUNCTION = "get_video_info" 28 | 29 | def get_video_info(self, video_info): 30 | keys = ["fps", "frame_count", "duration", "width", "height"] 31 | 32 | source_info = [] 33 | loaded_info = [] 34 | 35 | for key in keys: 36 | source_info.append(video_info[f"source_{key}"]) 37 | loaded_info.append(video_info[f"loaded_{key}"]) 38 | 39 | return (*source_info, *loaded_info) -------------------------------------------------------------------------------- /src/midas/backbones/next_vit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | import torch.nn as nn 4 | 5 | from pathlib import Path 6 | from .utils import activations, forward_default, get_activation 7 | 8 | from ..external.next_vit.classification.nextvit import * 9 | 10 | 11 | def forward_next_vit(pretrained, x): 12 | return forward_default(pretrained, x, "forward") 13 | 14 | 15 | def _make_next_vit_backbone( 16 | model, 17 | hooks=[2, 6, 36, 39], 18 | ): 19 | pretrained = nn.Module() 20 | 21 | pretrained.model = model 22 | pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) 23 | pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) 24 | pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) 25 | pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) 26 | 27 | pretrained.activations = activations 28 | 29 | return pretrained 30 | 31 | 32 | def _make_pretrained_next_vit_large_6m(hooks=None): 33 | model = timm.create_model("nextvit_large") 34 | 35 | hooks = [2, 6, 36, 39] if hooks == None else hooks 36 | return _make_next_vit_backbone( 37 | model, 38 | hooks=hooks, 39 | ) 40 | -------------------------------------------------------------------------------- /src/timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pdas a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /src/timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /nodes/other/vram_node.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch.cuda 3 | import comfy.model_management 4 | 5 | class AnyType(str): 6 | """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" 7 | def __eq__(self, __value: object) -> bool: 8 | return True 9 | def __ne__(self, __value: object) -> bool: 10 | return False 11 | 12 | any = AnyType("*") 13 | 14 | class PurgeVRAMNode: 15 | 16 | def __init__(self): 17 | pass 18 | 19 | @classmethod 20 | def INPUT_TYPES(cls): 21 | return { 22 | "required": { 23 | "anything": (any, {}), 24 | "purge_cache": ("BOOLEAN", {"default": True}), 25 | "purge_models": ("BOOLEAN", {"default": True}), 26 | }, 27 | "optional": { 28 | } 29 | } 30 | 31 | RETURN_TYPES = () 32 | FUNCTION = "purge_vram" 33 | CATEGORY = "tbox/other" 34 | OUTPUT_NODE = True 35 | 36 | def purge_vram(self, anything, purge_cache, purge_models): 37 | 38 | gc.collect() 39 | if torch.cuda.is_available(): 40 | torch.cuda.empty_cache() 41 | torch.cuda.ipc_collect() 42 | if purge_models: 43 | comfy.model_management.unload_all_models() 44 | comfy.model_management.soft_empty_cache() 45 | return (None,) 46 | -------------------------------------------------------------------------------- /src/timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | def get_outdir(path, *paths, inc=False): 14 | outdir = os.path.join(path, *paths) 15 | if not os.path.exists(outdir): 16 | os.makedirs(outdir) 17 | elif inc: 18 | count = 1 19 | outdir_inc = outdir + '-' + str(count) 20 | while os.path.exists(outdir_inc): 21 | count = count + 1 22 | outdir_inc = outdir + '-' + str(count) 23 | assert count < 100 24 | outdir = outdir_inc 25 | os.makedirs(outdir) 26 | return outdir 27 | 28 | 29 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): 30 | rowd = OrderedDict(epoch=epoch) 31 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 32 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 33 | if log_wandb: 34 | wandb.log(rowd) 35 | with open(filename, mode='a') as cf: 36 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 37 | if write_header: # first iteration (epoch == 1 can't be used) 38 | dw.writeheader() 39 | dw.writerow(rowd) 40 | -------------------------------------------------------------------------------- /web/fix_touchpad_pan_and_zoom.js: -------------------------------------------------------------------------------- 1 | // @ts-check 2 | /** @type {any} */ 3 | const { self } = window 4 | 5 | /** @type {import("../../../web/types/litegraph")} */ 6 | const { LGraphCanvas } = self 7 | 8 | // @ts-ignore 9 | import * as ComfyUI_module from "../../../scripts/app.js" 10 | /** @type { import("../../../web/scripts/app.js") } */ 11 | const { app } = ComfyUI_module 12 | 13 | ////////////////////////////// 14 | 15 | /** 16 | * Smooth scrolling for touchpad 17 | */ 18 | LGraphCanvas.prototype.processMouseWheel = function (/** @type {WheelEvent}*/ event) { 19 | if (!this.graph || !this.allow_dragcanvas) return 20 | 21 | const { clientX: x, clientY: y } = event 22 | if (this.viewport) { 23 | const [viewportX, viewportY, width, height] = this.viewport 24 | const isInsideX = x >= viewportX && x < viewportX + width 25 | const isInsideY = y >= viewportY && y < viewportY + height 26 | if (!(isInsideX && isInsideY)) return 27 | } 28 | 29 | let scale = this.ds.scale 30 | let { deltaX, deltaY } = event 31 | 32 | if (event.metaKey || event.ctrlKey) { 33 | let SCALE = event.ctrlKey ? 150 : 100 34 | if (event.metaKey) SCALE *= -1 / 0.5 35 | this.ds.changeScale(scale - deltaY / SCALE, [event.clientX, event.clientY]) 36 | } else { 37 | this.ds.mouseDrag(-deltaX, -deltaY) 38 | } 39 | this.graph.change() 40 | 41 | event.preventDefault() 42 | return false // prevent default 43 | } -------------------------------------------------------------------------------- /src/timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /src/timm/data/parsers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /src/timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /src/timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /src/timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /src/timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /src/midas/backbones/swin_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from .utils import activations, forward_default, get_activation, Transpose 7 | 8 | 9 | def forward_swin(pretrained, x): 10 | return forward_default(pretrained, x) 11 | 12 | 13 | def _make_swin_backbone( 14 | model, 15 | hooks=[1, 1, 17, 1], 16 | patch_grid=[96, 96] 17 | ): 18 | pretrained = nn.Module() 19 | 20 | pretrained.model = model 21 | pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) 22 | pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) 23 | pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) 24 | pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) 25 | 26 | pretrained.activations = activations 27 | 28 | if hasattr(model, "patch_grid"): 29 | used_patch_grid = model.patch_grid 30 | else: 31 | used_patch_grid = patch_grid 32 | 33 | patch_grid_size = np.array(used_patch_grid, dtype=int) 34 | 35 | pretrained.act_postprocess1 = nn.Sequential( 36 | Transpose(1, 2), 37 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 38 | ) 39 | pretrained.act_postprocess2 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) 42 | ) 43 | pretrained.act_postprocess3 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) 46 | ) 47 | pretrained.act_postprocess4 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) 50 | ) 51 | 52 | return pretrained 53 | -------------------------------------------------------------------------------- /src/timm/models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__( 19 | self, 20 | img_size=224, 21 | patch_size=16, 22 | in_chans=3, 23 | embed_dim=768, 24 | norm_layer=None, 25 | flatten=True, 26 | bias=True, 27 | ): 28 | super().__init__() 29 | img_size = to_2tuple(img_size) 30 | patch_size = to_2tuple(patch_size) 31 | self.img_size = img_size 32 | self.patch_size = patch_size 33 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 34 | self.num_patches = self.grid_size[0] * self.grid_size[1] 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 38 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 39 | 40 | def forward(self, x): 41 | B, C, H, W = x.shape 42 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 43 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 44 | x = self.proj(x) 45 | if self.flatten: 46 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 47 | x = self.norm(x) 48 | return x 49 | -------------------------------------------------------------------------------- /nodes/preprocessor/lineart_node.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_annotator_call, create_node_input_types 2 | import comfy.model_management as model_management 3 | import nodes 4 | 5 | class LineArt_Preprocessor: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return create_node_input_types( 9 | coarse=(["disable", "enable"], {"default": "enable"}), 10 | resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64}) 11 | ) 12 | 13 | RETURN_TYPES = ("IMAGE",) 14 | FUNCTION = "execute" 15 | 16 | CATEGORY = "tbox/ControlNet Preprocessors" 17 | 18 | 19 | def execute(self, image, resolution=512, **kwargs): 20 | from lineart import LineartDetector 21 | 22 | model = LineartDetector.from_pretrained().to(model_management.get_torch_device()) 23 | out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable") 24 | del model 25 | return (out, ) 26 | 27 | class Lineart_Standard_Preprocessor: 28 | @classmethod 29 | def INPUT_TYPES(s): 30 | return create_node_input_types( 31 | guassian_sigma=("FLOAT", {"default":6.0, "max": 100.0}), 32 | intensity_threshold=("INT", {"default": 8, "max": 16}), 33 | resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64}) 34 | ) 35 | 36 | RETURN_TYPES = ("IMAGE",) 37 | FUNCTION = "execute" 38 | 39 | CATEGORY = "tbox/ControlNet Preprocessors" 40 | 41 | 42 | def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs): 43 | from lineart import LineartStandardDetector 44 | return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), ) -------------------------------------------------------------------------------- /src/timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 21 | with amp.scale_loss(loss, optimizer) as scaled_loss: 22 | scaled_loss.backward(create_graph=create_graph) 23 | if clip_grad is not None: 24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 25 | optimizer.step() 26 | 27 | def state_dict(self): 28 | if 'state_dict' in amp.__dict__: 29 | return amp.state_dict() 30 | 31 | def load_state_dict(self, state_dict): 32 | if 'load_state_dict' in amp.__dict__: 33 | amp.load_state_dict(state_dict) 34 | 35 | 36 | class NativeScaler: 37 | state_dict_key = "amp_scaler" 38 | 39 | def __init__(self): 40 | self._scaler = torch.cuda.amp.GradScaler() 41 | 42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 43 | self._scaler.scale(loss).backward(create_graph=create_graph) 44 | if clip_grad is not None: 45 | assert parameters is not None 46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 48 | self._scaler.step(optimizer) 49 | self._scaler.update() 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) 56 | -------------------------------------------------------------------------------- /src/timm/utils/decay_batch.py: -------------------------------------------------------------------------------- 1 | """ Batch size decay and retry helpers. 2 | 3 | Copyright 2022 Ross Wightman 4 | """ 5 | import math 6 | 7 | 8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): 9 | """ power of two batch-size decay with intra steps 10 | 11 | Decay by stepping between powers of 2: 12 | * determine power-of-2 floor of current batch size (base batch size) 13 | * divide above value by num_intra_steps to determine step size 14 | * floor batch_size to nearest multiple of step_size (from base batch size) 15 | Examples: 16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 20 | """ 21 | if batch_size <= 1: 22 | # return 0 for stopping value so easy to use in loop 23 | return 0 24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) 25 | step_size = max(base_batch_size // num_intra_steps, 1) 26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size 27 | if no_odd and batch_size % 2: 28 | batch_size -= 1 29 | return batch_size 30 | 31 | 32 | def check_batch_size_retry(error_str): 33 | """ check failure error string for conditions where batch decay retry should not be attempted 34 | """ 35 | error_str = error_str.lower() 36 | if 'required rank' in error_str: 37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's 38 | # not compatible with channels_last memory format. 39 | return False 40 | if 'illegal' in error_str: 41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state 42 | return False 43 | return True 44 | -------------------------------------------------------------------------------- /src/timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /src/timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import types 8 | import functools 9 | 10 | import torch.nn as nn 11 | 12 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 13 | 14 | _NORM_MAP = dict( 15 | batchnorm=nn.BatchNorm2d, 16 | batchnorm2d=nn.BatchNorm2d, 17 | batchnorm1d=nn.BatchNorm1d, 18 | groupnorm=GroupNorm, 19 | groupnorm1=GroupNorm1, 20 | layernorm=LayerNorm, 21 | layernorm2d=LayerNorm2d, 22 | ) 23 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 24 | 25 | 26 | def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): 27 | layer = get_norm_layer(layer_name, act_layer=act_layer) 28 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 29 | return layer_instance 30 | 31 | 32 | def get_norm_layer(norm_layer): 33 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 34 | norm_kwargs = {} 35 | 36 | # unbind partial fn, so args can be rebound later 37 | if isinstance(norm_layer, functools.partial): 38 | norm_kwargs.update(norm_layer.keywords) 39 | norm_layer = norm_layer.func 40 | 41 | if isinstance(norm_layer, str): 42 | layer_name = norm_layer.replace('_', '') 43 | norm_layer = _NORM_MAP.get(layer_name, None) 44 | elif norm_layer in _NORM_TYPES: 45 | norm_layer = norm_layer 46 | elif isinstance(norm_layer, types.FunctionType): 47 | # if function type, assume it is a lambda/fn that creates a norm layer 48 | norm_layer = norm_layer 49 | else: 50 | type_name = norm_layer.__name__.lower().replace('_', '') 51 | norm_layer = _NORM_MAP.get(type_name, None) 52 | assert norm_layer is not None, f"No equivalent norm layer for {type_name}" 53 | 54 | if norm_kwargs: 55 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 56 | return norm_layer 57 | -------------------------------------------------------------------------------- /src/timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /src/timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /src/facefusion/affine.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | ffhq_512 = np.array([ 5 | [ 0.37691676, 0.46864664 ], 6 | [ 0.62285697, 0.46912813 ], 7 | [ 0.50123859, 0.61331904 ], 8 | [ 0.39308822, 0.72541100 ], 9 | [ 0.61150205, 0.72490465 ] 10 | ]) 11 | 12 | 13 | def warp_face_by_landmark(image , face_landmark_5, crop_size ): 14 | normed_warp_template = ffhq_512 * crop_size 15 | affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0] 16 | cropped = cv2.warpAffine(image, affine_matrix, crop_size, borderMode = cv2.BORDER_REPLICATE, flags = cv2.INTER_AREA) 17 | return cropped, affine_matrix 18 | 19 | 20 | def create_box_mask(crop_size, face_mask_blur, face_mask_padding): 21 | blur_amount = int(crop_size[0] * 0.5 * face_mask_blur) 22 | blur_area = max(blur_amount // 2, 1) 23 | box_mask = np.ones(crop_size, np.float32) 24 | box_mask[:max(blur_area, int(crop_size[1] * face_mask_padding[0] / 100)), :] = 0 25 | box_mask[-max(blur_area, int(crop_size[1] * face_mask_padding[2] / 100)):, :] = 0 26 | box_mask[:, :max(blur_area, int(crop_size[0] * face_mask_padding[3] / 100))] = 0 27 | box_mask[:, -max(blur_area, int(crop_size[0] * face_mask_padding[1] / 100)):] = 0 28 | if blur_amount > 0: 29 | box_mask = cv2.GaussianBlur(box_mask, (0, 0), blur_amount * 0.25) 30 | return box_mask 31 | 32 | def paste_back(image, cropped, crop_mask, affine_matrix): 33 | inverse_matrix = cv2.invertAffineTransform(affine_matrix) 34 | temp_size = image.shape[:2][::-1] 35 | inverse_mask = cv2.warpAffine(crop_mask, inverse_matrix, temp_size).clip(0, 1) 36 | inverse_vision_frame = cv2.warpAffine(cropped, inverse_matrix, temp_size, borderMode = cv2.BORDER_REPLICATE) 37 | paste_vision_frame = image.copy() 38 | paste_vision_frame[:, :, 0] = inverse_mask * inverse_vision_frame[:, :, 0] + (1 - inverse_mask) * image[:, :, 0] 39 | paste_vision_frame[:, :, 1] = inverse_mask * inverse_vision_frame[:, :, 1] + (1 - inverse_mask) * image[:, :, 1] 40 | paste_vision_frame[:, :, 2] = inverse_mask * inverse_vision_frame[:, :, 2] + (1 - inverse_mask) * image[:, :, 2] 41 | return paste_vision_frame 42 | -------------------------------------------------------------------------------- /src/timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /src/timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, 18 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): 19 | super(BinaryCrossEntropy, self).__init__() 20 | assert 0. <= smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.target_threshold = target_threshold 23 | self.reduction = reduction 24 | self.register_buffer('weight', weight) 25 | self.register_buffer('pos_weight', pos_weight) 26 | 27 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 28 | assert x.shape[0] == target.shape[0] 29 | if target.shape != x.shape: 30 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 31 | num_classes = x.shape[-1] 32 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 33 | off_value = self.smoothing / num_classes 34 | on_value = 1. - self.smoothing + off_value 35 | target = target.long().view(-1, 1) 36 | target = torch.full( 37 | (target.size()[0], num_classes), 38 | off_value, 39 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 40 | if self.target_threshold is not None: 41 | # Make target 0, or 1 if threshold set 42 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 43 | return F.binary_cross_entropy_with_logits( 44 | x, target, 45 | self.weight, 46 | pos_weight=self.pos_weight, 47 | reduction=self.reduction) 48 | -------------------------------------------------------------------------------- /nodes/image/watermark_node.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from PIL import Image, ImageSequence, ImageOps 5 | from ..utils import tensor2pil, pil2tensor 6 | 7 | PADDING = 4 8 | 9 | class WatermarkNode: 10 | 11 | @classmethod 12 | def INPUT_TYPES(cls): 13 | return { 14 | "required": { 15 | "images": ("IMAGE",), 16 | "logo_list": ("IMAGE",), 17 | }, 18 | "optional": { 19 | "logo_mask": ("MASK",), 20 | "enabled": ("BOOLEAN", {"default": True}),} 21 | } 22 | RETURN_TYPES = ("IMAGE",) 23 | FUNCTION = "watermark" 24 | CATEGORY = "tbox/Image" 25 | 26 | def watermark(self, images, logo_list, logo_mask, enabled): 27 | outputs = [] 28 | if enabled == False: 29 | return(images,) 30 | print(f'logo shape: {logo_list.shape}') 31 | print(f'images shape: {images.shape}') 32 | logo = tensor2pil(logo_list[0]) 33 | if logo_mask is not None: 34 | logo_mask = tensor2pil(logo_mask) 35 | for i, image in enumerate(images): 36 | img = tensor2pil(image) #Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 37 | dst = self.add_watermark2(img, logo, logo_mask, 85) 38 | result = pil2tensor(dst) 39 | outputs.append(result) 40 | base_image = torch.stack([tensor.squeeze() for tensor in outputs]) 41 | return (base_image,) 42 | 43 | def add_watermark2(self, image, logo, logo_mask, opacity=None): 44 | logo_width, logo_height = logo.size 45 | image_width, image_height = image.size 46 | if image_height <= logo_height + PADDING * 2 or image_width <= logo_width + PADDING * 2: 47 | return image 48 | y = image_height - logo_height - PADDING * 1 49 | x = PADDING 50 | logo = logo.convert('RGBA') 51 | opacity = int(opacity / 100 * 255) 52 | logo.putalpha(Image.new("L", logo.size, opacity)) 53 | if logo_mask is not None: 54 | logo.putalpha(ImageOps.invert(logo_mask)) 55 | 56 | position = (x, y) 57 | image.paste(logo, position, logo) 58 | return image 59 | -------------------------------------------------------------------------------- /src/timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__(self, 15 | optimizer: torch.optim.Optimizer, 16 | decay_t: List[int], 17 | decay_rate: float = 1., 18 | warmup_t=0, 19 | warmup_lr_init=0, 20 | t_in_epochs=True, 21 | noise_range_t=None, 22 | noise_pct=0.67, 23 | noise_std=1.0, 24 | noise_seed=42, 25 | initialize=True, 26 | ) -> None: 27 | super().__init__( 28 | optimizer, param_group_field="lr", 29 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 30 | initialize=initialize) 31 | 32 | self.decay_t = decay_t 33 | self.decay_rate = decay_rate 34 | self.warmup_t = warmup_t 35 | self.warmup_lr_init = warmup_lr_init 36 | self.t_in_epochs = t_in_epochs 37 | if self.warmup_t: 38 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 39 | super().update_groups(self.warmup_lr_init) 40 | else: 41 | self.warmup_steps = [1 for _ in self.base_values] 42 | 43 | def get_curr_decay_steps(self, t): 44 | # find where in the array t goes, 45 | # assumes self.decay_t is sorted 46 | return bisect.bisect_right(self.decay_t, t+1) 47 | 48 | def _get_lr(self, t): 49 | if t < self.warmup_t: 50 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 51 | else: 52 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 53 | return lrs 54 | 55 | def get_epoch_values(self, epoch: int): 56 | if self.t_in_epochs: 57 | return self._get_lr(epoch) 58 | else: 59 | return None 60 | 61 | def get_update_values(self, num_updates: int): 62 | if not self.t_in_epochs: 63 | return self._get_lr(num_updates) 64 | else: 65 | return None 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import yaml 5 | import comfy.utils 6 | import numpy as np 7 | import tempfile 8 | from pathlib import Path 9 | 10 | USE_SYMLINKS = False 11 | 12 | here = Path(__file__).parent.resolve() 13 | 14 | config_path = Path(here, "config.yaml") 15 | 16 | ANNOTATOR_CKPTS_PATH = "" 17 | TEMP_DIR = "" 18 | USE_SYMLINKS = False 19 | ORT_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider", "CoreMLExecutionProvider"] 20 | 21 | print(f'here: {here}') 22 | 23 | if os.path.exists(config_path): 24 | config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 25 | 26 | ANNOTATOR_CKPTS_PATH = str(Path(here, config["annotator_ckpts_path"])) 27 | TEMP_DIR = str(Path(here, config["custom_temp_path"]).resolve()) 28 | USE_SYMLINKS = config["USE_SYMLINKS"] 29 | ORT_PROVIDERS = config["EP_list"] 30 | 31 | if TEMP_DIR is None: 32 | TEMP_DIR = tempfile.gettempdir() 33 | elif not os.path.isdir(TEMP_DIR): 34 | try: 35 | os.makedirs(TEMP_DIR) 36 | except: 37 | print(f"Failed to create custom temp directory. Using default.") 38 | TEMP_DIR = tempfile.gettempdir() 39 | 40 | if not os.path.isdir(ANNOTATOR_CKPTS_PATH): 41 | try: 42 | os.makedirs(ANNOTATOR_CKPTS_PATH) 43 | except: 44 | print(f"Failed to create config ckpts directory. Using default.") 45 | ANNOTATOR_CKPTS_PATH = str(Path(here, "./ckpts")) 46 | else: 47 | ANNOTATOR_CKPTS_PATH = str(Path(here, "./ckpts")) 48 | TEMP_DIR = tempfile.gettempdir() 49 | USE_SYMLINKS = False 50 | ORT_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider", "CoreMLExecutionProvider"] 51 | 52 | os.environ['AUX_ANNOTATOR_CKPTS_PATH'] = os.getenv('AUX_ANNOTATOR_CKPTS_PATH', ANNOTATOR_CKPTS_PATH) 53 | os.environ['AUX_TEMP_DIR'] = os.getenv('AUX_TEMP_DIR', str(TEMP_DIR)) 54 | os.environ['AUX_USE_SYMLINKS'] = os.getenv('AUX_USE_SYMLINKS', str(USE_SYMLINKS)) 55 | os.environ['AUX_ORT_PROVIDERS'] = os.getenv('AUX_ORT_PROVIDERS', str(",".join(ORT_PROVIDERS))) 56 | 57 | print(f"Using ckpts path: {ANNOTATOR_CKPTS_PATH}") 58 | print(f"Using symlinks: {USE_SYMLINKS}") 59 | print(f"Using ort providers: {ORT_PROVIDERS}") 60 | -------------------------------------------------------------------------------- /src/timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /src/timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | 10 | def set_jit_legacy(): 11 | """ Set JIT executor to legacy w/ support for op fusion 12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 13 | in the JIT exectutor. These API are not supported so could change. 14 | """ 15 | # 16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 17 | torch._C._jit_set_profiling_executor(False) 18 | torch._C._jit_set_profiling_mode(False) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | #torch._C._jit_set_texpr_fuser_enabled(True) 21 | 22 | 23 | def set_jit_fuser(fuser): 24 | if fuser == "te": 25 | # default fuser should be == 'te' 26 | torch._C._jit_set_profiling_executor(True) 27 | torch._C._jit_set_profiling_mode(True) 28 | torch._C._jit_override_can_fuse_on_cpu(False) 29 | torch._C._jit_override_can_fuse_on_gpu(True) 30 | torch._C._jit_set_texpr_fuser_enabled(True) 31 | try: 32 | torch._C._jit_set_nvfuser_enabled(False) 33 | except Exception: 34 | pass 35 | elif fuser == "old" or fuser == "legacy": 36 | torch._C._jit_set_profiling_executor(False) 37 | torch._C._jit_set_profiling_mode(False) 38 | torch._C._jit_override_can_fuse_on_gpu(True) 39 | torch._C._jit_set_texpr_fuser_enabled(False) 40 | try: 41 | torch._C._jit_set_nvfuser_enabled(False) 42 | except Exception: 43 | pass 44 | elif fuser == "nvfuser" or fuser == "nvf": 45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' 46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' 47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' 48 | torch._C._jit_set_texpr_fuser_enabled(False) 49 | torch._C._jit_set_profiling_executor(True) 50 | torch._C._jit_set_profiling_mode(True) 51 | torch._C._jit_can_fuse_on_cpu() 52 | torch._C._jit_can_fuse_on_gpu() 53 | torch._C._jit_override_can_fuse_on_cpu(False) 54 | torch._C._jit_override_can_fuse_on_gpu(False) 55 | torch._C._jit_set_nvfuser_guard_mode(True) 56 | torch._C._jit_set_nvfuser_enabled(True) 57 | else: 58 | assert False, f"Invalid jit fuser ({fuser})" 59 | -------------------------------------------------------------------------------- /src/timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .convnext import * 9 | from .crossvit import * 10 | from .cspnet import * 11 | from .deit import * 12 | from .densenet import * 13 | from .dla import * 14 | from .dpn import * 15 | from .edgenext import * 16 | from .efficientformer import * 17 | from .efficientnet import * 18 | from .gcvit import * 19 | from .ghostnet import * 20 | from .gluon_resnet import * 21 | from .gluon_xception import * 22 | from .hardcorenas import * 23 | from .hrnet import * 24 | from .inception_resnet_v2 import * 25 | from .inception_v3 import * 26 | from .inception_v4 import * 27 | from .levit import * 28 | from .maxxvit import * 29 | from .mlp_mixer import * 30 | from .mobilenetv3 import * 31 | from .mobilevit import * 32 | from .mvitv2 import * 33 | from .nasnet import * 34 | from .nest import * 35 | from .nfnet import * 36 | from .pit import * 37 | from .pnasnet import * 38 | from .poolformer import * 39 | from .pvt_v2 import * 40 | from .regnet import * 41 | from .res2net import * 42 | from .resnest import * 43 | from .resnet import * 44 | from .resnetv2 import * 45 | from .rexnet import * 46 | from .selecsls import * 47 | from .senet import * 48 | from .sequencer import * 49 | from .sknet import * 50 | from .swin_transformer import * 51 | from .swin_transformer_v2 import * 52 | from .swin_transformer_v2_cr import * 53 | from .tnt import * 54 | from .tresnet import * 55 | from .twins import * 56 | from .vgg import * 57 | from .visformer import * 58 | from .vision_transformer import * 59 | from .vision_transformer_hybrid import * 60 | from .vision_transformer_relpos import * 61 | from .volo import * 62 | from .vovnet import * 63 | from .xception import * 64 | from .xception_aligned import * 65 | from .xcit import * 66 | 67 | from .factory import create_model, parse_model_name, safe_model_name 68 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 69 | from .layers import TestTimePoolHead, apply_test_time_pool 70 | from .layers import convert_splitbn_model, convert_sync_batchnorm 71 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 72 | from .layers import set_fast_norm 73 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 74 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value 75 | -------------------------------------------------------------------------------- /src/timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /nodes/image/load_node.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import requests 6 | import itertools 7 | import folder_paths 8 | import psutil 9 | import numpy as np 10 | from comfy.utils import common_upscale 11 | from io import BytesIO 12 | from PIL import Image, ImageSequence, ImageOps 13 | 14 | 15 | 16 | def pil2tensor(img): 17 | output_images = [] 18 | output_masks = [] 19 | for i in ImageSequence.Iterator(img): 20 | i = ImageOps.exif_transpose(i) 21 | if i.mode == 'I': 22 | i = i.point(lambda i: i * (1 / 255)) 23 | image = i.convert("RGB") 24 | image = np.array(image).astype(np.float32) / 255.0 25 | image = torch.from_numpy(image)[None,] 26 | if 'A' in i.getbands(): 27 | mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 28 | mask = 1. - torch.from_numpy(mask) 29 | else: 30 | mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") 31 | output_images.append(image) 32 | output_masks.append(mask.unsqueeze(0)) 33 | 34 | if len(output_images) > 1: 35 | output_image = torch.cat(output_images, dim=0) 36 | output_mask = torch.cat(output_masks, dim=0) 37 | else: 38 | output_image = output_images[0] 39 | output_mask = output_masks[0] 40 | 41 | return (output_image, output_mask) 42 | 43 | 44 | def load_image(image_source): 45 | if image_source.startswith('http'): 46 | print(image_source) 47 | response = requests.get(image_source) 48 | img = Image.open(BytesIO(response.content)) 49 | file_name = image_source.split('/')[-1] 50 | else: 51 | img = Image.open(image_source) 52 | file_name = os.path.basename(image_source) 53 | return img, file_name 54 | 55 | 56 | class LoadImageNode: 57 | @classmethod 58 | def INPUT_TYPES(cls): 59 | return { 60 | "required": { 61 | "path": ("STRING", {"multiline": True, "dynamicPrompts": False}) 62 | } 63 | } 64 | 65 | 66 | RETURN_TYPES = ("IMAGE", "MASK") 67 | FUNCTION = "load_image" 68 | CATEGORY = "tbox/Image" 69 | 70 | def load_image(self, path): 71 | filepaht = path.split('\n')[0] 72 | img, name = load_image(filepaht) 73 | img_out, mask_out = pil2tensor(img) 74 | return (img_out, mask_out) 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | img, name = load_image("https://creativestorage.blob.core.chinacloudapi.cn/test/bird.png") 80 | img_out, mask_out = pil2tensor(img) 81 | 82 | -------------------------------------------------------------------------------- /src/timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__(self, params, lr=required, momentum=0, dampening=0, 21 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 22 | defaults = dict( 23 | lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 24 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 25 | super(SGDP, self).__init__(params, defaults) 26 | 27 | @torch.no_grad() 28 | def step(self, closure=None): 29 | loss = None 30 | if closure is not None: 31 | with torch.enable_grad(): 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | weight_decay = group['weight_decay'] 36 | momentum = group['momentum'] 37 | dampening = group['dampening'] 38 | nesterov = group['nesterov'] 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad 44 | state = self.state[p] 45 | 46 | # State initialization 47 | if len(state) == 0: 48 | state['momentum'] = torch.zeros_like(p) 49 | 50 | # SGD 51 | buf = state['momentum'] 52 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 53 | if nesterov: 54 | d_p = grad + momentum * buf 55 | else: 56 | d_p = buf 57 | 58 | # Projection 59 | wd_ratio = 1. 60 | if len(p.shape) > 1: 61 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 62 | 63 | # Weight decay 64 | if weight_decay != 0: 65 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 66 | 67 | # Step 68 | p.add_(d_p, alpha=-group['lr']) 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /src/dwpose/dw_onnx/cv_ox_yolo_nas.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/Hyuto/yolo-nas-onnx/tree/master/yolo-nas-py 2 | # Inspired from: https://github.com/Deci-AI/super-gradients/blob/3.1.1/src/super_gradients/training/processing/processing.py 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | def preprocess(img, input_size, swap=(2, 0, 1)): 8 | if len(img.shape) == 3: 9 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 10 | else: 11 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 12 | 13 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 14 | resized_img = cv2.resize( 15 | img, 16 | (int(img.shape[1] * r), int(img.shape[0] * r)), 17 | interpolation=cv2.INTER_LINEAR, 18 | ).astype(np.uint8) 19 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 20 | 21 | padded_img = padded_img.transpose(swap) 22 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 23 | return padded_img, r 24 | 25 | def inference_detector(session, oriImg, detect_classes=[0], dtype=np.uint8): 26 | """ 27 | This function is only compatible with onnx models exported from the new API with built-in NMS 28 | ```py 29 | from super_gradients.conversion.conversion_enums import ExportQuantizationMode 30 | from super_gradients.common.object_names import Models 31 | from super_gradients.training import models 32 | 33 | model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco") 34 | 35 | export_result = model.export( 36 | "yolo_nas/yolo_nas_l_fp16.onnx", 37 | quantization_mode=ExportQuantizationMode.FP16, 38 | device="cuda" 39 | ) 40 | ``` 41 | """ 42 | input_shape = (640,640) 43 | img, ratio = preprocess(oriImg, input_shape) 44 | input = img[None, :, :, :] 45 | input = input.astype(dtype) 46 | if "InferenceSession" in type(session).__name__: 47 | input_name = session.get_inputs()[0].name 48 | output = session.run(None, {input_name: input}) 49 | else: 50 | outNames = session.getUnconnectedOutLayersNames() 51 | session.setInput(input) 52 | output = session.forward(outNames) 53 | num_preds, pred_boxes, pred_scores, pred_classes = output 54 | num_preds = num_preds[0,0] 55 | if num_preds == 0: 56 | return None 57 | idxs = np.where((np.isin(pred_classes[0, :num_preds], detect_classes)) & (pred_scores[0, :num_preds] > 0.3)) 58 | if (len(idxs) == 0) or (idxs[0].size == 0): 59 | return None 60 | return pred_boxes[0, idxs].squeeze(axis=0) / ratio 61 | -------------------------------------------------------------------------------- /src/timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | # NOTE super().__init__() not called on purpose 15 | if not 0.0 <= alpha <= 1.0: 16 | raise ValueError(f'Invalid slow update rate: {alpha}') 17 | if not 1 <= k: 18 | raise ValueError(f'Invalid lookahead steps: {k}') 19 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 20 | self._base_optimizer = base_optimizer 21 | self.param_groups = base_optimizer.param_groups 22 | self.defaults = base_optimizer.defaults 23 | self.defaults.update(defaults) 24 | self.state = defaultdict(dict) 25 | # manually add our defaults to the param groups 26 | for name, default in defaults.items(): 27 | for group in self._base_optimizer.param_groups: 28 | group.setdefault(name, default) 29 | 30 | @torch.no_grad() 31 | def update_slow(self, group): 32 | for fast_p in group["params"]: 33 | if fast_p.grad is None: 34 | continue 35 | param_state = self._base_optimizer.state[fast_p] 36 | if 'lookahead_slow_buff' not in param_state: 37 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 38 | param_state['lookahead_slow_buff'].copy_(fast_p) 39 | slow = param_state['lookahead_slow_buff'] 40 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 41 | fast_p.copy_(slow) 42 | 43 | def sync_lookahead(self): 44 | for group in self._base_optimizer.param_groups: 45 | self.update_slow(group) 46 | 47 | @torch.no_grad() 48 | def step(self, closure=None): 49 | loss = self._base_optimizer.step(closure) 50 | for group in self._base_optimizer.param_groups: 51 | group['lookahead_step'] += 1 52 | if group['lookahead_step'] % group['lookahead_k'] == 0: 53 | self.update_slow(group) 54 | return loss 55 | 56 | def state_dict(self): 57 | return self._base_optimizer.state_dict() 58 | 59 | def load_state_dict(self, state_dict): 60 | self._base_optimizer.load_state_dict(state_dict) 61 | self.param_groups = self._base_optimizer.param_groups 62 | -------------------------------------------------------------------------------- /nodes/video/batch_node.py: -------------------------------------------------------------------------------- 1 | 2 | import hashlib 3 | import os 4 | 5 | class BatchManagerNode: 6 | def __init__(self, frames_per_batch=-1): 7 | print("BatchNode init") 8 | self.frames_per_batch = frames_per_batch 9 | self.inputs = {} 10 | self.outputs = {} 11 | self.unique_id = None 12 | self.has_closed_inputs = False 13 | self.total_frames = float('inf') 14 | 15 | def reset(self): 16 | print("BatchNode reset") 17 | self.close_inputs() 18 | for key in self.outputs: 19 | if getattr(self.outputs[key][-1], "gi_suspended", False): 20 | try: 21 | self.outputs[key][-1].send(None) 22 | except StopIteration: 23 | pass 24 | self.__init__(self.frames_per_batch) 25 | def has_open_inputs(self): 26 | return len(self.inputs) > 0 27 | def close_inputs(self): 28 | for key in self.inputs: 29 | if getattr(self.inputs[key][-1], "gi_suspended", False): 30 | try: 31 | self.inputs[key][-1].send(1) 32 | except StopIteration: 33 | pass 34 | self.inputs = {} 35 | 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | return { 39 | "required": {"frames_per_batch": ("INT", {"default": 16, "min": 1, "max": 128, "step": 1})}, 40 | "hidden": {"prompt": "PROMPT", "unique_id": "UNIQUE_ID"}, 41 | } 42 | 43 | RETURN_TYPES = ("BatchManager",) 44 | RETURN_NAMES = ("meta_batch",) 45 | CATEGORY = "tbox/Video" 46 | FUNCTION = "update_batch" 47 | 48 | def update_batch(self, frames_per_batch, prompt=None, unique_id=None): 49 | if unique_id is not None and prompt is not None: 50 | requeue = prompt[unique_id]['inputs'].get('requeue', 0) 51 | else: 52 | requeue = 0 53 | print(f'update_batch >> unique_id: {unique_id}; requeue: {requeue}') 54 | if requeue == 0: 55 | self.reset() 56 | self.frames_per_batch = frames_per_batch 57 | self.unique_id = unique_id 58 | else: 59 | num_batches = (self.total_frames+self.frames_per_batch-1)//frames_per_batch 60 | print(f'Meta-Batch {requeue}/{num_batches}') 61 | #onExecuted seems to not be called unless some message is sent 62 | return (self,) 63 | 64 | @classmethod 65 | def IS_CHANGED(self, frames_per_batch, prompt=None, unique_id=None): 66 | print(f"BatchManagerNode >>> IS_CHANGED : {result}") 67 | random_bytes = os.urandom(32) 68 | result = hashlib.sha256(random_bytes).hexdigest() 69 | return result 70 | -------------------------------------------------------------------------------- /src/timm/models/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /src/timm/models/layers/fast_norm.py: -------------------------------------------------------------------------------- 1 | """ 'Fast' Normalization Functions 2 | 3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. 4 | 5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) 6 | 7 | Hacked together by / Copyright 2022 Ross Wightman 8 | """ 9 | from typing import List, Optional 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | try: 15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine 16 | has_apex = True 17 | except ImportError: 18 | has_apex = False 19 | 20 | 21 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up 22 | _USE_FAST_NORM = False # defaulting to False for now 23 | 24 | 25 | def is_fast_norm(): 26 | return _USE_FAST_NORM 27 | 28 | 29 | def set_fast_norm(enable=True): 30 | global _USE_FAST_NORM 31 | _USE_FAST_NORM = enable 32 | 33 | 34 | def fast_group_norm( 35 | x: torch.Tensor, 36 | num_groups: int, 37 | weight: Optional[torch.Tensor] = None, 38 | bias: Optional[torch.Tensor] = None, 39 | eps: float = 1e-5 40 | ) -> torch.Tensor: 41 | if torch.jit.is_scripting(): 42 | # currently cannot use is_autocast_enabled within torchscript 43 | return F.group_norm(x, num_groups, weight, bias, eps) 44 | 45 | if torch.is_autocast_enabled(): 46 | # normally native AMP casts GN inputs to float32 47 | # here we use the low precision autocast dtype 48 | # FIXME what to do re CPU autocast? 49 | dt = torch.get_autocast_gpu_dtype() 50 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 51 | 52 | with torch.cuda.amp.autocast(enabled=False): 53 | return F.group_norm(x, num_groups, weight, bias, eps) 54 | 55 | 56 | def fast_layer_norm( 57 | x: torch.Tensor, 58 | normalized_shape: List[int], 59 | weight: Optional[torch.Tensor] = None, 60 | bias: Optional[torch.Tensor] = None, 61 | eps: float = 1e-5 62 | ) -> torch.Tensor: 63 | if torch.jit.is_scripting(): 64 | # currently cannot use is_autocast_enabled within torchscript 65 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 66 | 67 | if has_apex: 68 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) 69 | 70 | if torch.is_autocast_enabled(): 71 | # normally native AMP casts LN inputs to float32 72 | # apex LN does not, this is behaving like Apex 73 | dt = torch.get_autocast_gpu_dtype() 74 | # FIXME what to do re CPU autocast? 75 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 76 | 77 | with torch.cuda.amp.autocast(enabled=False): 78 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 79 | -------------------------------------------------------------------------------- /src/timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm import get_norm_layer, create_norm_layer 15 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 18 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 19 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 20 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 21 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 22 | from .gather_excite import GatherExcite 23 | from .global_context import GlobalContext 24 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 25 | from .inplace_abn import InplaceAbn 26 | from .linear import Linear 27 | from .mixed_conv2d import MixedConv2d 28 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 29 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 30 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 31 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 32 | from .padding import get_padding, get_same_padding, pad_same 33 | from .patch_embed import PatchEmbed 34 | from .pool2d_same import AvgPool2dSame, create_pool2d 35 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 36 | from .selective_kernel import SelectiveKernel 37 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 38 | from .space_to_depth import SpaceToDepthModule 39 | from .split_attn import SplitAttn 40 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 41 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 42 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 43 | from .trace_utils import _assert, _float_to_int 44 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 45 | -------------------------------------------------------------------------------- /src/timm/models/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /src/timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /src/timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .parser import Parser 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ParserImageTar(Parser): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /src/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /nodes/mask/mask_node.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | class MaskSubNode: 6 | @classmethod 7 | def INPUT_TYPES(cls): 8 | return { 9 | "required": { 10 | "mask": ("MASK",), 11 | }, 12 | "optional": { 13 | "src1": ("MASK",), 14 | "src2": ("MASK",), 15 | "src3": ("MASK",), 16 | "src4": ("MASK",), 17 | "src5": ("MASK",), 18 | "src6": ("MASK",), 19 | } 20 | } 21 | 22 | CATEGORY = "mask" 23 | RETURN_TYPES = ("MASK",) 24 | 25 | FUNCTION = "sub" 26 | CATEGORY = "tbox/Mask" 27 | 28 | def sub_mask(self, dst, src): 29 | if src != None: 30 | mask = src.reshape((-1, src.shape[-2], src.shape[-1])) 31 | return dst - mask 32 | return dst 33 | 34 | def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None): 35 | print(f'mask shape: {mask.shape}') 36 | output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() 37 | output[:, :, :] = self.sub_mask(output, src1) 38 | output[:, :, :] = self.sub_mask(output, src2) 39 | output[:, :, :] = self.sub_mask(output, src3) 40 | output[:, :, :] = self.sub_mask(output, src4) 41 | output[:, :, :] = self.sub_mask(output, src5) 42 | output[:, :, :] = self.sub_mask(output, src6) 43 | output = torch.clamp(output, 0.0, 1.0) 44 | return (output, ) 45 | 46 | class MaskAddNode: 47 | @classmethod 48 | def INPUT_TYPES(cls): 49 | return { 50 | "required": { 51 | "mask": ("MASK",), 52 | }, 53 | "optional": { 54 | "src1": ("MASK",), 55 | "src2": ("MASK",), 56 | "src3": ("MASK",), 57 | "src4": ("MASK",), 58 | "src5": ("MASK",), 59 | "src6": ("MASK",), 60 | } 61 | } 62 | 63 | CATEGORY = "mask" 64 | RETURN_TYPES = ("MASK",) 65 | 66 | FUNCTION = "add" 67 | CATEGORY = "tbox/Mask" 68 | 69 | def add_mask(self, dst, src): 70 | if src != None: 71 | mask = src.reshape((-1, src.shape[-2], src.shape[-1])) 72 | return dst + mask 73 | return dst 74 | 75 | def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None): 76 | print(f'mask shape: {mask.shape}') 77 | output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() 78 | output[:, :, :] = self.add_mask(output, src1) 79 | output[:, :, :] = self.add_mask(output, src2) 80 | output[:, :, :] = self.add_mask(output, src3) 81 | output[:, :, :] = self.add_mask(output, src4) 82 | output[:, :, :] = self.add_mask(output, src5) 83 | output[:, :, :] = self.add_mask(output, src6) 84 | output = torch.clamp(output, 0.0, 1.0) 85 | return (output, ) 86 | -------------------------------------------------------------------------------- /nodes/image/save_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from PIL import Image, ImageSequence, ImageOps 5 | 6 | #from load_node import load_image, pil2tensor 7 | 8 | def save_image(img, filepath, format, quality): 9 | try: 10 | if format in ["jpg", "jpeg"]: 11 | img.convert("RGB").save(filepath, format="JPEG", quality=quality, subsampling=0) 12 | elif format == "webp": 13 | img.save(filepath, format="WEBP", quality=quality, method=6) 14 | elif format == "bmp": 15 | img.save(filepath, format="BMP") 16 | else: 17 | img.save(filepath, format="PNG", optimize=True) 18 | except Exception as e: 19 | print(f"Error saving {filepath}: {str(e)}") 20 | 21 | class SaveImageNode: 22 | @classmethod 23 | def INPUT_TYPES(cls): 24 | return { 25 | "required": { 26 | "images": ("IMAGE",), 27 | "path": ("STRING", {"multiline": True, "dynamicPrompts": False}), 28 | "quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}), 29 | } 30 | } 31 | RETURN_TYPES = () 32 | FUNCTION = "save_image" 33 | CATEGORY = "tbox/Image" 34 | OUTPUT_NODE = True 35 | 36 | def save_image(self, images, path, quality): 37 | filepaht = path.split('\n')[0] 38 | format = os.path.splitext(filepaht)[1][1:] 39 | image = images[0] 40 | img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8)) 41 | save_image(img, filepaht, format, quality) 42 | return {} 43 | 44 | class SaveImagesNode: 45 | @classmethod 46 | def INPUT_TYPES(cls): 47 | return { 48 | "required": { 49 | "images": ("IMAGE",), 50 | "path": ("STRING", {"multiline": False, "dynamicPrompts": False}), 51 | "prefix": ("STRING", {"default": "image"}), 52 | "format": (["PNG", "JPG", "WEBP", "BMP"],), 53 | "quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}), 54 | } 55 | } 56 | RETURN_TYPES = () 57 | FUNCTION = "save_image" 58 | CATEGORY = "tbox/Image" 59 | OUTPUT_NODE = True 60 | 61 | def save_image(self, images, path, prefix, format, quality): 62 | format = format.lower() 63 | for i, image in enumerate(images): 64 | img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8)) 65 | filepath = self.generate_filename(path, prefix, i, format) 66 | save_image(img, filepath, format, quality) 67 | return {} 68 | 69 | def IS_CHANGED(s, images): 70 | return time.time() 71 | 72 | def generate_filename(self, save_dir, prefix, index, format): 73 | base_filename = f"{prefix}_{index+1}.{format}" 74 | filename = os.path.join(save_dir, base_filename) 75 | counter = 1 76 | while os.path.exists(filename): 77 | filename = os.path.join(save_dir, f"{prefix}_{index+1}_{counter}.{format}") 78 | counter += 1 79 | return filename 80 | -------------------------------------------------------------------------------- /src/facefusion/gfpgan_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | import argparse 6 | import cv2 7 | import numpy as np 8 | #import timeit 9 | import onnxruntime 10 | from facefusion.affine import create_box_mask, warp_face_by_landmark, paste_back 11 | 12 | class GFPGANOnnx: 13 | def __init__(self, model_path, providers): 14 | self.session = onnxruntime.InferenceSession(model_path, providers=providers) 15 | inputs = self.session.get_inputs() 16 | self.input_size = (inputs[0].shape[2], inputs[0].shape[3]) 17 | self.input_name = inputs[0].name 18 | self.affine = False 19 | 20 | def pre_process(self, image): 21 | img = cv2.resize(image, self.input_size) 22 | img = img/255.0 23 | img[:,:,0] = (img[:,:,0]-0.5)/0.5 24 | img[:,:,1] = (img[:,:,1]-0.5)/0.5 25 | img[:,:,2] = (img[:,:,2]-0.5)/0.5 26 | img = np.float32(img[np.newaxis,:,:,:]) 27 | img = img.transpose(0, 3, 1, 2) 28 | return img 29 | 30 | def post_process(self, output, height, width): 31 | output = output.clip(-1,1) 32 | output = (output + 1) / 2 33 | output = output.transpose(1, 2, 0) 34 | # output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 35 | output = (output * 255.0).round() 36 | output = cv2.resize(output, (width, height)) 37 | return output 38 | 39 | def run(self, image): 40 | height, width = image.shape[0], image.shape[1] 41 | img = self.pre_process(image) 42 | #t = timeit.default_timer() 43 | outputs = self.session.run(None, {self.input_name: img}) 44 | output = outputs[0][0] 45 | output = self.post_process(output, height, width) 46 | #print('infer time:',timeit.default_timer()-t) 47 | output = output.astype(np.uint8) 48 | return output 49 | 50 | if __name__ == "__main__": 51 | from yoloface_onnx import YoloFaceOnnx 52 | providers=['CPUExecutionProvider'] 53 | model_path = '/Users/wadahana/workspace/AI/sd/ComfyUI/models/facefusion/gfpgan_1.4.onnx' 54 | yolo_path = '/Users/wadahana/workspace/AI/sd/ComfyUI/models/facefusion/yoloface_8n.onnx' 55 | 56 | detector = YoloFaceOnnx(model_path=yolo_path, providers=providers) 57 | session = GFPGANOnnx(model_path=model_path, providers=providers) 58 | 59 | 60 | 61 | image = cv2.imread('/Users/wadahana/Desktop/anime-3.jpeg') 62 | 63 | face_list = detector.detect(image=image, conf=0.7) 64 | print(f'total of face: {len(face_list)}') 65 | 66 | output = image 67 | for index, face in enumerate(face_list): 68 | cropped, affine_matrix = warp_face_by_landmark(image, face.landmarks, session.input_size) 69 | box_mask = create_box_mask(session.input_size, 0.3, (0,0,0,0)) 70 | crop_mask = np.minimum.reduce([box_mask]).clip(0, 1) 71 | result = session.run(cropped) 72 | cv2.imwrite(f'/Users/wadahana/Desktop/output_{index}.jpg', result) 73 | output = paste_back(output, result, crop_mask, affine_matrix) 74 | 75 | cv2.imwrite(f'/Users/wadahana/Desktop/output.jpg', output) 76 | -------------------------------------------------------------------------------- /nodes/face/face_enhance_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from facefusion.gfpgan_onnx import GFPGANOnnx 7 | from facefusion.yoloface_onnx import YoloFaceOnnx 8 | from facefusion.affine import create_box_mask, warp_face_by_landmark, paste_back 9 | 10 | import folder_paths 11 | from ..utils import tensor2pil, pil2tensor 12 | 13 | # class GFPGANProvider: 14 | # @classmethod 15 | # def INPUT_TYPES(s): 16 | # return { 17 | # "required": { 18 | # "model_name": ("IMAGE", ["gfpgan_1.4.onnx"]), 19 | # }, 20 | # } 21 | 22 | # RETURN_TYPES = ("GFPGAN_MODEL",) 23 | # RETURN_NAMES = ("model",) 24 | # FUNCTION = "load_model" 25 | # CATEGORY = "tbox/facefusion" 26 | 27 | # def load_model(self, model_name): 28 | # return (model_name,) 29 | 30 | 31 | class GFPGANNode: 32 | @classmethod 33 | def INPUT_TYPES(cls): 34 | return { 35 | "required": { 36 | "images": ("IMAGE",), 37 | "model_name": (['gfpgan_1.3', 'gfpgan_1.4'], {"default": 'gfpgan_1.4'}), 38 | "device": (['CPU', 'CUDA', 'CoreML', 'ROCM'], {"default": 'CPU'}), 39 | "weight": ("FLOAT", {"default": 0.8}), 40 | } 41 | } 42 | 43 | RETURN_TYPES = ("IMAGE", ) 44 | FUNCTION = "process" 45 | CATEGORY = "tbox/FaceFusion" 46 | 47 | def process(self, images, model_name, device='CPU', weight=0.8): 48 | providers = ['CPUExecutionProvider'] 49 | if device== 'CUDA': 50 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 51 | elif device == 'CoreML': 52 | providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider'] 53 | elif device == 'ROCM': 54 | providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] 55 | 56 | gfpgan_path = folder_paths.get_full_path("facefusion", f'{model_name}.onnx') 57 | yolo_path = folder_paths.get_full_path("facefusion", 'yoloface_8n.onnx') 58 | 59 | detector = YoloFaceOnnx(model_path=yolo_path, providers=providers) 60 | enhancer = GFPGANOnnx(model_path=gfpgan_path, providers=providers) 61 | 62 | image_list = [] 63 | for i, img in enumerate(images): 64 | pil = tensor2pil(img) 65 | image = np.ascontiguousarray(pil) 66 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 67 | output = image 68 | face_list = detector.detect(image=image, conf=0.7) 69 | for index, face in enumerate(face_list): 70 | cropped, affine_matrix = warp_face_by_landmark(image, face.landmarks, enhancer.input_size) 71 | box_mask = create_box_mask(enhancer.input_size, 0.3, (0,0,0,0)) 72 | crop_mask = np.minimum.reduce([box_mask]).clip(0, 1) 73 | result = enhancer.run(cropped) 74 | output = paste_back(output, result, crop_mask, affine_matrix) 75 | image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) 76 | pil = Image.fromarray(image) 77 | image_list.append(pil2tensor(pil)) 78 | image_list = torch.stack([tensor.squeeze() for tensor in image_list]) 79 | return (image_list,) 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/timm/models/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | -------------------------------------------------------------------------------- /src/timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /src/midas/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from PIL import Image 8 | 9 | from common import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download 10 | from .api import MiDaSInference 11 | 12 | HF_MODEL_NAME = "lllyasviel/Annotators" 13 | 14 | class MidasDetector: 15 | def __init__(self, model): 16 | self.model = model 17 | self.device = "cpu" 18 | 19 | @classmethod 20 | def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, model_type="dpt_hybrid", filename="dpt_hybrid-midas-501f0c75.pt"): 21 | subfolder = "annotator/ckpts" if pretrained_model_or_path == "lllyasviel/ControlNet" else '' 22 | model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder) 23 | model = MiDaSInference(model_type=model_type, model_path=model_path) 24 | return cls(model) 25 | 26 | 27 | def to(self, device): 28 | self.model.to(device) 29 | self.device = device 30 | return self 31 | 32 | def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): 33 | input_image, output_type = common_input_validate(input_image, output_type, **kwargs) 34 | detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) 35 | image_depth = detected_map 36 | with torch.no_grad(): 37 | image_depth = torch.from_numpy(image_depth).float() 38 | image_depth = image_depth.to(self.device) 39 | image_depth = image_depth / 127.5 - 1.0 40 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 41 | depth = self.model(image_depth)[0] 42 | 43 | depth_pt = depth.clone() 44 | depth_pt -= torch.min(depth_pt) 45 | depth_pt /= torch.max(depth_pt) 46 | depth_pt = depth_pt.cpu().numpy() 47 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 48 | 49 | if depth_and_normal: 50 | depth_np = depth.cpu().numpy() 51 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) 52 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) 53 | z = np.ones_like(x) * a 54 | x[depth_pt < bg_th] = 0 55 | y[depth_pt < bg_th] = 0 56 | normal = np.stack([x, y, z], axis=2) 57 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 58 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] 59 | 60 | depth_image = HWC3(depth_image) 61 | if depth_and_normal: 62 | normal_image = HWC3(normal_image) 63 | 64 | 65 | depth_image = remove_pad(depth_image) 66 | if depth_and_normal: 67 | normal_image = remove_pad(normal_image) 68 | 69 | if output_type == "pil": 70 | depth_image = Image.fromarray(depth_image) 71 | if depth_and_normal: 72 | normal_image = Image.fromarray(normal_image) 73 | 74 | if depth_and_normal: 75 | return depth_image, normal_image 76 | else: 77 | return depth_image 78 | -------------------------------------------------------------------------------- /src/timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | else: 29 | if use_test_size and 'test_input_size' in default_cfg: 30 | input_size = default_cfg['test_input_size'] 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if 'interpolation' in args and args['interpolation']: 38 | new_config['interpolation'] = args['interpolation'] 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if 'mean' in args and args['mean'] is not None: 45 | mean = tuple(args['mean']) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if 'std' in args and args['std'] is not None: 57 | std = tuple(args['std']) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | crop_pct = DEFAULT_CROP_PCT 68 | if 'crop_pct' in args and args['crop_pct'] is not None: 69 | crop_pct = args['crop_pct'] 70 | else: 71 | if use_test_size and 'test_crop_pct' in default_cfg: 72 | crop_pct = default_cfg['test_crop_pct'] 73 | elif 'crop_pct' in default_cfg: 74 | crop_pct = default_cfg['crop_pct'] 75 | new_config['crop_pct'] = crop_pct 76 | 77 | if verbose: 78 | _logger.info('Data processing configuration for current model + dataset:') 79 | for n, v in new_config.items(): 80 | _logger.info('\t%s: %s' % (n, str(v))) 81 | 82 | return new_config 83 | -------------------------------------------------------------------------------- /src/timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /src/densepose/__init__.py: -------------------------------------------------------------------------------- 1 | import torchvision # Fix issue Unknown builtin op: torchvision::nms 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from PIL import Image 8 | 9 | from common import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download 10 | from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences 11 | 12 | N_PART_LABELS = 24 13 | DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image" 14 | 15 | class DenseposeDetector: 16 | def __init__(self, model): 17 | self.dense_pose_estimation = model 18 | self.device = "cpu" 19 | self.result_visualizer = DensePoseMaskedColormapResultsVisualizer( 20 | alpha=1, 21 | data_extractor=_extract_i_from_iuvarr, 22 | segm_extractor=_extract_i_from_iuvarr, 23 | val_scale = 255.0 / N_PART_LABELS 24 | ) 25 | 26 | @classmethod 27 | def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"): 28 | torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename) 29 | densepose = torch.jit.load(torchscript_model_path, map_location="cpu") 30 | return cls(densepose) 31 | 32 | def to(self, device): 33 | self.dense_pose_estimation.to(device) 34 | self.device = device 35 | return self 36 | 37 | def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs): 38 | input_image, output_type = common_input_validate(input_image, output_type, **kwargs) 39 | input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) 40 | H, W = input_image.shape[:2] 41 | 42 | hint_image_canvas = np.zeros([H, W], dtype=np.uint8) 43 | hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3]) 44 | 45 | input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w') 46 | 47 | pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image) 48 | 49 | extractor = densepose_chart_predictor_output_to_result_with_confidences 50 | densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))] 51 | 52 | if cmap=="viridis": 53 | self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS 54 | hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results) 55 | hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB) 56 | hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68 57 | hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1 58 | hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84 59 | else: 60 | self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA 61 | hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results) 62 | hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB) 63 | 64 | detected_map = remove_pad(HWC3(hint_image)) 65 | if output_type == "pil": 66 | detected_map = Image.fromarray(detected_map) 67 | return detected_map 68 | -------------------------------------------------------------------------------- /src/timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from .layers import set_layer_config 7 | from .hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | 41 | Args: 42 | model_name (str): name of model to instantiate 43 | pretrained (bool): load pretrained ImageNet-1k weights if true 44 | checkpoint_path (str): path of checkpoint to load after model is initialized 45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 48 | 49 | Keyword Args: 50 | drop_rate (float): dropout rate for training (default: 0.0) 51 | global_pool (str): global pool type (default: 'avg') 52 | **: other kwargs are model specific 53 | """ 54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 56 | # non-supporting models don't break and default args remain in effect. 57 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 58 | 59 | model_source, model_name = parse_model_name(model_name) 60 | if model_source == 'hf-hub': 61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 63 | # load model weights + pretrained_cfg from Hugging Face hub. 64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 65 | 66 | if not is_model(model_name): 67 | raise RuntimeError('Unknown model (%s)' % model_name) 68 | 69 | create_fn = model_entrypoint(model_name) 70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 72 | 73 | if checkpoint_path: 74 | load_checkpoint(model, checkpoint_path) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /src/timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 16 | super(ConvNormAct, self).__init__() 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | return x 39 | 40 | 41 | ConvBnAct = ConvNormAct 42 | 43 | 44 | def create_aa(aa_layer, channels, stride=2, enable=True): 45 | if not aa_layer or not enable: 46 | return nn.Identity() 47 | if isinstance(aa_layer, functools.partial): 48 | if issubclass(aa_layer.func, nn.AvgPool2d): 49 | return aa_layer() 50 | else: 51 | return aa_layer(channels) 52 | elif issubclass(aa_layer, nn.AvgPool2d): 53 | return aa_layer(stride) 54 | else: 55 | return aa_layer(channels=channels, stride=stride) 56 | 57 | 58 | class ConvNormActAa(nn.Module): 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 62 | super(ConvNormActAa, self).__init__() 63 | use_aa = aa_layer is not None and stride == 2 64 | 65 | self.conv = create_conv2d( 66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 67 | padding=padding, dilation=dilation, groups=groups, bias=bias) 68 | 69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.aa(x) 88 | return x 89 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from .utils import here 4 | import platform 5 | 6 | sys.path.insert(0, str(Path(here, "src").resolve())) 7 | 8 | from .nodes.image.load_node import LoadImageNode 9 | from .nodes.image.save_node import SaveImageNode 10 | from .nodes.image.save_node import SaveImagesNode 11 | from .nodes.image.size_node import ImageResizeNode 12 | from .nodes.image.size_node import ImageSizeNode 13 | from .nodes.image.size_node import ConstrainImageNode 14 | from .nodes.image.watermark_node import WatermarkNode 15 | from .nodes.mask.mask_node import MaskAddNode 16 | from .nodes.video.load_node import LoadVideoNode 17 | from .nodes.video.save_node import SaveVideoNode 18 | from .nodes.video.info_node import VideoInfoNode 19 | from .nodes.video.batch_node import BatchManagerNode 20 | from .nodes.preprocessor.canny_node import Canny_Preprocessor 21 | from .nodes.preprocessor.lineart_node import LineArt_Preprocessor 22 | from .nodes.preprocessor.lineart_node import Lineart_Standard_Preprocessor 23 | from .nodes.preprocessor.midas_node import MIDAS_Depth_Map_Preprocessor 24 | from .nodes.preprocessor.dwpose_node import DWPose_Preprocessor, AnimalPose_Preprocessor 25 | from .nodes.preprocessor.densepose_node import DensePose_Preprocessor 26 | from .nodes.face.face_enhance_node import GFPGANNode 27 | from .nodes.other.vram_node import PurgeVRAMNode 28 | 29 | NODE_CLASS_MAPPINGS = { 30 | "PurgeVRAMNode": PurgeVRAMNode, 31 | "GFPGANNode": GFPGANNode, 32 | "MaskAddNode": MaskAddNode, 33 | "ImageLoader": LoadImageNode, 34 | "ImageSaver": SaveImageNode, 35 | "ImagesSaver": SaveImagesNode, 36 | "ImageResize": ImageResizeNode, 37 | "ImageSize": ImageSizeNode, 38 | "WatermarkNode": WatermarkNode, 39 | "VideoLoader": LoadVideoNode, 40 | "VideoSaver": SaveVideoNode, 41 | "VideoInfo": VideoInfoNode, 42 | "BatchManager": BatchManagerNode, 43 | "ConstrainImageNode": ConstrainImageNode, 44 | "DensePosePreprocessor": DensePose_Preprocessor, 45 | "DWPosePreprocessor": DWPose_Preprocessor, 46 | "AnimalPosePreprocessor": AnimalPose_Preprocessor, 47 | "MiDaSDepthPreprocessor": MIDAS_Depth_Map_Preprocessor, 48 | "CannyPreprocessor": Canny_Preprocessor, 49 | "LineArtPreprocessor": LineArt_Preprocessor, 50 | "LineartStandardPreprocessor": Lineart_Standard_Preprocessor, 51 | } 52 | 53 | NODE_DISPLAY_NAME_MAPPINGS = { 54 | "PurgeVRAMNode":"PurgeVRAMNode", 55 | "GFPGANNode": "GFPGANNode", 56 | "MaskAddNode": "MaskAddNode", 57 | "ImageLoader": "Image Load", 58 | "ImageSaver": "Image Save", 59 | "ImagesSaver": "Image List Save", 60 | "ImageResize": "Image Resize", 61 | "ImageSize": "Image Size", 62 | "WatermarkNode": "Watermark", 63 | "VideoLoader": "Video Load", 64 | "VideoSaver": "Video Save", 65 | "VideoInfo": "Video Info", 66 | "BatchManager": "Batch Manager", 67 | "ConstrainImageNode": "Image Constrain", 68 | "DensePosePreprocessor": "DensePose Estimator", 69 | "DWPosePreprocessor": "DWPose Estimator", 70 | "AnimalPosePreprocessor": "AnimalPose Estimator", 71 | "MiDaSDepthPreprocessor": "MiDaS Depth Estimator", 72 | "CannyPreprocessor": "Canny Edge Estimator", 73 | "LineArtPreprocessor": "Realistic Lineart", 74 | "LineartStandardPreprocessor": "Standard Lineart", 75 | } 76 | 77 | 78 | if platform.system() == "Darwin": 79 | WEB_DIRECTORY = "./web" 80 | __all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 81 | else: 82 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /src/timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch._C.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch._C.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /src/timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads images from folders 2 | 3 | Folders are scannerd recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | from typing import Dict, List, Optional, Set, Tuple, Union 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .parser import Parser 16 | 17 | 18 | def find_images_and_targets( 19 | folder: str, 20 | types: Optional[Union[List, Tuple, Set]] = None, 21 | class_to_idx: Optional[Dict] = None, 22 | leaf_name_only: bool = True, 23 | sort: bool = True 24 | ): 25 | """ Walk folder recursively to discover images and map them to classes by folder names. 26 | 27 | Args: 28 | folder: root of folder to recrusively search 29 | types: types (file extensions) to search for in path 30 | class_to_idx: specify mapping for class (folder name) to class index if set 31 | leaf_name_only: use only leaf-name of folder walk for class names 32 | sort: re-sort found images by name (for consistent ordering) 33 | 34 | Returns: 35 | A list of image and target tuples, class_to_idx mapping 36 | """ 37 | types = get_img_extensions(as_set=True) if not types else set(types) 38 | labels = [] 39 | filenames = [] 40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 41 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 43 | for f in files: 44 | base, ext = os.path.splitext(f) 45 | if ext.lower() in types: 46 | filenames.append(os.path.join(root, f)) 47 | labels.append(label) 48 | if class_to_idx is None: 49 | # building class index 50 | unique_labels = set(labels) 51 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 54 | if sort: 55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 56 | return images_and_targets, class_to_idx 57 | 58 | 59 | class ParserImageFolder(Parser): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | class_map=''): 65 | super().__init__() 66 | 67 | self.root = root 68 | class_to_idx = None 69 | if class_map: 70 | class_to_idx = load_class_map(class_map, root) 71 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 72 | if len(self.samples) == 0: 73 | raise RuntimeError( 74 | f'Found 0 images in subfolders of {root}. ' 75 | f'Supported image extensions are {", ".join(get_img_extensions())}') 76 | 77 | def __getitem__(self, index): 78 | path, target = self.samples[index] 79 | return open(path, 'rb'), target 80 | 81 | def __len__(self): 82 | return len(self.samples) 83 | 84 | def _filename(self, index, basename=False, absolute=False): 85 | filename = self.samples[index][0] 86 | if basename: 87 | filename = os.path.basename(filename) 88 | elif not absolute: 89 | filename = os.path.relpath(filename, self.root) 90 | return filename 91 | -------------------------------------------------------------------------------- /src/dwpose/hand.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import numpy as np 4 | import math 5 | import time 6 | from scipy.ndimage.filters import gaussian_filter 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import torch 10 | from skimage.measure import label 11 | 12 | from .model import handpose_model 13 | from . import util 14 | 15 | class Hand(object): 16 | def __init__(self, model_path): 17 | self.model = handpose_model() 18 | # if torch.cuda.is_available(): 19 | # self.model = self.model.cuda() 20 | # print('cuda') 21 | model_dict = util.transfer(self.model, torch.load(model_path)) 22 | self.model.load_state_dict(model_dict) 23 | self.model.eval() 24 | 25 | def __call__(self, oriImgRaw): 26 | scale_search = [0.5, 1.0, 1.5, 2.0] 27 | # scale_search = [0.5] 28 | boxsize = 368 29 | stride = 8 30 | padValue = 128 31 | thre = 0.05 32 | multiplier = [x * boxsize for x in scale_search] 33 | 34 | wsize = 128 35 | heatmap_avg = np.zeros((wsize, wsize, 22)) 36 | 37 | Hr, Wr, Cr = oriImgRaw.shape 38 | 39 | oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) 40 | 41 | for m in range(len(multiplier)): 42 | scale = multiplier[m] 43 | imageToTest = util.smart_resize(oriImg, (scale, scale)) 44 | 45 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 46 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 47 | im = np.ascontiguousarray(im) 48 | 49 | data = torch.from_numpy(im).float() 50 | if torch.cuda.is_available(): 51 | data = data.cuda() 52 | 53 | with torch.no_grad(): 54 | data = data.to(self.cn_device) 55 | output = self.model(data).cpu().numpy() 56 | 57 | # extract outputs, resize, and remove padding 58 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps 59 | heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) 60 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 61 | heatmap = util.smart_resize(heatmap, (wsize, wsize)) 62 | 63 | heatmap_avg += heatmap / len(multiplier) 64 | 65 | all_peaks = [] 66 | for part in range(21): 67 | map_ori = heatmap_avg[:, :, part] 68 | one_heatmap = gaussian_filter(map_ori, sigma=3) 69 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) 70 | 71 | if np.sum(binary) == 0: 72 | all_peaks.append([0, 0]) 73 | continue 74 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) 75 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 76 | label_img[label_img != max_index] = 0 77 | map_ori[label_img == 0] = 0 78 | 79 | y, x = util.npmax(map_ori) 80 | y = int(float(y) * float(Hr) / float(wsize)) 81 | x = int(float(x) * float(Wr) / float(wsize)) 82 | all_peaks.append([x, y]) 83 | return np.array(all_peaks) 84 | 85 | if __name__ == "__main__": 86 | hand_estimation = Hand('../model/hand_pose_model.pth') 87 | 88 | # test_image = '../images/hand.jpg' 89 | test_image = '../images/hand.jpg' 90 | oriImg = cv2.imread(test_image) # B,G,R order 91 | peaks = hand_estimation(oriImg) 92 | canvas = util.draw_handpose(oriImg, peaks, True) 93 | cv2.imshow('', canvas) 94 | cv2.waitKey(0) -------------------------------------------------------------------------------- /src/timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /src/timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /src/timm/optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict( 14 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 15 | buffer=[[None, None, None] for _ in range(10)]) 16 | super(RAdam, self).__init__(params, defaults) 17 | 18 | def __setstate__(self, state): 19 | super(RAdam, self).__setstate__(state) 20 | 21 | @torch.no_grad() 22 | def step(self, closure=None): 23 | loss = None 24 | if closure is not None: 25 | with torch.enable_grad(): 26 | loss = closure() 27 | 28 | for group in self.param_groups: 29 | 30 | for p in group['params']: 31 | if p.grad is None: 32 | continue 33 | grad = p.grad.float() 34 | if grad.is_sparse: 35 | raise RuntimeError('RAdam does not support sparse gradients') 36 | 37 | p_fp32 = p.float() 38 | 39 | state = self.state[p] 40 | 41 | if len(state) == 0: 42 | state['step'] = 0 43 | state['exp_avg'] = torch.zeros_like(p_fp32) 44 | state['exp_avg_sq'] = torch.zeros_like(p_fp32) 45 | else: 46 | state['exp_avg'] = state['exp_avg'].type_as(p_fp32) 47 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) 48 | 49 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 50 | beta1, beta2 = group['betas'] 51 | 52 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 53 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 54 | 55 | state['step'] += 1 56 | buffered = group['buffer'][int(state['step'] % 10)] 57 | if state['step'] == buffered[0]: 58 | num_sma, step_size = buffered[1], buffered[2] 59 | else: 60 | buffered[0] = state['step'] 61 | beta2_t = beta2 ** state['step'] 62 | num_sma_max = 2 / (1 - beta2) - 1 63 | num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 64 | buffered[1] = num_sma 65 | 66 | # more conservative since it's an approximated value 67 | if num_sma >= 5: 68 | step_size = group['lr'] * math.sqrt( 69 | (1 - beta2_t) * 70 | (num_sma - 4) / (num_sma_max - 4) * 71 | (num_sma - 2) / num_sma * 72 | num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) 73 | else: 74 | step_size = group['lr'] / (1 - beta1 ** state['step']) 75 | buffered[2] = step_size 76 | 77 | if group['weight_decay'] != 0: 78 | p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) 79 | 80 | # more conservative since it's an approximated value 81 | if num_sma >= 5: 82 | denom = exp_avg_sq.sqrt().add_(group['eps']) 83 | p_fp32.addcdiv_(exp_avg, denom, value=-step_size) 84 | else: 85 | p_fp32.add_(exp_avg, alpha=-step_size) 86 | 87 | p.copy_(p_fp32) 88 | 89 | return loss 90 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /src/midas/backbones/levit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from .utils import activations, get_activation, Transpose 7 | 8 | 9 | def forward_levit(pretrained, x): 10 | pretrained.model.forward_features(x) 11 | 12 | layer_1 = pretrained.activations["1"] 13 | layer_2 = pretrained.activations["2"] 14 | layer_3 = pretrained.activations["3"] 15 | 16 | layer_1 = pretrained.act_postprocess1(layer_1) 17 | layer_2 = pretrained.act_postprocess2(layer_2) 18 | layer_3 = pretrained.act_postprocess3(layer_3) 19 | 20 | return layer_1, layer_2, layer_3 21 | 22 | 23 | def _make_levit_backbone( 24 | model, 25 | hooks=[3, 11, 21], 26 | patch_grid=[14, 14] 27 | ): 28 | pretrained = nn.Module() 29 | 30 | pretrained.model = model 31 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 32 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 33 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 34 | 35 | pretrained.activations = activations 36 | 37 | patch_grid_size = np.array(patch_grid, dtype=int) 38 | 39 | pretrained.act_postprocess1 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 42 | ) 43 | pretrained.act_postprocess2 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) 46 | ) 47 | pretrained.act_postprocess3 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) 50 | ) 51 | 52 | return pretrained 53 | 54 | 55 | class ConvTransposeNorm(nn.Sequential): 56 | """ 57 | Modification of 58 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm 59 | such that ConvTranspose2d is used instead of Conv2d. 60 | """ 61 | 62 | def __init__( 63 | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, 64 | groups=1, bn_weight_init=1): 65 | super().__init__() 66 | self.add_module('c', 67 | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) 68 | self.add_module('bn', nn.BatchNorm2d(out_chs)) 69 | 70 | nn.init.constant_(self.bn.weight, bn_weight_init) 71 | 72 | @torch.no_grad() 73 | def fuse(self): 74 | c, bn = self._modules.values() 75 | w = bn.weight / (bn.running_var + bn.eps) ** 0.5 76 | w = c.weight * w[:, None, None, None] 77 | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 78 | m = nn.ConvTranspose2d( 79 | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, 80 | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 81 | m.weight.data.copy_(w) 82 | m.bias.data.copy_(b) 83 | return m 84 | 85 | 86 | def stem_b4_transpose(in_chs, out_chs, activation): 87 | """ 88 | Modification of 89 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 90 | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. 91 | """ 92 | return nn.Sequential( 93 | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), 94 | activation(), 95 | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), 96 | activation()) 97 | 98 | 99 | def _make_pretrained_levit_384(pretrained, hooks=None): 100 | model = timm.create_model("levit_384", pretrained=pretrained) 101 | 102 | hooks = [3, 11, 21] if hooks == None else hooks 103 | return _make_levit_backbone( 104 | model, 105 | hooks=hooks 106 | ) 107 | -------------------------------------------------------------------------------- /src/timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__( 34 | optimizer, 35 | 'lr', 36 | noise_range_t=noise_range_t, 37 | noise_type=noise_type, 38 | noise_pct=noise_pct, 39 | noise_std=noise_std, 40 | noise_seed=noise_seed, 41 | initialize=initialize, 42 | ) 43 | 44 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 45 | self.optimizer, 46 | patience=patience_t, 47 | factor=decay_rate, 48 | verbose=verbose, 49 | threshold=threshold, 50 | cooldown=cooldown_t, 51 | mode=mode, 52 | min_lr=lr_min 53 | ) 54 | 55 | self.warmup_t = warmup_t 56 | self.warmup_lr_init = warmup_lr_init 57 | if self.warmup_t: 58 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 59 | super().update_groups(self.warmup_lr_init) 60 | else: 61 | self.warmup_steps = [1 for _ in self.base_values] 62 | self.restore_lr = None 63 | 64 | def state_dict(self): 65 | return { 66 | 'best': self.lr_scheduler.best, 67 | 'last_epoch': self.lr_scheduler.last_epoch, 68 | } 69 | 70 | def load_state_dict(self, state_dict): 71 | self.lr_scheduler.best = state_dict['best'] 72 | if 'last_epoch' in state_dict: 73 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 74 | 75 | # override the base class step fn completely 76 | def step(self, epoch, metric=None): 77 | if epoch <= self.warmup_t: 78 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 79 | super().update_groups(lrs) 80 | else: 81 | if self.restore_lr is not None: 82 | # restore actual LR from before our last noise perturbation before stepping base 83 | for i, param_group in enumerate(self.optimizer.param_groups): 84 | param_group['lr'] = self.restore_lr[i] 85 | self.restore_lr = None 86 | 87 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 88 | 89 | if self._is_apply_noise(epoch): 90 | self._apply_noise(epoch) 91 | 92 | def _apply_noise(self, epoch): 93 | noise = self._calculate_noise(epoch) 94 | 95 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 96 | # stepping of base scheduler 97 | restore_lr = [] 98 | for i, param_group in enumerate(self.optimizer.param_groups): 99 | old_lr = float(param_group['lr']) 100 | restore_lr.append(old_lr) 101 | new_lr = old_lr + old_lr * noise 102 | param_group['lr'] = new_lr 103 | self.restore_lr = restore_lr 104 | -------------------------------------------------------------------------------- /src/timm/optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer 14 | import math 15 | 16 | 17 | def _channel_view(x) -> torch.Tensor: 18 | return x.reshape(x.size(0), -1) 19 | 20 | 21 | def _layer_view(x) -> torch.Tensor: 22 | return x.reshape(1, -1) 23 | 24 | 25 | def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): 26 | wd = 1. 27 | expand_size = (-1,) + (1,) * (len(p.shape) - 1) 28 | for view_func in [_channel_view, _layer_view]: 29 | param_view = view_func(p) 30 | grad_view = view_func(grad) 31 | cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() 32 | 33 | # FIXME this is a problem for PyTorch XLA 34 | if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): 35 | p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) 36 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) 37 | wd = wd_ratio 38 | return perturb, wd 39 | 40 | return perturb, wd 41 | 42 | 43 | class AdamP(Optimizer): 44 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 45 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 46 | defaults = dict( 47 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 48 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 49 | super(AdamP, self).__init__(params, defaults) 50 | 51 | @torch.no_grad() 52 | def step(self, closure=None): 53 | loss = None 54 | if closure is not None: 55 | with torch.enable_grad(): 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | 63 | grad = p.grad 64 | beta1, beta2 = group['betas'] 65 | nesterov = group['nesterov'] 66 | 67 | state = self.state[p] 68 | 69 | # State initialization 70 | if len(state) == 0: 71 | state['step'] = 0 72 | state['exp_avg'] = torch.zeros_like(p) 73 | state['exp_avg_sq'] = torch.zeros_like(p) 74 | 75 | # Adam 76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 77 | 78 | state['step'] += 1 79 | bias_correction1 = 1 - beta1 ** state['step'] 80 | bias_correction2 = 1 - beta2 ** state['step'] 81 | 82 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 83 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 84 | 85 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 86 | step_size = group['lr'] / bias_correction1 87 | 88 | if nesterov: 89 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 90 | else: 91 | perturb = exp_avg / denom 92 | 93 | # Projection 94 | wd_ratio = 1. 95 | if len(p.shape) > 1: 96 | perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 97 | 98 | # Weight decay 99 | if group['weight_decay'] > 0: 100 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) 101 | 102 | # Step 103 | p.add_(perturb, alpha=-step_size) 104 | 105 | return loss 106 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 54 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 55 | norm_act_kwargs = {} 56 | 57 | # unbind partial fn, so args can be rebound later 58 | if isinstance(norm_layer, functools.partial): 59 | norm_act_kwargs.update(norm_layer.keywords) 60 | norm_layer = norm_layer.func 61 | 62 | if isinstance(norm_layer, str): 63 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 64 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) 65 | elif norm_layer in _NORM_ACT_TYPES: 66 | norm_act_layer = norm_layer 67 | elif isinstance(norm_layer, types.FunctionType): 68 | # if function type, must be a lambda/fn that creates a norm_act layer 69 | norm_act_layer = norm_layer 70 | else: 71 | type_name = norm_layer.__name__.lower() 72 | if type_name.startswith('batchnorm'): 73 | norm_act_layer = BatchNormAct2d 74 | elif type_name.startswith('groupnorm'): 75 | norm_act_layer = GroupNormAct 76 | elif type_name.startswith('groupnorm1'): 77 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 78 | elif type_name.startswith('layernorm2d'): 79 | norm_act_layer = LayerNormAct2d 80 | elif type_name.startswith('layernorm'): 81 | norm_act_layer = LayerNormAct 82 | else: 83 | assert False, f"No equivalent norm_act layer for {type_name}" 84 | 85 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 86 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 87 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 88 | norm_act_kwargs.setdefault('act_layer', act_layer) 89 | if norm_act_kwargs: 90 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 91 | return norm_act_layer 92 | -------------------------------------------------------------------------------- /src/timm/models/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /src/timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class Nadam(Optimizer): 8 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 9 | 10 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 11 | 12 | Arguments: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 2e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 22 | 23 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 24 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 25 | 26 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 27 | NOTE: Has potential issues but does work well on some problems. 28 | """ 29 | 30 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, schedule_decay=4e-3): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | defaults = dict( 35 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) 36 | super(Nadam, self).__init__(params, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | grad = p.grad 56 | state = self.state[p] 57 | 58 | # State initialization 59 | if len(state) == 0: 60 | state['step'] = 0 61 | state['m_schedule'] = 1. 62 | state['exp_avg'] = torch.zeros_like(p) 63 | state['exp_avg_sq'] = torch.zeros_like(p) 64 | 65 | # Warming momentum schedule 66 | m_schedule = state['m_schedule'] 67 | schedule_decay = group['schedule_decay'] 68 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 69 | beta1, beta2 = group['betas'] 70 | eps = group['eps'] 71 | state['step'] += 1 72 | t = state['step'] 73 | bias_correction2 = 1 - beta2 ** t 74 | 75 | if group['weight_decay'] != 0: 76 | grad = grad.add(p, alpha=group['weight_decay']) 77 | 78 | momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay))) 79 | momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 80 | m_schedule_new = m_schedule * momentum_cache_t 81 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 82 | state['m_schedule'] = m_schedule_new 83 | 84 | # Decay the first and second moment running average coefficient 85 | exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1) 86 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2) 87 | 88 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 89 | p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) 90 | p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) 91 | 92 | return loss 93 | --------------------------------------------------------------------------------