├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_dataset.cpython-36.pyc │ ├── image_folder.cpython-36.pyc │ ├── unaligned_dataset.cpython-36.pyc │ └── unaligned_double_dataset.cpython-36.pyc ├── base_dataset.py ├── image_folder.py ├── single_dataset.py ├── singleimage_dataset.py ├── template_dataset.py ├── unaligned_dataset.py └── unaligned_double_dataset.py ├── datasets ├── bibtex │ ├── cityscapes.tex │ ├── facades.tex │ ├── handbags.tex │ ├── shoes.tex │ └── transattr.tex ├── combine_A_and_B.py ├── detect_cat_face.py ├── download_cut_dataset.sh ├── download_pix2pix_dataset.sh ├── make_dataset_aligned.py ├── prepare_cityscapes_dataset.py └── single_image_monet_etretat │ ├── trainA │ └── monet.jpg │ └── trainB │ └── etretat-normandy-france.jpg ├── images ├── detection1.gif ├── detection2.gif ├── fusion.gif └── method_final.jpg ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_model.cpython-36.pyc │ ├── cut_model.cpython-36.pyc │ ├── mae.cpython-36.pyc │ ├── models_mae.cpython-36.pyc │ ├── mutilvitgloballocal_model.cpython-36.pyc │ ├── networks.cpython-36.pyc │ ├── patchnce.cpython-36.pyc │ ├── region0_model.cpython-36.pyc │ ├── region_model.cpython-36.pyc │ ├── stylegan_networks.cpython-36.pyc │ ├── vit2Gmask_model.cpython-36.pyc │ ├── vit2_model.cpython-36.pyc │ ├── vit2patchmask_model.cpython-36.pyc │ ├── vit2tokenmask_model.cpython-36.pyc │ ├── vitD_model.cpython-36.pyc │ ├── vit_model.cpython-36.pyc │ ├── vitdonly2_model.cpython-36.pyc │ ├── vitdonly_model.cpython-36.pyc │ ├── vitgloballocal_model.cpython-36.pyc │ └── vitlocalgloballocal_model.cpython-36.pyc ├── base_model.py ├── cut_model.py ├── cycle_gan_model.py ├── networks.py ├── patchnce.py ├── roma_model.py ├── roma_single_model.py ├── stylegan_networks.py ├── template_model.py └── util │ ├── __pycache__ │ └── pos_embed.cpython-36.pyc │ ├── crop.py │ ├── datasets.py │ ├── lars.py │ ├── lr_decay.py │ ├── lr_sched.py │ ├── misc.py │ └── pos_embed.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_options.cpython-36.pyc │ ├── test_options.cpython-36.pyc │ └── train_options.cpython-36.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── test.sh └── train.sh ├── test.py ├── timm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── version.cpython-36.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── auto_augment.cpython-36.pyc │ │ ├── config.cpython-36.pyc │ │ ├── constants.cpython-36.pyc │ │ ├── dataset.cpython-36.pyc │ │ ├── dataset_factory.cpython-36.pyc │ │ ├── distributed_sampler.cpython-36.pyc │ │ ├── loader.cpython-36.pyc │ │ ├── mixup.cpython-36.pyc │ │ ├── random_erasing.cpython-36.pyc │ │ ├── real_labels.cpython-36.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── transforms_factory.cpython-36.pyc │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── parsers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── class_map.cpython-36.pyc │ │ │ ├── constants.cpython-36.pyc │ │ │ ├── parser.cpython-36.pyc │ │ │ ├── parser_factory.cpython-36.pyc │ │ │ ├── parser_image_folder.cpython-36.pyc │ │ │ ├── parser_image_in_tar.cpython-36.pyc │ │ │ └── parser_image_tar.cpython-36.pyc │ │ ├── class_map.py │ │ ├── constants.py │ │ ├── parser.py │ │ ├── parser_factory.py │ │ ├── parser_image_folder.py │ │ ├── parser_image_in_tar.py │ │ ├── parser_image_tar.py │ │ └── parser_tfds.py │ ├── random_erasing.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── asymmetric_loss.py │ ├── binary_cross_entropy.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── beit.cpython-36.pyc │ │ ├── byoanet.cpython-36.pyc │ │ ├── byobnet.cpython-36.pyc │ │ ├── cait.cpython-36.pyc │ │ ├── coat.cpython-36.pyc │ │ ├── convit.cpython-36.pyc │ │ ├── convmixer.cpython-36.pyc │ │ ├── crossvit.cpython-36.pyc │ │ ├── cspnet.cpython-36.pyc │ │ ├── densenet.cpython-36.pyc │ │ ├── dla.cpython-36.pyc │ │ ├── dpn.cpython-36.pyc │ │ ├── efficientnet.cpython-36.pyc │ │ ├── efficientnet_blocks.cpython-36.pyc │ │ ├── efficientnet_builder.cpython-36.pyc │ │ ├── factory.cpython-36.pyc │ │ ├── features.cpython-36.pyc │ │ ├── fx_features.cpython-36.pyc │ │ ├── ghostnet.cpython-36.pyc │ │ ├── gluon_resnet.cpython-36.pyc │ │ ├── gluon_xception.cpython-36.pyc │ │ ├── hardcorenas.cpython-36.pyc │ │ ├── helpers.cpython-36.pyc │ │ ├── hrnet.cpython-36.pyc │ │ ├── hub.cpython-36.pyc │ │ ├── inception_resnet_v2.cpython-36.pyc │ │ ├── inception_v3.cpython-36.pyc │ │ ├── inception_v4.cpython-36.pyc │ │ ├── levit.cpython-36.pyc │ │ ├── mlp_mixer.cpython-36.pyc │ │ ├── mobilenetv3.cpython-36.pyc │ │ ├── nasnet.cpython-36.pyc │ │ ├── nest.cpython-36.pyc │ │ ├── nfnet.cpython-36.pyc │ │ ├── pit.cpython-36.pyc │ │ ├── pnasnet.cpython-36.pyc │ │ ├── registry.cpython-36.pyc │ │ ├── regnet.cpython-36.pyc │ │ ├── res2net.cpython-36.pyc │ │ ├── resnest.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnetv2.cpython-36.pyc │ │ ├── rexnet.cpython-36.pyc │ │ ├── selecsls.cpython-36.pyc │ │ ├── senet.cpython-36.pyc │ │ ├── sknet.cpython-36.pyc │ │ ├── swin_transformer.cpython-36.pyc │ │ ├── tnt.cpython-36.pyc │ │ ├── tresnet.cpython-36.pyc │ │ ├── twins.cpython-36.pyc │ │ ├── vgg.cpython-36.pyc │ │ ├── visformer.cpython-36.pyc │ │ ├── vision_transformer.cpython-36.pyc │ │ ├── vision_transformer_hybrid.cpython-36.pyc │ │ ├── vovnet.cpython-36.pyc │ │ ├── xception.cpython-36.pyc │ │ ├── xception_aligned.cpython-36.pyc │ │ └── xcit.cpython-36.pyc │ ├── beit.py │ ├── byoanet.py │ ├── byobnet.py │ ├── cait.py │ ├── coat.py │ ├── convit.py │ ├── convmixer.py │ ├── crossvit.py │ ├── cspnet.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── features.py │ ├── fx_features.py │ ├── ghostnet.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── hardcorenas.py │ ├── helpers.py │ ├── hrnet.py │ ├── hub.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── activations.cpython-36.pyc │ │ │ ├── activations_jit.cpython-36.pyc │ │ │ ├── activations_me.cpython-36.pyc │ │ │ ├── adaptive_avgmax_pool.cpython-36.pyc │ │ │ ├── blur_pool.cpython-36.pyc │ │ │ ├── bottleneck_attn.cpython-36.pyc │ │ │ ├── cbam.cpython-36.pyc │ │ │ ├── classifier.cpython-36.pyc │ │ │ ├── cond_conv2d.cpython-36.pyc │ │ │ ├── config.cpython-36.pyc │ │ │ ├── conv2d_same.cpython-36.pyc │ │ │ ├── conv_bn_act.cpython-36.pyc │ │ │ ├── create_act.cpython-36.pyc │ │ │ ├── create_attn.cpython-36.pyc │ │ │ ├── create_conv2d.cpython-36.pyc │ │ │ ├── create_norm_act.cpython-36.pyc │ │ │ ├── drop.cpython-36.pyc │ │ │ ├── eca.cpython-36.pyc │ │ │ ├── evo_norm.cpython-36.pyc │ │ │ ├── gather_excite.cpython-36.pyc │ │ │ ├── global_context.cpython-36.pyc │ │ │ ├── halo_attn.cpython-36.pyc │ │ │ ├── helpers.cpython-36.pyc │ │ │ ├── inplace_abn.cpython-36.pyc │ │ │ ├── lambda_layer.cpython-36.pyc │ │ │ ├── linear.cpython-36.pyc │ │ │ ├── mixed_conv2d.cpython-36.pyc │ │ │ ├── mlp.cpython-36.pyc │ │ │ ├── non_local_attn.cpython-36.pyc │ │ │ ├── norm.cpython-36.pyc │ │ │ ├── norm_act.cpython-36.pyc │ │ │ ├── padding.cpython-36.pyc │ │ │ ├── patch_embed.cpython-36.pyc │ │ │ ├── pool2d_same.cpython-36.pyc │ │ │ ├── selective_kernel.cpython-36.pyc │ │ │ ├── separable_conv.cpython-36.pyc │ │ │ ├── space_to_depth.cpython-36.pyc │ │ │ ├── split_attn.cpython-36.pyc │ │ │ ├── split_batchnorm.cpython-36.pyc │ │ │ ├── squeeze_excite.cpython-36.pyc │ │ │ ├── std_conv.cpython-36.pyc │ │ │ ├── test_time_pool.cpython-36.pyc │ │ │ ├── trace_utils.cpython-36.pyc │ │ │ └── weight_init.cpython-36.pyc │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── attention_pool2d.py │ │ ├── blur_pool.py │ │ ├── bottleneck_attn.py │ │ ├── cbam.py │ │ ├── classifier.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm_act.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── gather_excite.py │ │ ├── global_context.py │ │ ├── halo_attn.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── lambda_layer.py │ │ ├── linear.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── mlp.py │ │ ├── non_local_attn.py │ │ ├── norm.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── patch_embed.py │ │ ├── pool2d_same.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── squeeze_excite.py │ │ ├── std_conv.py │ │ ├── test_time_pool.py │ │ ├── trace_utils.py │ │ └── weight_init.py │ ├── levit.py │ ├── mlp_mixer.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── nest.py │ ├── nfnet.py │ ├── pit.py │ ├── pnasnet.py │ ├── pruned │ │ ├── ecaresnet101d_pruned.txt │ │ ├── ecaresnet50d_pruned.txt │ │ ├── efficientnet_b1_pruned.txt │ │ ├── efficientnet_b2_pruned.txt │ │ └── efficientnet_b3_pruned.txt │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── resnetv2.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── swin_transformer.py │ ├── tnt.py │ ├── tresnet.py │ ├── twins.py │ ├── vgg.py │ ├── visformer.py │ ├── vision_transformer.py │ ├── vision_transformer_hybrid.py │ ├── vovnet.py │ ├── xception.py │ ├── xception_aligned.py │ └── xcit.py ├── optim │ ├── __init__.py │ ├── adabelief.py │ ├── adafactor.py │ ├── adahessian.py │ ├── adamp.py │ ├── adamw.py │ ├── lamb.py │ ├── lars.py │ ├── lookahead.py │ ├── madgrad.py │ ├── nadam.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ ├── rmsprop_tf.py │ └── sgdp.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── multistep_lr.py │ ├── plateau_lr.py │ ├── poly_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── agc.cpython-36.pyc │ │ ├── checkpoint_saver.cpython-36.pyc │ │ ├── clip_grad.cpython-36.pyc │ │ ├── cuda.cpython-36.pyc │ │ ├── distributed.cpython-36.pyc │ │ ├── jit.cpython-36.pyc │ │ ├── log.cpython-36.pyc │ │ ├── metrics.cpython-36.pyc │ │ ├── misc.cpython-36.pyc │ │ ├── model.cpython-36.pyc │ │ ├── model_ema.cpython-36.pyc │ │ ├── random.cpython-36.pyc │ │ └── summary.cpython-36.pyc │ ├── agc.py │ ├── checkpoint_saver.py │ ├── clip_grad.py │ ├── cuda.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ ├── random.py │ └── summary.py └── version.py ├── train.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── html.cpython-36.pyc ├── util.cpython-36.pyc └── visualizer.cpython-36.pyc ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/unaligned_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/data/__pycache__/unaligned_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/unaligned_double_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/data/__pycache__/unaligned_double_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform 2 | from data.image_folder import make_dataset 3 | from PIL import Image 4 | 5 | 6 | class SingleDataset(BaseDataset): 7 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 8 | 9 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 10 | """ 11 | 12 | def __init__(self, opt): 13 | """Initialize this dataset class. 14 | 15 | Parameters: 16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 17 | """ 18 | BaseDataset.__init__(self, opt) 19 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 20 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 21 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 22 | 23 | def __getitem__(self, index): 24 | """Return a data point and its metadata information. 25 | 26 | Parameters: 27 | index - - a random integer for data indexing 28 | 29 | Returns a dictionary that contains A and A_paths 30 | A(tensor) - - an image in one domain 31 | A_paths(str) - - the path of the image 32 | """ 33 | A_path = self.A_paths[index] 34 | A_img = Image.open(A_path).convert('RGB') 35 | A = self.transform(A_img) 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | """Return the total number of images in the dataset.""" 40 | return len(self.A_paths) 41 | -------------------------------------------------------------------------------- /data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /datasets/bibtex/facades.tex: -------------------------------------------------------------------------------- 1 | @INPROCEEDINGS{Tylecek13, 2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, 3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, 4 | booktitle = {Proc. GCPR}, 5 | year = {2013}, 6 | address = {Saarbrucken, Germany}, 7 | } 8 | -------------------------------------------------------------------------------- /datasets/bibtex/handbags.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhu2016generative, 2 | title={Generative Visual Manipulation on the Natural Image Manifold}, 3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.}, 4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)}, 5 | year={2016} 6 | } 7 | 8 | @InProceedings{xie15hed, 9 | author = {"Xie, Saining and Tu, Zhuowen"}, 10 | Title = {Holistically-Nested Edge Detection}, 11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 12 | Year = {2015}, 13 | } 14 | -------------------------------------------------------------------------------- /datasets/bibtex/shoes.tex: -------------------------------------------------------------------------------- 1 | @InProceedings{fine-grained, 2 | author = {A. Yu and K. Grauman}, 3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning}, 4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 5 | month = {June}, 6 | year = {2014} 7 | } 8 | 9 | @InProceedings{xie15hed, 10 | author = {"Xie, Saining and Tu, Zhuowen"}, 11 | Title = {Holistically-Nested Edge Detection}, 12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 13 | Year = {2015}, 14 | } 15 | -------------------------------------------------------------------------------- /datasets/bibtex/transattr.tex: -------------------------------------------------------------------------------- 1 | @article {Laffont14, 2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes}, 3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays}, 4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)}, 5 | volume = {33}, 6 | number = {4}, 7 | year = {2014} 8 | } 9 | -------------------------------------------------------------------------------- /datasets/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | img_fold_A = os.path.join(args.fold_A, sp) 21 | img_fold_B = os.path.join(args.fold_B, sp) 22 | img_list = os.listdir(img_fold_A) 23 | if args.use_AB: 24 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 25 | 26 | num_imgs = min(args.num_imgs, len(img_list)) 27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 28 | img_fold_AB = os.path.join(args.fold_AB, sp) 29 | if not os.path.isdir(img_fold_AB): 30 | os.makedirs(img_fold_AB) 31 | print('split = %s, number of images = %d' % (sp, num_imgs)) 32 | for n in range(num_imgs): 33 | name_A = img_list[n] 34 | path_A = os.path.join(img_fold_A, name_A) 35 | if args.use_AB: 36 | name_B = name_A.replace('_A.', '_B.') 37 | else: 38 | name_B = name_A 39 | path_B = os.path.join(img_fold_B, name_B) 40 | if os.path.isfile(path_A) and os.path.isfile(path_B): 41 | name_AB = name_A 42 | if args.use_AB: 43 | name_AB = name_AB.replace('_A.', '.') # remove _A 44 | path_AB = os.path.join(img_fold_AB, name_AB) 45 | im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 46 | im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 47 | im_AB = np.concatenate([im_A, im_B], 1) 48 | cv2.imwrite(path_AB, im_AB) 49 | -------------------------------------------------------------------------------- /datasets/detect_cat_face.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | import argparse 5 | 6 | 7 | def get_file_paths(folder): 8 | image_file_paths = [] 9 | for root, dirs, filenames in os.walk(folder): 10 | filenames = sorted(filenames) 11 | for filename in filenames: 12 | input_path = os.path.abspath(root) 13 | file_path = os.path.join(input_path, filename) 14 | if filename.endswith('.png') or filename.endswith('.jpg'): 15 | image_file_paths.append(file_path) 16 | 17 | break # prevent descending into subfolders 18 | return image_file_paths 19 | 20 | 21 | SF = 1.05 22 | N = 3 23 | 24 | 25 | def detect_cat(img_path, cat_cascade, output_dir, ratio=0.05, border_ratio=0.25): 26 | print('processing {}'.format(img_path)) 27 | output_width = 286 28 | img = cv2.imread(img_path) 29 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 30 | H, W = img.shape[0], img.shape[1] 31 | minH = int(H * ratio) 32 | minW = int(W * ratio) 33 | cats = cat_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N, minSize=(minH, minW)) 34 | 35 | for cat_id, (x, y, w, h) in enumerate(cats): 36 | x1 = max(0, x - w * border_ratio) 37 | x2 = min(W, x + w * (1 + border_ratio)) 38 | y1 = max(0, y - h * border_ratio) 39 | y2 = min(H, y + h * (1 + border_ratio)) 40 | img_crop = img[int(y1):int(y2), int(x1):int(x2)] 41 | img_name = os.path.basename(img_path) 42 | out_path = os.path.join(output_dir, img_name.replace('.jpg', '_cat%d.jpg' % cat_id)) 43 | print('write', out_path) 44 | img_crop = cv2.resize(img_crop, (output_width, output_width), interpolation=cv2.INTER_CUBIC) 45 | cv2.imwrite(out_path, img_crop, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description='detecting cat faces using opencv detector') 50 | parser.add_argument('--input_dir', type=str, help='input image directory') 51 | parser.add_argument('--output_dir', type=str, help='wihch directory to store cropped cat faces') 52 | parser.add_argument('--use_ext', action='store_true', help='if use haarcascade_frontalcatface_extended or not') 53 | args = parser.parse_args() 54 | 55 | if args.use_ext: 56 | cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml') 57 | else: 58 | cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface_extended.xml') 59 | img_paths = get_file_paths(args.input_dir) 60 | print('total number of images {} from {}'.format(len(img_paths), args.input_dir)) 61 | if not os.path.exists(args.output_dir): 62 | os.makedirs(args.output_dir) 63 | for img_path in img_paths: 64 | detect_cat(img_path, cat_cascade, args.output_dir) 65 | -------------------------------------------------------------------------------- /datasets/download_cut_dataset.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | FILE=$1 4 | 5 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" && $FILE != "grumpifycat" ]]; then 6 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, grumpifycat" 7 | exit 1 8 | fi 9 | 10 | if [[ $FILE == "cityscapes" ]]; then 11 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py." 12 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py" 13 | exit 1 14 | fi 15 | 16 | echo "Specified [$FILE]" 17 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 18 | ZIP_FILE=./datasets/$FILE.zip 19 | TARGET_DIR=./datasets/$FILE/ 20 | wget --no-check-certificate -N $URL -O $ZIP_FILE 21 | mkdir $TARGET_DIR 22 | unzip $ZIP_FILE -d ./datasets/ 23 | rm $ZIP_FILE 24 | -------------------------------------------------------------------------------- /datasets/download_pix2pix_dataset.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | FILE=$1 4 | 5 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then 6 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps" 7 | exit 1 8 | fi 9 | 10 | if [[ $FILE == "cityscapes" ]]; then 11 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py." 12 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py" 13 | exit 1 14 | fi 15 | 16 | echo "Specified [$FILE]" 17 | 18 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz 19 | TAR_FILE=./datasets/$FILE.tar.gz 20 | TARGET_DIR=./datasets/$FILE/ 21 | wget -N $URL -O $TAR_FILE 22 | mkdir -p $TARGET_DIR 23 | tar -zxvf $TAR_FILE -C ./datasets/ 24 | rm $TAR_FILE 25 | -------------------------------------------------------------------------------- /datasets/make_dataset_aligned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | 6 | def get_file_paths(folder): 7 | image_file_paths = [] 8 | for root, dirs, filenames in os.walk(folder): 9 | filenames = sorted(filenames) 10 | for filename in filenames: 11 | input_path = os.path.abspath(root) 12 | file_path = os.path.join(input_path, filename) 13 | if filename.endswith('.png') or filename.endswith('.jpg'): 14 | image_file_paths.append(file_path) 15 | 16 | break # prevent descending into subfolders 17 | return image_file_paths 18 | 19 | 20 | def align_images(a_file_paths, b_file_paths, target_path): 21 | if not os.path.exists(target_path): 22 | os.makedirs(target_path) 23 | 24 | for i in range(len(a_file_paths)): 25 | img_a = Image.open(a_file_paths[i]) 26 | img_b = Image.open(b_file_paths[i]) 27 | assert(img_a.size == img_b.size) 28 | 29 | aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1])) 30 | aligned_image.paste(img_a, (0, 0)) 31 | aligned_image.paste(img_b, (img_a.size[0], 0)) 32 | aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i))) 33 | 34 | 35 | if __name__ == '__main__': 36 | import argparse 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | '--dataset-path', 40 | dest='dataset_path', 41 | help='Which folder to process (it should have subfolders testA, testB, trainA and trainB' 42 | ) 43 | args = parser.parse_args() 44 | 45 | dataset_folder = args.dataset_path 46 | print(dataset_folder) 47 | 48 | test_a_path = os.path.join(dataset_folder, 'testA') 49 | test_b_path = os.path.join(dataset_folder, 'testB') 50 | test_a_file_paths = get_file_paths(test_a_path) 51 | test_b_file_paths = get_file_paths(test_b_path) 52 | assert(len(test_a_file_paths) == len(test_b_file_paths)) 53 | test_path = os.path.join(dataset_folder, 'test') 54 | 55 | train_a_path = os.path.join(dataset_folder, 'trainA') 56 | train_b_path = os.path.join(dataset_folder, 'trainB') 57 | train_a_file_paths = get_file_paths(train_a_path) 58 | train_b_file_paths = get_file_paths(train_b_path) 59 | assert(len(train_a_file_paths) == len(train_b_file_paths)) 60 | train_path = os.path.join(dataset_folder, 'train') 61 | 62 | align_images(test_a_file_paths, test_b_file_paths, test_path) 63 | align_images(train_a_file_paths, train_b_file_paths, train_path) 64 | -------------------------------------------------------------------------------- /datasets/single_image_monet_etretat/trainA/monet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/datasets/single_image_monet_etretat/trainA/monet.jpg -------------------------------------------------------------------------------- /datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg -------------------------------------------------------------------------------- /images/detection1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/images/detection1.gif -------------------------------------------------------------------------------- /images/detection2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/images/detection2.gif -------------------------------------------------------------------------------- /images/fusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/images/fusion.gif -------------------------------------------------------------------------------- /images/method_final.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/images/method_final.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/cut_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/cut_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/mae.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/models_mae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/models_mae.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/patchnce.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/patchnce.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/region0_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/region0_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/region_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/region_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/stylegan_networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/stylegan_networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit2Gmask_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vit2Gmask_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit2_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vit2_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit2patchmask_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vit2patchmask_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit2tokenmask_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vit2tokenmask_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vitD_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vitD_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vit_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vit_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vitdonly2_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vitdonly2_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vitdonly_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vitdonly_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vitgloballocal_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vitgloballocal_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/patchnce.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class PatchNCELoss(nn.Module): 7 | def __init__(self, opt): 8 | super().__init__() 9 | self.opt = opt 10 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') 11 | self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool 12 | 13 | def forward(self, feat_q, feat_k): 14 | num_patches = feat_q.shape[0] 15 | dim = feat_q.shape[1] 16 | feat_k = feat_k.detach() 17 | 18 | # pos logit 19 | l_pos = torch.bmm( 20 | feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1)) 21 | l_pos = l_pos.view(num_patches, 1) 22 | 23 | # neg logit 24 | 25 | # Should the negatives from the other samples of a minibatch be utilized? 26 | # In CUT and FastCUT, we found that it's best to only include negatives 27 | # from the same image. Therefore, we set 28 | # --nce_includes_all_negatives_from_minibatch as False 29 | # However, for single-image translation, the minibatch consists of 30 | # crops from the "same" high-resolution image. 31 | # Therefore, we will include the negatives from the entire minibatch. 32 | if self.opt.nce_includes_all_negatives_from_minibatch: 33 | # reshape features as if they are all negatives of minibatch of size 1. 34 | batch_dim_for_bmm = 1 35 | else: 36 | batch_dim_for_bmm = self.opt.batch_size 37 | 38 | # reshape features to batch size 39 | feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) 40 | feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) 41 | npatches = feat_q.size(1) 42 | l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) 43 | 44 | # diagonal entries are similarity between same features, and hence meaningless. 45 | # just fill the diagonal with very small number, which is exp(-10) and almost zero 46 | diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :] 47 | l_neg_curbatch.masked_fill_(diagonal, -10.0) 48 | l_neg = l_neg_curbatch.view(-1, npatches) 49 | 50 | out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T 51 | 52 | loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, 53 | device=feat_q.device)) 54 | 55 | return loss 56 | -------------------------------------------------------------------------------- /models/util/__pycache__/pos_embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/models/util/__pycache__/pos_embed.cpython-36.pyc -------------------------------------------------------------------------------- /models/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /models/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /models/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /models/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /models/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/options/__pycache__/train_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 14 | # Dropout and Batchnorm has different behavioir during training and test. 15 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 16 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 17 | 18 | # To avoid cropping, the load_size should be the same as crop_size 19 | parser.set_defaults(load_size=parser.get_default('crop_size')) 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest 2 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | # Train for video mode 2 | CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001 3 | 4 | # Train for image mode 5 | CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001 -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ 4 | get_model_default_value, is_model_pretrained 5 | -------------------------------------------------------------------------------- /timm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/__pycache__/version.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/__pycache__/version.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser 10 | from .real_labels import RealLabelsImagenet 11 | from .transforms import * 12 | from .transforms_factory import create_transform -------------------------------------------------------------------------------- /timm/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/auto_augment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/auto_augment.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/constants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/constants.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/dataset_factory.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/distributed_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/distributed_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/mixup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/mixup.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/random_erasing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/random_erasing.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/real_labels.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/real_labels.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/__pycache__/transforms_factory.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | else: 29 | if use_test_size and 'test_input_size' in default_cfg: 30 | input_size = default_cfg['test_input_size'] 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if 'interpolation' in args and args['interpolation']: 38 | new_config['interpolation'] = args['interpolation'] 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if 'mean' in args and args['mean'] is not None: 45 | mean = tuple(args['mean']) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if 'std' in args and args['std'] is not None: 57 | std = tuple(args['std']) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | new_config['crop_pct'] = DEFAULT_CROP_PCT 68 | if 'crop_pct' in args and args['crop_pct'] is not None: 69 | new_config['crop_pct'] = args['crop_pct'] 70 | elif 'crop_pct' in default_cfg: 71 | new_config['crop_pct'] = default_cfg['crop_pct'] 72 | 73 | if verbose: 74 | _logger.info('Data processing configuration for current model + dataset:') 75 | for n, v in new_config.items(): 76 | _logger.info('\t%s: %s' % (n, str(v))) 77 | 78 | return new_config 79 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/class_map.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/class_map.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/constants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/constants.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/parser_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/parser_factory.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/parser_image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/parser_image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/parser_image_in_tar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/__pycache__/parser_image_tar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/data/parsers/__pycache__/parser_image_tar.cpython-36.pyc -------------------------------------------------------------------------------- /timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_class_map(map_or_filename, root=''): 5 | if isinstance(map_or_filename, dict): 6 | assert dict, 'class_map dict must be non-empty' 7 | return map_or_filename 8 | class_map_path = map_or_filename 9 | if not os.path.exists(class_map_path): 10 | class_map_path = os.path.join(root, class_map_path) 11 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 12 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 13 | if class_map_ext == '.txt': 14 | with open(class_map_path) as f: 15 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 16 | else: 17 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 18 | return class_to_idx 19 | 20 | -------------------------------------------------------------------------------- /timm/data/parsers/constants.py: -------------------------------------------------------------------------------- 1 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') 2 | -------------------------------------------------------------------------------- /timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_tar import ParserImageTar 5 | from .parser_image_in_tar import ParserImageInTar 6 | 7 | 8 | def create_parser(name, root, split='train', **kwargs): 9 | name = name.lower() 10 | name = name.split('/', 2) 11 | prefix = '' 12 | if len(name) > 1: 13 | prefix = name[0] 14 | name = name[-1] 15 | 16 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 17 | # explicitly select other options shortly 18 | if prefix == 'tfds': 19 | from .parser_tfds import ParserTfds # defer tensorflow import 20 | parser = ParserTfds(root, name, split=split, **kwargs) 21 | else: 22 | assert os.path.exists(root) 23 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 24 | # FIXME support split here, in parser? 25 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 26 | parser = ParserImageInTar(root, **kwargs) 27 | else: 28 | parser = ParserImageFolder(root, **kwargs) 29 | return parser 30 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads images from folders 2 | 3 | Folders are scannerd recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | 10 | from timm.utils.misc import natural_key 11 | 12 | from .parser import Parser 13 | from .class_map import load_class_map 14 | from .constants import IMG_EXTENSIONS 15 | 16 | 17 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 18 | labels = [] 19 | filenames = [] 20 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 21 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 22 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 23 | for f in files: 24 | base, ext = os.path.splitext(f) 25 | if ext.lower() in types: 26 | filenames.append(os.path.join(root, f)) 27 | labels.append(label) 28 | if class_to_idx is None: 29 | # building class index 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 34 | if sort: 35 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 36 | return images_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageFolder(Parser): 40 | 41 | def __init__( 42 | self, 43 | root, 44 | class_map=''): 45 | super().__init__() 46 | 47 | self.root = root 48 | class_to_idx = None 49 | if class_map: 50 | class_to_idx = load_class_map(class_map, root) 51 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 52 | if len(self.samples) == 0: 53 | raise RuntimeError( 54 | f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') 55 | 56 | def __getitem__(self, index): 57 | path, target = self.samples[index] 58 | return open(path, 'rb'), target 59 | 60 | def __len__(self): 61 | return len(self.samples) 62 | 63 | def _filename(self, index, basename=False, absolute=False): 64 | filename = self.samples[index][0] 65 | if basename: 66 | filename = os.path.basename(filename) 67 | elif not absolute: 68 | filename = os.path.relpath(filename, self.root) 69 | return filename 70 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from .parser import Parser 12 | from .class_map import load_class_map 13 | from .constants import IMG_EXTENSIONS 14 | from timm.utils.misc import natural_key 15 | 16 | 17 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 18 | files = [] 19 | labels = [] 20 | for ti in tarfile.getmembers(): 21 | if not ti.isfile(): 22 | continue 23 | dirname, basename = os.path.split(ti.path) 24 | label = os.path.basename(dirname) 25 | ext = os.path.splitext(basename)[1] 26 | if ext.lower() in IMG_EXTENSIONS: 27 | files.append(ti) 28 | labels.append(label) 29 | if class_to_idx is None: 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 34 | if sort: 35 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 36 | return tarinfo_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageTar(Parser): 40 | """ Single tarfile dataset where classes are mapped to folders within tar 41 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 42 | operate on folders of tars or tars in tars. 43 | """ 44 | def __init__(self, root, class_map=''): 45 | super().__init__() 46 | 47 | class_to_idx = None 48 | if class_map: 49 | class_to_idx = load_class_map(class_map, root) 50 | assert os.path.isfile(root) 51 | self.root = root 52 | 53 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 54 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 55 | self.imgs = self.samples 56 | self.tarfile = None # lazy init in __getitem__ 57 | 58 | def __getitem__(self, index): 59 | if self.tarfile is None: 60 | self.tarfile = tarfile.open(self.root) 61 | tarinfo, target = self.samples[index] 62 | fileobj = self.tarfile.extractfile(tarinfo) 63 | return fileobj, target 64 | 65 | def __len__(self): 66 | return len(self.samples) 67 | 68 | def _filename(self, index, basename=False, absolute=False): 69 | filename = self.samples[index][0].name 70 | if basename: 71 | filename = os.path.basename(filename) 72 | return filename 73 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch._C.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch._C.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, 18 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): 19 | super(BinaryCrossEntropy, self).__init__() 20 | assert 0. <= smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.target_threshold = target_threshold 23 | self.reduction = reduction 24 | self.register_buffer('weight', weight) 25 | self.register_buffer('pos_weight', pos_weight) 26 | 27 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 28 | assert x.shape[0] == target.shape[0] 29 | if target.shape != x.shape: 30 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 31 | num_classes = x.shape[-1] 32 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 33 | off_value = self.smoothing / num_classes 34 | on_value = 1. - self.smoothing + off_value 35 | target = target.long().view(-1, 1) 36 | target = torch.full( 37 | (target.size()[0], num_classes), 38 | off_value, 39 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 40 | if self.target_threshold is not None: 41 | # Make target 0, or 1 if threshold set 42 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 43 | return F.binary_cross_entropy_with_logits( 44 | x, target, 45 | self.weight, 46 | pos_weight=self.pos_weight, 47 | reduction=self.reduction) 48 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .crossvit import * 9 | from .cspnet import * 10 | from .densenet import * 11 | from .dla import * 12 | from .dpn import * 13 | from .efficientnet import * 14 | from .ghostnet import * 15 | from .gluon_resnet import * 16 | from .gluon_xception import * 17 | from .hardcorenas import * 18 | from .hrnet import * 19 | from .inception_resnet_v2 import * 20 | from .inception_v3 import * 21 | from .inception_v4 import * 22 | from .levit import * 23 | from .mlp_mixer import * 24 | from .mobilenetv3 import * 25 | from .nasnet import * 26 | from .nest import * 27 | from .nfnet import * 28 | from .pit import * 29 | from .pnasnet import * 30 | from .regnet import * 31 | from .res2net import * 32 | from .resnest import * 33 | from .resnet import * 34 | from .resnetv2 import * 35 | from .rexnet import * 36 | from .selecsls import * 37 | from .senet import * 38 | from .sknet import * 39 | from .swin_transformer import * 40 | from .tnt import * 41 | from .tresnet import * 42 | from .twins import * 43 | from .vgg import * 44 | from .visformer import * 45 | from .vision_transformer import * 46 | from .vision_transformer_hybrid import * 47 | from .vovnet import * 48 | from .xception import * 49 | from .xception_aligned import * 50 | from .xcit import * 51 | 52 | from .factory import create_model, split_model_name, safe_model_name 53 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 54 | from .layers import TestTimePoolHead, apply_test_time_pool 55 | from .layers import convert_splitbn_model 56 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 57 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 58 | has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained 59 | -------------------------------------------------------------------------------- /timm/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/beit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/beit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byoanet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/byoanet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byobnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/byobnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/cait.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/cait.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/coat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/coat.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/convit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/convit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/convmixer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/convmixer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/crossvit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/crossvit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/cspnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/cspnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/densenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/densenet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/dla.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/dla.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/dpn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/dpn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/efficientnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/efficientnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/efficientnet_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/efficientnet_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/efficientnet_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/efficientnet_builder.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/factory.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/features.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/features.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/fx_features.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/fx_features.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/ghostnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/ghostnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/gluon_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/gluon_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/gluon_xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/gluon_xception.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/hardcorenas.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/hardcorenas.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/hrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/hrnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/hub.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/hub.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/inception_resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/inception_resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/inception_v3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/inception_v3.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/inception_v4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/inception_v4.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/levit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/levit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/mlp_mixer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/mlp_mixer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/mobilenetv3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/mobilenetv3.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/nasnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/nasnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/nest.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/nest.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/nfnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/nfnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/pit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/pit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/pnasnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/pnasnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/registry.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/registry.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/regnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/regnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/res2net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/res2net.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/resnest.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/resnest.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/resnetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/resnetv2.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/rexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/rexnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/selecsls.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/selecsls.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/senet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/senet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/sknet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/sknet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/swin_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/swin_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/tnt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/tnt.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/tresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/tresnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/twins.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/twins.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/visformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/visformer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/vision_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/vision_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/vision_transformer_hybrid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/vision_transformer_hybrid.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/vovnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/vovnet.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/xception.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/xception_aligned.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/xception_aligned.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/xcit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/__pycache__/xcit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/fx_features.py: -------------------------------------------------------------------------------- 1 | """ PyTorch FX Based Feature Extraction Helpers 2 | Using https://pytorch.org/vision/stable/feature_extraction.html 3 | """ 4 | from typing import Callable 5 | from torch import nn 6 | 7 | from .features import _get_feature_info 8 | 9 | try: 10 | from torchvision.models.feature_extraction import create_feature_extractor 11 | has_fx_feature_extraction = True 12 | except ImportError: 13 | has_fx_feature_extraction = False 14 | 15 | # Layers we went to treat as leaf modules 16 | from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath 17 | from .layers.non_local_attn import BilinearAttnTransform 18 | from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame 19 | 20 | # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here 21 | # BUT modules from timm.models should use the registration mechanism below 22 | _leaf_modules = { 23 | BatchNormAct2d, # reason: flow control for jit scripting 24 | BilinearAttnTransform, # reason: flow control t <= 1 25 | BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] 26 | # Reason: get_same_padding has a max which raises a control flow error 27 | Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, 28 | CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) 29 | DropPath, # reason: TypeError: rand recieved Proxy in `size` argument 30 | } 31 | 32 | try: 33 | from .layers import InplaceAbn 34 | _leaf_modules.add(InplaceAbn) 35 | except ImportError: 36 | pass 37 | 38 | 39 | def register_notrace_module(module: nn.Module): 40 | """ 41 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it. 42 | """ 43 | _leaf_modules.add(module) 44 | return module 45 | 46 | 47 | # Functions we want to autowrap (treat them as leaves) 48 | _autowrap_functions = set() 49 | 50 | 51 | def register_notrace_function(func: Callable): 52 | """ 53 | Decorator for functions which ought not to be traced through 54 | """ 55 | _autowrap_functions.add(func) 56 | return func 57 | 58 | 59 | class FeatureGraphNet(nn.Module): 60 | def __init__(self, model, out_indices, out_map=None): 61 | super().__init__() 62 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 63 | self.feature_info = _get_feature_info(model, out_indices) 64 | if out_map is not None: 65 | assert len(out_map) == len(out_indices) 66 | return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] 67 | for i, info in enumerate(self.feature_info) if i in out_indices} 68 | self.graph_module = create_feature_extractor( 69 | model, return_nodes, 70 | tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) 71 | 72 | def forward(self, x): 73 | return list(self.graph_module(x).values()) 74 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act 15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 16 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 17 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 18 | from .gather_excite import GatherExcite 19 | from .global_context import GlobalContext 20 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 21 | from .inplace_abn import InplaceAbn 22 | from .linear import Linear 23 | from .mixed_conv2d import MixedConv2d 24 | from .mlp import Mlp, GluMlp, GatedMlp 25 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 26 | from .norm import GroupNorm, LayerNorm2d 27 | from .norm_act import BatchNormAct2d, GroupNormAct 28 | from .padding import get_padding, get_same_padding, pad_same 29 | from .patch_embed import PatchEmbed 30 | from .pool2d_same import AvgPool2dSame, create_pool2d 31 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 32 | from .selective_kernel import SelectiveKernel 33 | from .separable_conv import SeparableConv2d, SeparableConvBnAct 34 | from .space_to_depth import SpaceToDepthModule 35 | from .split_attn import SplitAttn 36 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 37 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 38 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 39 | from .trace_utils import _assert, _float_to_int 40 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ 41 | -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/activations.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_jit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/activations_jit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_me.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/activations_me.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/blur_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/blur_pool.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/bottleneck_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/bottleneck_attn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cbam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/cbam.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/classifier.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cond_conv2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/cond_conv2d.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv2d_same.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/conv2d_same.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv_bn_act.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/conv_bn_act.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_act.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/create_act.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/create_attn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_conv2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/create_conv2d.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_norm_act.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/create_norm_act.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/drop.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/drop.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/eca.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/eca.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/evo_norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/evo_norm.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/gather_excite.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/gather_excite.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/global_context.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/global_context.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/halo_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/halo_attn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/inplace_abn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/inplace_abn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/lambda_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/lambda_layer.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/linear.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mixed_conv2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/mixed_conv2d.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/non_local_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/non_local_attn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm_act.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/norm_act.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/padding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/padding.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/patch_embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/patch_embed.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/pool2d_same.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/pool2d_same.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/selective_kernel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/selective_kernel.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/separable_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/separable_conv.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/space_to_depth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/space_to_depth.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/split_attn.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/split_batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/squeeze_excite.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/squeeze_excite.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/std_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/std_conv.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/test_time_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/test_time_pool.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/trace_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/trace_utils.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/weight_init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/models/layers/__pycache__/weight_init.cpython-36.pyc -------------------------------------------------------------------------------- /timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) 43 | -------------------------------------------------------------------------------- /timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | from .linear import Linear 10 | 11 | 12 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 13 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 14 | if not pool_type: 15 | assert num_classes == 0 or use_conv,\ 16 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 17 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 18 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 19 | num_pooled_features = num_features * global_pool.feat_mult() 20 | return global_pool, num_pooled_features 21 | 22 | 23 | def _create_fc(num_features, num_classes, use_conv=False): 24 | if num_classes <= 0: 25 | fc = nn.Identity() # pass-through (no classifier) 26 | elif use_conv: 27 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 28 | else: 29 | # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue 30 | fc = Linear(num_features, num_classes, bias=True) 31 | return fc 32 | 33 | 34 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 35 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 36 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 37 | return global_pool, fc 38 | 39 | 40 | class ClassifierHead(nn.Module): 41 | """Classifier head w/ configurable global pooling and dropout.""" 42 | 43 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 44 | super(ClassifierHead, self).__init__() 45 | self.drop_rate = drop_rate 46 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 47 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 48 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 49 | 50 | def forward(self, x): 51 | x = self.global_pool(x) 52 | if self.drop_rate: 53 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 54 | x = self.fc(x) 55 | x = self.flatten(x) 56 | return x 57 | -------------------------------------------------------------------------------- /timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import convert_norm_act 9 | 10 | 11 | class ConvBnAct(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 13 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, 14 | drop_block=None): 15 | super(ConvBnAct, self).__init__() 16 | use_aa = aa_layer is not None 17 | 18 | self.conv = create_conv2d( 19 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 20 | padding=padding, dilation=dilation, groups=groups, bias=bias) 21 | 22 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 23 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 25 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | if self.aa is not None: 39 | x = self.aa(x) 40 | return x 41 | -------------------------------------------------------------------------------- /timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 26 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 27 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 28 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 29 | else: 30 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 31 | return m 32 | -------------------------------------------------------------------------------- /timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 16 | from .norm_act import BatchNormAct2d, GroupNormAct 17 | from .inplace_abn import InplaceAbn 18 | 19 | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} 20 | _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type 21 | 22 | 23 | def get_norm_act_layer(layer_class): 24 | layer_class = layer_class.replace('_', '').lower() 25 | if layer_class.startswith("batchnorm"): 26 | layer = BatchNormAct2d 27 | elif layer_class.startswith("groupnorm"): 28 | layer = GroupNormAct 29 | elif layer_class == "evonormbatch": 30 | layer = EvoNormBatch2d 31 | elif layer_class == "evonormsample": 32 | layer = EvoNormSample2d 33 | elif layer_class == "iabn" or layer_class == "inplaceabn": 34 | layer = InplaceAbn 35 | else: 36 | assert False, "Invalid norm_act layer (%s)" % layer_class 37 | return layer 38 | 39 | 40 | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): 41 | layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu 42 | assert len(layer_parts) in (1, 2) 43 | layer = get_norm_act_layer(layer_parts[0]) 44 | #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? 45 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 46 | if jit: 47 | layer_instance = torch.jit.script(layer_instance) 48 | return layer_instance 49 | 50 | 51 | def convert_norm_act(norm_layer, act_layer): 52 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 53 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 54 | norm_act_kwargs = {} 55 | 56 | # unbind partial fn, so args can be rebound later 57 | if isinstance(norm_layer, functools.partial): 58 | norm_act_kwargs.update(norm_layer.keywords) 59 | norm_layer = norm_layer.func 60 | 61 | if isinstance(norm_layer, str): 62 | norm_act_layer = get_norm_act_layer(norm_layer) 63 | elif norm_layer in _NORM_ACT_TYPES: 64 | norm_act_layer = norm_layer 65 | elif isinstance(norm_layer, types.FunctionType): 66 | # if function type, must be a lambda/fn that creates a norm_act layer 67 | norm_act_layer = norm_layer 68 | else: 69 | type_name = norm_layer.__name__.lower() 70 | if type_name.startswith('batchnorm'): 71 | norm_act_layer = BatchNormAct2d 72 | elif type_name.startswith('groupnorm'): 73 | norm_act_layer = GroupNormAct 74 | else: 75 | assert False, f"No equivalent norm_act layer for {type_name}" 76 | 77 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 78 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 79 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 80 | norm_act_kwargs.setdefault('act_layer', act_layer) 81 | if norm_act_kwargs: 82 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 83 | return norm_act_layer 84 | -------------------------------------------------------------------------------- /timm/models/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_block=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GroupNorm(nn.GroupNorm): 9 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 10 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 11 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 12 | 13 | def forward(self, x): 14 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 15 | 16 | 17 | class LayerNorm2d(nn.LayerNorm): 18 | """ LayerNorm for channels of '2D' spatial BCHW tensors """ 19 | def __init__(self, num_channels): 20 | super().__init__(num_channels) 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | return F.layer_norm( 24 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 25 | -------------------------------------------------------------------------------- /timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /timm/models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 19 | super().__init__() 20 | img_size = to_2tuple(img_size) 21 | patch_size = to_2tuple(patch_size) 22 | self.img_size = img_size 23 | self.patch_size = patch_size 24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.flatten = flatten 27 | 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 30 | 31 | def forward(self, x): 32 | B, C, H, W = x.shape 33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 35 | x = self.proj(x) 36 | if self.flatten: 37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 38 | x = self.norm(x) 39 | return x 40 | -------------------------------------------------------------------------------- /timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import convert_norm_act 12 | 13 | 14 | class SeparableConvBnAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_block=None): 20 | super(SeparableConvBnAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 30 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 31 | 32 | @property 33 | def in_channels(self): 34 | return self.conv_dw.in_channels 35 | 36 | @property 37 | def out_channels(self): 38 | return self.conv_pw.out_channels 39 | 40 | def forward(self, x): 41 | x = self.conv_dw(x) 42 | x = self.conv_pw(x) 43 | if self.bn is not None: 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | class SeparableConv2d(nn.Module): 49 | """ Separable Conv 50 | """ 51 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 52 | channel_multiplier=1.0, pw_kernel_size=1): 53 | super(SeparableConv2d, self).__init__() 54 | 55 | self.conv_dw = create_conv2d( 56 | in_channels, int(in_channels * channel_multiplier), kernel_size, 57 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 58 | 59 | self.conv_pw = create_conv2d( 60 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 61 | 62 | @property 63 | def in_channels(self): 64 | return self.conv_dw.in_channels 65 | 66 | @property 67 | def out_channels(self): 68 | return self.conv_pw.out_channels 69 | 70 | def forward(self, x): 71 | x = self.conv_dw(x) 72 | x = self.conv_pw(x) 73 | return x 74 | -------------------------------------------------------------------------------- /timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | self.drop_block = drop_block 43 | mid_chs = out_channels * radix 44 | if rd_channels is None: 45 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 46 | else: 47 | attn_chs = rd_channels * radix 48 | 49 | padding = kernel_size // 2 if padding is None else padding 50 | self.conv = nn.Conv2d( 51 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 52 | groups=groups * radix, bias=bias, **kwargs) 53 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | if self.drop_block is not None: 65 | x = self.drop_block(x) 66 | x = self.act0(x) 67 | 68 | B, RC, H, W = x.shape 69 | if self.radix > 1: 70 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 71 | x_gap = x.sum(dim=1) 72 | else: 73 | x_gap = x 74 | x_gap = x_gap.mean((2, 3), keepdim=True) 75 | x_gap = self.fc1(x_gap) 76 | x_gap = self.bn1(x_gap) 77 | x_gap = self.act1(x_gap) 78 | x_attn = self.fc2(x_gap) 79 | 80 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 81 | if self.radix > 1: 82 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 83 | else: 84 | out = x * x_attn 85 | return out.contiguous() 86 | -------------------------------------------------------------------------------- /timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/models/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | -------------------------------------------------------------------------------- /timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=True): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /timm/models/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /timm/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 67 | if mode == 'fan_in': 68 | denom = fan_in 69 | elif mode == 'fan_out': 70 | denom = fan_out 71 | elif mode == 'fan_avg': 72 | denom = (fan_in + fan_out) / 2 73 | 74 | variance = scale / denom 75 | 76 | if distribution == "truncated_normal": 77 | # constant is stddev of standard normal truncated to (-2, 2) 78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 79 | elif distribution == "normal": 80 | tensor.normal_(std=math.sqrt(variance)) 81 | elif distribution == "uniform": 82 | bound = math.sqrt(3 * variance) 83 | tensor.uniform_(-bound, bound) 84 | else: 85 | raise ValueError(f"invalid distribution {distribution}") 86 | 87 | 88 | def lecun_normal_(tensor): 89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 90 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adahessian import Adahessian 4 | from .adamp import AdamP 5 | from .adamw import AdamW 6 | from .lamb import Lamb 7 | from .lars import Lars 8 | from .lookahead import Lookahead 9 | from .madgrad import MADGRAD 10 | from .nadam import Nadam 11 | from .nvnovograd import NvNovoGrad 12 | from .radam import RAdam 13 | from .rmsprop_tf import RMSpropTF 14 | from .sgdp import SGDP 15 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 16 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | # NOTE super().__init__() not called on purpose 15 | if not 0.0 <= alpha <= 1.0: 16 | raise ValueError(f'Invalid slow update rate: {alpha}') 17 | if not 1 <= k: 18 | raise ValueError(f'Invalid lookahead steps: {k}') 19 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 20 | self._base_optimizer = base_optimizer 21 | self.param_groups = base_optimizer.param_groups 22 | self.defaults = base_optimizer.defaults 23 | self.defaults.update(defaults) 24 | self.state = defaultdict(dict) 25 | # manually add our defaults to the param groups 26 | for name, default in defaults.items(): 27 | for group in self._base_optimizer.param_groups: 28 | group.setdefault(name, default) 29 | 30 | @torch.no_grad() 31 | def update_slow(self, group): 32 | for fast_p in group["params"]: 33 | if fast_p.grad is None: 34 | continue 35 | param_state = self._base_optimizer.state[fast_p] 36 | if 'lookahead_slow_buff' not in param_state: 37 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 38 | param_state['lookahead_slow_buff'].copy_(fast_p) 39 | slow = param_state['lookahead_slow_buff'] 40 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 41 | fast_p.copy_(slow) 42 | 43 | def sync_lookahead(self): 44 | for group in self._base_optimizer.param_groups: 45 | self.update_slow(group) 46 | 47 | @torch.no_grad() 48 | def step(self, closure=None): 49 | loss = self._base_optimizer.step(closure) 50 | for group in self._base_optimizer.param_groups: 51 | group['lookahead_step'] += 1 52 | if group['lookahead_step'] % group['lookahead_k'] == 0: 53 | self.update_slow(group) 54 | return loss 55 | 56 | def state_dict(self): 57 | return self._base_optimizer.state_dict() 58 | 59 | def load_state_dict(self, state_dict): 60 | self._base_optimizer.load_state_dict(state_dict) 61 | self.param_groups = self._base_optimizer.param_groups 62 | -------------------------------------------------------------------------------- /timm/optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict( 14 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 15 | buffer=[[None, None, None] for _ in range(10)]) 16 | super(RAdam, self).__init__(params, defaults) 17 | 18 | def __setstate__(self, state): 19 | super(RAdam, self).__setstate__(state) 20 | 21 | @torch.no_grad() 22 | def step(self, closure=None): 23 | loss = None 24 | if closure is not None: 25 | with torch.enable_grad(): 26 | loss = closure() 27 | 28 | for group in self.param_groups: 29 | 30 | for p in group['params']: 31 | if p.grad is None: 32 | continue 33 | grad = p.grad.float() 34 | if grad.is_sparse: 35 | raise RuntimeError('RAdam does not support sparse gradients') 36 | 37 | p_fp32 = p.float() 38 | 39 | state = self.state[p] 40 | 41 | if len(state) == 0: 42 | state['step'] = 0 43 | state['exp_avg'] = torch.zeros_like(p_fp32) 44 | state['exp_avg_sq'] = torch.zeros_like(p_fp32) 45 | else: 46 | state['exp_avg'] = state['exp_avg'].type_as(p_fp32) 47 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) 48 | 49 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 50 | beta1, beta2 = group['betas'] 51 | 52 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 53 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 54 | 55 | state['step'] += 1 56 | buffered = group['buffer'][int(state['step'] % 10)] 57 | if state['step'] == buffered[0]: 58 | num_sma, step_size = buffered[1], buffered[2] 59 | else: 60 | buffered[0] = state['step'] 61 | beta2_t = beta2 ** state['step'] 62 | num_sma_max = 2 / (1 - beta2) - 1 63 | num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 64 | buffered[1] = num_sma 65 | 66 | # more conservative since it's an approximated value 67 | if num_sma >= 5: 68 | step_size = group['lr'] * math.sqrt( 69 | (1 - beta2_t) * 70 | (num_sma - 4) / (num_sma_max - 4) * 71 | (num_sma - 2) / num_sma * 72 | num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) 73 | else: 74 | step_size = group['lr'] / (1 - beta1 ** state['step']) 75 | buffered[2] = step_size 76 | 77 | if group['weight_decay'] != 0: 78 | p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) 79 | 80 | # more conservative since it's an approximated value 81 | if num_sma >= 5: 82 | denom = exp_avg_sq.sqrt().add_(group['eps']) 83 | p_fp32.addcdiv_(exp_avg, denom, value=-step_size) 84 | else: 85 | p_fp32.add_(exp_avg, alpha=-step_size) 86 | 87 | p.copy_(p_fp32) 88 | 89 | return loss 90 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__(self, params, lr=required, momentum=0, dampening=0, 21 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 22 | defaults = dict( 23 | lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 24 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 25 | super(SGDP, self).__init__(params, defaults) 26 | 27 | @torch.no_grad() 28 | def step(self, closure=None): 29 | loss = None 30 | if closure is not None: 31 | with torch.enable_grad(): 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | weight_decay = group['weight_decay'] 36 | momentum = group['momentum'] 37 | dampening = group['dampening'] 38 | nesterov = group['nesterov'] 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad 44 | state = self.state[p] 45 | 46 | # State initialization 47 | if len(state) == 0: 48 | state['momentum'] = torch.zeros_like(p) 49 | 50 | # SGD 51 | buf = state['momentum'] 52 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 53 | if nesterov: 54 | d_p = grad + momentum * buf 55 | else: 56 | d_p = buf 57 | 58 | # Projection 59 | wd_ratio = 1. 60 | if len(p.shape) > 1: 61 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 62 | 63 | # Weight decay 64 | if weight_decay != 0: 65 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 66 | 67 | # Step 68 | p.add_(d_p, alpha=-group['lr']) 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler 9 | -------------------------------------------------------------------------------- /timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__(self, 15 | optimizer: torch.optim.Optimizer, 16 | decay_t: List[int], 17 | decay_rate: float = 1., 18 | warmup_t=0, 19 | warmup_lr_init=0, 20 | t_in_epochs=True, 21 | noise_range_t=None, 22 | noise_pct=0.67, 23 | noise_std=1.0, 24 | noise_seed=42, 25 | initialize=True, 26 | ) -> None: 27 | super().__init__( 28 | optimizer, param_group_field="lr", 29 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 30 | initialize=initialize) 31 | 32 | self.decay_t = decay_t 33 | self.decay_rate = decay_rate 34 | self.warmup_t = warmup_t 35 | self.warmup_lr_init = warmup_lr_init 36 | self.t_in_epochs = t_in_epochs 37 | if self.warmup_t: 38 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 39 | super().update_groups(self.warmup_lr_init) 40 | else: 41 | self.warmup_steps = [1 for _ in self.base_values] 42 | 43 | def get_curr_decay_steps(self, t): 44 | # find where in the array t goes, 45 | # assumes self.decay_t is sorted 46 | return bisect.bisect_right(self.decay_t, t+1) 47 | 48 | def _get_lr(self, t): 49 | if t < self.warmup_t: 50 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 51 | else: 52 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 53 | return lrs 54 | 55 | def get_epoch_values(self, epoch: int): 56 | if self.t_in_epochs: 57 | return self._get_lr(epoch) 58 | else: 59 | return None 60 | 61 | def get_update_values(self, num_updates: int): 62 | if not self.t_in_epochs: 63 | return self._get_lr(num_updates) 64 | else: 65 | return None 66 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .distributed import distribute_bn, reduce_tensor 6 | from .jit import set_jit_legacy 7 | from .log import setup_default_logging, FormatterNoInfo 8 | from .metrics import AverageMeter, accuracy 9 | from .misc import natural_key, add_bool_arg 10 | from .model import unwrap_model, get_state_dict, freeze, unfreeze 11 | from .model_ema import ModelEma, ModelEmaV2 12 | from .random import random_seed 13 | from .summary import update_summary, get_outdir 14 | -------------------------------------------------------------------------------- /timm/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/agc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/agc.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/checkpoint_saver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/checkpoint_saver.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/clip_grad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/clip_grad.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/cuda.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/cuda.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/distributed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/distributed.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/jit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/jit.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/model_ema.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/model_ema.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/random.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/random.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/__pycache__/summary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/timm/utils/__pycache__/summary.cpython-36.pyc -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 21 | with amp.scale_loss(loss, optimizer) as scaled_loss: 22 | scaled_loss.backward(create_graph=create_graph) 23 | if clip_grad is not None: 24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 25 | optimizer.step() 26 | 27 | def state_dict(self): 28 | if 'state_dict' in amp.__dict__: 29 | return amp.state_dict() 30 | 31 | def load_state_dict(self, state_dict): 32 | if 'load_state_dict' in amp.__dict__: 33 | amp.load_state_dict(state_dict) 34 | 35 | 36 | class NativeScaler: 37 | state_dict_key = "amp_scaler" 38 | 39 | def __init__(self): 40 | self._scaler = torch.cuda.amp.GradScaler() 41 | 42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 43 | self._scaler.scale(loss).backward(create_graph=create_graph) 44 | if clip_grad is not None: 45 | assert parameters is not None 46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 48 | self._scaler.step(optimizer) 49 | self._scaler.update() 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) 56 | -------------------------------------------------------------------------------- /timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | 8 | def set_jit_legacy(): 9 | """ Set JIT executor to legacy w/ support for op fusion 10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 11 | in the JIT exectutor. These API are not supported so could change. 12 | """ 13 | # 14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 15 | torch._C._jit_set_profiling_executor(False) 16 | torch._C._jit_set_profiling_mode(False) 17 | torch._C._jit_override_can_fuse_on_gpu(True) 18 | #torch._C._jit_set_texpr_fuser_enabled(True) 19 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | def get_outdir(path, *paths, inc=False): 14 | outdir = os.path.join(path, *paths) 15 | if not os.path.exists(outdir): 16 | os.makedirs(outdir) 17 | elif inc: 18 | count = 1 19 | outdir_inc = outdir + '-' + str(count) 20 | while os.path.exists(outdir_inc): 21 | count = count + 1 22 | outdir_inc = outdir + '-' + str(count) 23 | assert count < 100 24 | outdir = outdir_inc 25 | os.makedirs(outdir) 26 | return outdir 27 | 28 | 29 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): 30 | rowd = OrderedDict(epoch=epoch) 31 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 32 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 33 | if log_wandb: 34 | wandb.log(rowd) 35 | with open(filename, mode='a') as cf: 36 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 37 | if write_header: # first iteration (epoch == 1 can't be used) 38 | dw.writeheader() 39 | dw.writerow(rowd) 40 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.5.0' 2 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from util import * 3 | -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/ROMA/4a1bcb2e72edfb80a88594a6b09e0521b3bfe591/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | --------------------------------------------------------------------------------