├── FIFO.yaml ├── LICENSE ├── README.md ├── compute_iou.py ├── configs ├── test_config.py └── train_config.py ├── dataset ├── Foggy_Zurich.py ├── Foggy_Zurich_test.py ├── __init__.py ├── cityscapes_dataset.py ├── cityscapes_list │ ├── clear_lindau.txt │ ├── info.json │ ├── label_lindau.txt │ ├── label_val.txt │ ├── train_foggy_0.005.txt │ ├── train_origin.txt │ ├── val.txt │ └── val_foggy_0.005.txt ├── foggy_driving.py └── paired_cityscapes.py ├── evaluate.py ├── lists_file_names ├── gt_testall_filenames.txt ├── gt_testdense_filenames.txt ├── leftImg8bit_testall_filenames.txt ├── leftImg8bit_testdense_filenames.txt └── leftImg8bit_testfine_filenames.txt ├── main.py ├── model ├── __init__.py ├── fogpassfilter.py ├── refinenetlw.py └── utils.py ├── pytorch_metric_learning ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── distances │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── base_distance.cpython-37.pyc │ │ ├── cosine_similarity.cpython-37.pyc │ │ ├── dot_product_similarity.cpython-37.pyc │ │ ├── lp_distance.cpython-37.pyc │ │ └── snr_distance.cpython-37.pyc │ ├── base_distance.py │ ├── cosine_similarity.py │ ├── dot_product_similarity.py │ ├── lp_distance.py │ └── snr_distance.py ├── losses │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── angular_loss.cpython-37.pyc │ │ ├── arcface_loss.cpython-37.pyc │ │ ├── base_metric_loss_function.cpython-37.pyc │ │ ├── circle_loss.cpython-37.pyc │ │ ├── contrastive_loss.cpython-37.pyc │ │ ├── cosface_loss.cpython-37.pyc │ │ ├── cross_batch_memory.cpython-37.pyc │ │ ├── fast_ap_loss.cpython-37.pyc │ │ ├── generic_pair_loss.cpython-37.pyc │ │ ├── intra_pair_variance_loss.cpython-37.pyc │ │ ├── large_margin_softmax_loss.cpython-37.pyc │ │ ├── lifted_structure_loss.cpython-37.pyc │ │ ├── margin_loss.cpython-37.pyc │ │ ├── mixins.cpython-37.pyc │ │ ├── multi_similarity_loss.cpython-37.pyc │ │ ├── n_pairs_loss.cpython-37.pyc │ │ ├── nca_loss.cpython-37.pyc │ │ ├── normalized_softmax_loss.cpython-37.pyc │ │ ├── ntxent_loss.cpython-37.pyc │ │ ├── proxy_anchor_loss.cpython-37.pyc │ │ ├── proxy_losses.cpython-37.pyc │ │ ├── signal_to_noise_ratio_losses.cpython-37.pyc │ │ ├── soft_triple_loss.cpython-37.pyc │ │ ├── sphereface_loss.cpython-37.pyc │ │ ├── supcon_loss.cpython-37.pyc │ │ ├── triplet_margin_loss.cpython-37.pyc │ │ └── tuplet_margin_loss.cpython-37.pyc │ ├── angular_loss.py │ ├── arcface_loss.py │ ├── base_metric_loss_function.py │ ├── circle_loss.py │ ├── contrastive_loss.py │ ├── cosface_loss.py │ ├── cross_batch_memory.py │ ├── fast_ap_loss.py │ ├── generic_pair_loss.py │ ├── intra_pair_variance_loss.py │ ├── large_margin_softmax_loss.py │ ├── lifted_structure_loss.py │ ├── margin_loss.py │ ├── mixins.py │ ├── multi_similarity_loss.py │ ├── n_pairs_loss.py │ ├── nca_loss.py │ ├── normalized_softmax_loss.py │ ├── ntxent_loss.py │ ├── proxy_anchor_loss.py │ ├── proxy_losses.py │ ├── signal_to_noise_ratio_losses.py │ ├── soft_triple_loss.py │ ├── sphereface_loss.py │ ├── supcon_loss.py │ ├── triplet_margin_loss.py │ └── tuplet_margin_loss.py ├── miners │ ├── __init__.py │ ├── angular_miner.py │ ├── base_miner.py │ ├── batch_easy_hard_miner.py │ ├── batch_hard_miner.py │ ├── distance_weighted_miner.py │ ├── embeddings_already_packaged_as_triplets.py │ ├── hdc_miner.py │ ├── maximum_loss_miner.py │ ├── multi_similarity_miner.py │ ├── pair_margin_miner.py │ ├── triplet_margin_miner.py │ └── uniform_histogram_miner.py ├── reducers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── avg_non_zero_reducer.cpython-37.pyc │ │ ├── base_reducer.cpython-37.pyc │ │ ├── class_weighted_reducer.cpython-37.pyc │ │ ├── divisor_reducer.cpython-37.pyc │ │ ├── do_nothing_reducer.cpython-37.pyc │ │ ├── mean_reducer.cpython-37.pyc │ │ ├── multiple_reducers.cpython-37.pyc │ │ ├── per_anchor_reducer.cpython-37.pyc │ │ └── threshold_reducer.cpython-37.pyc │ ├── avg_non_zero_reducer.py │ ├── base_reducer.py │ ├── class_weighted_reducer.py │ ├── divisor_reducer.py │ ├── do_nothing_reducer.py │ ├── mean_reducer.py │ ├── multiple_reducers.py │ ├── per_anchor_reducer.py │ └── threshold_reducer.py ├── regularizers │ ├── __init__.py │ ├── base_regularizer.py │ ├── center_invariant_regularizer.py │ ├── lp_regularizer.py │ ├── regular_face_regularizer.py │ ├── sparse_centers_regularizer.py │ └── zero_mean_regularizer.py ├── samplers │ ├── __init__.py │ ├── fixed_set_of_triplets.py │ ├── hierarchical_sampler.py │ ├── m_per_class_sampler.py │ └── tuples_to_weights_sampler.py ├── testers │ ├── __init__.py │ ├── base_tester.py │ ├── global_embedding_space.py │ ├── global_twostream_embedding_space.py │ └── with_same_parent_label.py ├── trainers │ ├── __init__.py │ ├── base_trainer.py │ ├── cascaded_embeddings.py │ ├── deep_adversarial_metric_learning.py │ ├── metric_loss_only.py │ ├── train_with_classifier.py │ ├── twostream_metric_loss.py │ └── unsupervised_embeddings_using_augmentations.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── common_functions.cpython-37.pyc │ ├── loss_and_miner_utils.cpython-37.pyc │ ├── module_with_records.cpython-37.pyc │ └── module_with_records_and_reducer.cpython-37.pyc │ ├── accuracy_calculator.py │ ├── common_functions.py │ ├── distributed.py │ ├── inference.py │ ├── key_checker.py │ ├── logging_presets.py │ ├── loss_and_miner_utils.py │ ├── loss_tracker.py │ ├── module_with_records.py │ ├── module_with_records_and_reducer.py │ └── stat_utils.py └── utils ├── __init__.py ├── helpers.py ├── layer_factory.py ├── losses.py ├── network.py └── optimisers.py /LICENSE: -------------------------------------------------------------------------------- 1 | Light-Weight RefineNet for non-commercial purposes 2 | 3 | Copyright (c) 2018, Vladimir Nekrasov 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | MIT License 28 | 29 | Copyright (c) 2022 Multimedia Computing Group, Nanjing University 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | -------------------------------------------------------------------------------- /compute_iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | from PIL import Image 5 | from os.path import join 6 | import scipy.misc as m 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | def fast_hist(a, b, n): 11 | # import pdb; pdb.set_trace() 12 | k = (a >= 0) & (a < n) 13 | return np.bincount(n * a[k].astype(int)+ b[k], minlength=n ** 2).reshape(n, n) # 14 | 15 | 16 | def per_class_iu(hist): 17 | return np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 18 | 19 | 20 | def label_mapping(input, mapping): 21 | output = np.copy(input) 22 | for ind in range(len(mapping)): 23 | output[input == mapping[ind][0]] = mapping[ind][1] 24 | return np.array(output, dtype=np.int64) 25 | 26 | 27 | def compute_mIoU(gt_dir, pred_dir, devkit_dir, dataset): 28 | """ 29 | Compute IoU given the predicted colorized images and 30 | """ 31 | label = [ 32 | "road", 33 | "sidewalk", 34 | "building", 35 | "wall", 36 | "fence", 37 | "pole", 38 | "light", 39 | "sign", 40 | "vegetation", 41 | "terrain", 42 | "sky", 43 | "person", 44 | "rider", 45 | "car", 46 | "truck", 47 | "bus", 48 | "train", 49 | "motocycle", 50 | "bicycle"] 51 | 52 | label2train=[ 53 | [0, 255], 54 | [1, 255], 55 | [2, 255], 56 | [3, 255], 57 | [4, 255], 58 | [5, 255], 59 | [6, 255], 60 | [7, 0], 61 | [8, 1], 62 | [9, 255], 63 | [10, 255], 64 | [11, 2], 65 | [12, 3], 66 | [13, 4], 67 | [14, 255], 68 | [15, 255], 69 | [16, 255], 70 | [17, 5], 71 | [18, 255], 72 | [19, 6], 73 | [20, 7], 74 | [21, 8], 75 | [22, 9], 76 | [23, 10], 77 | [24, 11], 78 | [25, 12], 79 | [26, 13], 80 | [27, 14], 81 | [28, 15], 82 | [29, 255], 83 | [30, 255], 84 | [31, 16], 85 | [32, 17], 86 | [33, 18], 87 | [-1, 255]] 88 | 89 | num_classes = 19 90 | name_classes = np.array(label, dtype=np.str) 91 | 92 | hist = np.zeros((num_classes, num_classes)) 93 | if 'FZ' in dataset: 94 | image_path_list = join(devkit_dir, 'RGB_testv2_filenames.txt') 95 | label_path_list = join(devkit_dir, 'gt_labelTrainIds_testv2_filenames.txt') 96 | elif 'FDD' in dataset: 97 | image_path_list = join(devkit_dir, 'leftImg8bit_testdense_filenames.txt') 98 | label_path_list = join(devkit_dir, 'gt_testdense_filenames.txt') 99 | elif 'FD' in dataset: 100 | image_path_list = join(devkit_dir, 'leftImg8bit_testall_filenames.txt') 101 | label_path_list = join(devkit_dir, 'gt_testall_filenames.txt') 102 | elif 'Clindau' in dataset: 103 | image_path_list = join(devkit_dir, 'clear_lindau.txt') 104 | label_path_list = join(devkit_dir, 'label_lindau.txt') 105 | gt_imgs = open(label_path_list, 'r').read().splitlines() 106 | gt_imgs = [join(gt_dir, x) for x in gt_imgs] 107 | 108 | if not 'FZ' in dataset: 109 | mapping = np.array(label2train, dtype=np.int) 110 | pred_imgs = open(image_path_list, 'r').read().splitlines() 111 | pred_imgs = [join(pred_dir, x.split('/')[-1]) for x in pred_imgs] 112 | 113 | for ind in range(len(gt_imgs)): 114 | pred = np.array(Image.open(pred_imgs[ind])) 115 | 116 | label = np.array(Image.open(gt_imgs[ind])) 117 | if not 'FZ' in dataset: 118 | label = label_mapping(label, mapping) 119 | if len(label.flatten()) != len(pred.flatten()): 120 | print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind], pred_imgs[ind])) 121 | continue 122 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 123 | 124 | mIoUs = per_class_iu(hist) 125 | if 'FZ' in dataset: 126 | print('Evaluation on Foggy Zurich') 127 | elif 'FDD' in dataset: 128 | print('Evaluation on Foggy Driving Dense') 129 | elif 'FD' in dataset: 130 | print('Evaluation on Foggy Driving') 131 | elif 'Clindau' in dataset: 132 | print('Evaluation on Cityscapes lindau 40') 133 | print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2))) 134 | miou = float(str(round(np.nanmean(mIoUs) * 100, 2))) 135 | return miou 136 | 137 | 138 | def miou(args): 139 | compute_mIoU(args.gt_dir, args.pred_dir, args.devkit_dir, args.dataset) 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument('--gt-dir', type=str, help='directory which stores CityScapes val gt images') 145 | parser.add_argument('--pred-dir', type=str, help='directory which stores CityScapes val pred images') 146 | parser.add_argument('--devkit_dir', default='/root/data1/Foggy_Zurich/lists_file_names', help='base directory of zurich') 147 | parser.add_argument('--dataset', type=str) 148 | args = parser.parse_args() 149 | miou(args) 150 | -------------------------------------------------------------------------------- /configs/test_config.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | 5 | IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 6 | MODEL = 'RefineNetNew' 7 | DATA_DIRECTORY ='/root/data1' 8 | DATA_CITY_PATH = './dataset/cityscapes_list/clear_lindau.txt' 9 | DATA_DIRECTORY_CITY = '/root/data1/Cityscapes' 10 | DATA_LIST_PATH_EVAL = '/root/data1/Foggy_Zurich/lists_file_names/RGB_testv2_filenames.txt' 11 | DATA_LIST_PATH_EVAL_FD ='./lists_file_names/leftImg8bit_testall_filenames.txt' 12 | DATA_LIST_PATH_EVAL_FDD ='./lists_file_names/leftImg8bit_testdense_filenames.txt' 13 | DATA_DIR_EVAL = '/root/data1' 14 | DATA_DIR_EVAL_FD = '/root/data1/Foggy_Driving' 15 | NUM_CLASSES = 19 16 | RESTORE_FROM = 'no model' 17 | SNAPSHOT_DIR = f'/root/data1/snapshots/FIFO' 18 | GT_DIR_FZ = '/root/data1/Foggy_Zurich' 19 | GT_DIR_FD = '/root/data1/Foggy_Driving' 20 | GT_DIR_CLINDAU = '/root/data1/Cityscapes/gtFine' 21 | SET = 'val' 22 | 23 | MODEL = 'RefineNetNew' 24 | 25 | def get_arguments(): 26 | parser = argparse.ArgumentParser(description="Evlauation") 27 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY) 28 | parser.add_argument("--data-city-list", type=str, default = DATA_CITY_PATH) 29 | parser.add_argument("--data-list-eval-fd", type=str, default=DATA_LIST_PATH_EVAL_FD) 30 | parser.add_argument("--data-list-eval-fdd", type=str, default=DATA_LIST_PATH_EVAL_FDD) 31 | parser.add_argument("--data-dir-city", type=str, default=DATA_DIRECTORY_CITY) 32 | parser.add_argument("--data-list-eval", type=str, default=DATA_LIST_PATH_EVAL) 33 | parser.add_argument("--data-dir-eval", type=str, default=DATA_DIR_EVAL) 34 | parser.add_argument("--data-dir-eval-fd", type=str, default=DATA_DIR_EVAL_FD) 35 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES) 36 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM) 37 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR) 38 | parser.add_argument("--gpu", type=int, default=0) 39 | parser.add_argument("--set", type=str, default=SET) 40 | parser.add_argument("--file-name", type=str, required=True) 41 | parser.add_argument("--gt-dir-fz", type=str, default=GT_DIR_FZ) 42 | parser.add_argument("--gt-dir-fd", type=str, default=GT_DIR_FD) 43 | parser.add_argument("--gt-dir-clindau", type=str, default=GT_DIR_CLINDAU) 44 | parser.add_argument("--devkit-dir-fz", default='/root/data1/Foggy_Zurich/lists_file_names') 45 | parser.add_argument("--devkit-dir-fd", default='./lists_file_names') 46 | parser.add_argument("--devkit-dir-clindau", default='./dataset/cityscapes_list') 47 | return parser.parse_args() 48 | -------------------------------------------------------------------------------- /configs/train_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) 5 | BETA = 0.005 6 | BATCH_SIZE = 4 7 | ITER_SIZE = 1 8 | NUM_WORKERS = 4 9 | DATA_DIRECTORY ='/root/data1' 10 | DATA_LIST_PATH = f'./dataset/cityscapes_list/train_foggy_{BETA}.txt' 11 | DATA_CITY_PATH = './dataset/cityscapes_list/clear_lindau.txt' 12 | INPUT_SIZE = '2048,1024' 13 | DATA_DIRECTORY_CWSF = '/root/data1/Cityscapes' 14 | DATA_LIST_PATH_CWSF = './dataset/cityscapes_list/train_origin.txt' 15 | DATA_LIST_RF = '/root/data1/Foggy_Zurich/lists_file_names/RGB_sum_filenames.txt' 16 | DATA_DIR = '/root/data1' 17 | INPUT_SIZE_RF = '1920,1080' 18 | NUM_CLASSES = 19 19 | NUM_STEPS = 100000 20 | NUM_STEPS_STOP = 60000 # early stopping 21 | RANDOM_SEED = 1234 22 | RESTORE_FROM = 'no_model' 23 | RESTORE_FROM_fogpass = 'no_model' 24 | SAVE_PRED_EVERY = 100 25 | SNAPSHOT_DIR = f'/root/data1/snapshots/FIFO_model' 26 | 27 | SET = 'train' 28 | 29 | def get_arguments(): 30 | 31 | parser = argparse.ArgumentParser(description="FIFO framework") 32 | 33 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE) 34 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE) 35 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS) 36 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY) 37 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH) 38 | parser.add_argument("--data-city-list", type=str, default = DATA_CITY_PATH) 39 | parser.add_argument("--data-list-rf", type=str, default=DATA_LIST_RF) 40 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE) 41 | parser.add_argument("--input-size-rf", type=str, default=INPUT_SIZE_RF) 42 | parser.add_argument("--data-dir-cwsf", type=str, default=DATA_DIRECTORY_CWSF) 43 | parser.add_argument("--data-list-cwsf", type=str, default=DATA_LIST_PATH_CWSF) 44 | parser.add_argument("--data-dir-rf", type=str, default=DATA_DIR) 45 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES) 46 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS) 47 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP) 48 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED) 49 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM) 50 | parser.add_argument("--restore-from-fogpass", type=str, default=RESTORE_FROM_fogpass) 51 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY) 52 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR) 53 | parser.add_argument("--gpu", type=int, default=0) 54 | parser.add_argument("--set", type=str, default=SET) 55 | parser.add_argument("--lambda-fsm", type=float, default=0.0000001) 56 | parser.add_argument("--lambda-con", type=float, default=0.0001) 57 | parser.add_argument("--file-name", type=str, required=True) 58 | parser.add_argument("--modeltrain", type=str, required=True) 59 | return parser.parse_args() 60 | 61 | args = get_arguments() -------------------------------------------------------------------------------- /dataset/Foggy_Zurich_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | from torch.utils import data 10 | from PIL import Image 11 | from os.path import join 12 | import json 13 | import scipy.misc as m 14 | 15 | class foggyzurichDataSet(data.Dataset): 16 | colors = [ # [ 0, 0, 0], 17 | [128, 64, 128], 18 | [244, 35, 232], 19 | [70, 70, 70], 20 | [102, 102, 156], 21 | [190, 153, 153], 22 | [153, 153, 153], 23 | [250, 170, 30], 24 | [220, 220, 0], 25 | [107, 142, 35], 26 | [152, 251, 152], 27 | [0, 130, 180], 28 | [220, 20, 60], 29 | [255, 0, 0], 30 | [0, 0, 142], 31 | [0, 0, 70], 32 | [0, 60, 100], 33 | [0, 80, 100], 34 | [0, 0, 230], 35 | [119, 11, 32], 36 | ] 37 | 38 | def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128)): 39 | self.root = root 40 | self.list_path = list_path 41 | self.crop_size = crop_size 42 | self.mean = mean 43 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 44 | if not max_iters==None: 45 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 46 | self.files = [] 47 | 48 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 49 | self.valid_classes = [ 50 | 7, 51 | 8, 52 | 11, 53 | 12, 54 | 13, 55 | 17, 56 | 19, 57 | 20, 58 | 21, 59 | 22, 60 | 23, 61 | 24, 62 | 25, 63 | 26, 64 | 27, 65 | 28, 66 | 31, 67 | 32, 68 | 33, 69 | ] 70 | self.class_names = [ 71 | "unlabelled", 72 | "road", 73 | "sidewalk", 74 | "building", 75 | "wall", 76 | "fence", 77 | "pole", 78 | "traffic_light", 79 | "traffic_sign", 80 | "vegetation", 81 | "terrain", 82 | "sky", 83 | "person", 84 | "rider", 85 | "car", 86 | "truck", 87 | "bus", 88 | "train", 89 | "motorcycle", 90 | "bicycle", 91 | ] 92 | 93 | self.ignore_index = 255 94 | self.class_map = dict(zip(self.valid_classes, range(19))) 95 | 96 | for name in self.img_ids: 97 | img_file = osp.join(self.root, "./Foggy_Zurich/%s" % (name)) 98 | label_file = osp.join(self.root, "./Foggy_Zurich/%s" % ("gt_labelTrainIds/"+name[4:])) 99 | self.files.append({ 100 | "img": img_file, 101 | "label": label_file, 102 | "name": name 103 | }) 104 | 105 | def __len__(self): 106 | return len(self.files) 107 | 108 | def __getitem__(self, index): 109 | datafiles = self.files[index] 110 | 111 | image = Image.open(datafiles["img"]).convert('RGB') 112 | label = m.imread(datafiles["label"]) 113 | label = np.array(label, dtype=np.float32) 114 | name = datafiles["name"] 115 | 116 | # resize 117 | w, h = image.size 118 | 119 | image = image.resize(self.crop_size, Image.BICUBIC) 120 | image = np.asarray(image, np.float32) 121 | 122 | classes = np.unique(label) 123 | lbl = label.astype(float) 124 | lbl = m.imresize(label, (self.crop_size[1], self.crop_size[0]), "nearest", mode="F") 125 | label = lbl.astype(int) 126 | 127 | size = image.shape 128 | image = image[:, :, ::-1] # change to BGR 129 | image -= self.mean 130 | image = image.transpose((2, 0, 1)) 131 | 132 | return image.copy(), label.copy(), np.array(size), name 133 | 134 | def encode_segmap(self, mask): 135 | # Put all void classes to zero 136 | for _voidc in self.void_classes: 137 | mask[mask == _voidc] = self.ignore_index 138 | for _validc in self.valid_classes: 139 | mask[mask == _validc] = self.class_map[_validc] 140 | return mask 141 | 142 | if __name__ == '__main__': 143 | dst = foggyzurichDataSet("./data", is_transform=True) 144 | trainloader = data.DataLoader(dst, batch_size=4) 145 | for i, data in enumerate(trainloader): 146 | imgs, labels = data 147 | if i == 0: 148 | img = torchvision.utils.make_grid(imgs).numpy() 149 | img = np.transpose(img, (1, 2, 0)) 150 | img = img[:, :, ::-1] 151 | plt.imshow(img) 152 | plt.show() 153 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | from torch.utils import data 10 | from PIL import Image 11 | 12 | class cityscapesDataSet(data.Dataset): 13 | 14 | def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='val'): 15 | self.root = root 16 | self.list_path = list_path 17 | self.crop_size = crop_size 18 | self.scale = scale 19 | self.ignore_label = ignore_label 20 | self.mean = mean 21 | self.is_mirror = mirror 22 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 23 | if not max_iters==None: 24 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 25 | self.files = [] 26 | self.set = set 27 | for name in self.img_ids: 28 | img_file = osp.join(self.root, "leftImg8bit/%s/%s" % (self.set, name)) 29 | self.files.append({ 30 | "img": img_file, 31 | "name": name 32 | }) 33 | 34 | def __len__(self): 35 | return len(self.files) 36 | 37 | def __getitem__(self, index): 38 | datafiles = self.files[index] 39 | 40 | image = Image.open(datafiles["img"]).convert('RGB') 41 | name = datafiles["name"] 42 | 43 | # resize 44 | w, h = image.size 45 | 46 | image = np.asarray(image, np.float32) 47 | 48 | size = image.shape 49 | image = image[:, :, ::-1] # change to BGR 50 | image -= self.mean 51 | image = image.transpose((2, 0, 1)) 52 | 53 | return image.copy(), np.array(size), name 54 | 55 | 56 | if __name__ == '__main__': 57 | dst = GTA5DataSet("./data", is_transform=True) 58 | trainloader = data.DataLoader(dst, batch_size=4) 59 | for i, data in enumerate(trainloader): 60 | imgs, labels = data 61 | if i == 0: 62 | img = torchvision.utils.make_grid(imgs).numpy() 63 | img = np.transpose(img, (1, 2, 0)) 64 | img = img[:, :, ::-1] 65 | plt.imshow(img) 66 | plt.show() 67 | -------------------------------------------------------------------------------- /dataset/cityscapes_list/clear_lindau.txt: -------------------------------------------------------------------------------- 1 | ../val/lindau/lindau_000009_000019_leftImg8bit.png 2 | ../val/lindau/lindau_000037_000019_leftImg8bit.png 3 | ../val/lindau/lindau_000015_000019_leftImg8bit.png 4 | ../val/lindau/lindau_000030_000019_leftImg8bit.png 5 | ../val/lindau/lindau_000012_000019_leftImg8bit.png 6 | ../val/lindau/lindau_000032_000019_leftImg8bit.png 7 | ../val/lindau/lindau_000000_000019_leftImg8bit.png 8 | ../val/lindau/lindau_000031_000019_leftImg8bit.png 9 | ../val/lindau/lindau_000011_000019_leftImg8bit.png 10 | ../val/lindau/lindau_000027_000019_leftImg8bit.png 11 | ../val/lindau/lindau_000026_000019_leftImg8bit.png 12 | ../val/lindau/lindau_000017_000019_leftImg8bit.png 13 | ../val/lindau/lindau_000023_000019_leftImg8bit.png 14 | ../val/lindau/lindau_000005_000019_leftImg8bit.png 15 | ../val/lindau/lindau_000025_000019_leftImg8bit.png 16 | ../val/lindau/lindau_000014_000019_leftImg8bit.png 17 | ../val/lindau/lindau_000004_000019_leftImg8bit.png 18 | ../val/lindau/lindau_000021_000019_leftImg8bit.png 19 | ../val/lindau/lindau_000033_000019_leftImg8bit.png 20 | ../val/lindau/lindau_000013_000019_leftImg8bit.png 21 | ../val/lindau/lindau_000024_000019_leftImg8bit.png 22 | ../val/lindau/lindau_000002_000019_leftImg8bit.png 23 | ../val/lindau/lindau_000016_000019_leftImg8bit.png 24 | ../val/lindau/lindau_000018_000019_leftImg8bit.png 25 | ../val/lindau/lindau_000007_000019_leftImg8bit.png 26 | ../val/lindau/lindau_000022_000019_leftImg8bit.png 27 | ../val/lindau/lindau_000038_000019_leftImg8bit.png 28 | ../val/lindau/lindau_000001_000019_leftImg8bit.png 29 | ../val/lindau/lindau_000036_000019_leftImg8bit.png 30 | ../val/lindau/lindau_000035_000019_leftImg8bit.png 31 | ../val/lindau/lindau_000003_000019_leftImg8bit.png 32 | ../val/lindau/lindau_000034_000019_leftImg8bit.png 33 | ../val/lindau/lindau_000010_000019_leftImg8bit.png 34 | ../val/lindau/lindau_000006_000019_leftImg8bit.png 35 | ../val/lindau/lindau_000019_000019_leftImg8bit.png 36 | ../val/lindau/lindau_000029_000019_leftImg8bit.png 37 | ../val/lindau/lindau_000039_000019_leftImg8bit.png 38 | ../val/lindau/lindau_000020_000019_leftImg8bit.png 39 | ../val/lindau/lindau_000028_000019_leftImg8bit.png 40 | ../val/lindau/lindau_000008_000019_leftImg8bit.png -------------------------------------------------------------------------------- /dataset/cityscapes_list/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "classes":19, 3 | "label2train":[ 4 | [0, 255], 5 | [1, 255], 6 | [2, 255], 7 | [3, 255], 8 | [4, 255], 9 | [5, 255], 10 | [6, 255], 11 | [7, 0], 12 | [8, 1], 13 | [9, 255], 14 | [10, 255], 15 | [11, 2], 16 | [12, 3], 17 | [13, 4], 18 | [14, 255], 19 | [15, 255], 20 | [16, 255], 21 | [17, 5], 22 | [18, 255], 23 | [19, 6], 24 | [20, 7], 25 | [21, 8], 26 | [22, 9], 27 | [23, 10], 28 | [24, 11], 29 | [25, 12], 30 | [26, 13], 31 | [27, 14], 32 | [28, 15], 33 | [29, 255], 34 | [30, 255], 35 | [31, 16], 36 | [32, 17], 37 | [33, 18], 38 | [-1, 255]], 39 | "label":[ 40 | "road", 41 | "sidewalk", 42 | "building", 43 | "wall", 44 | "fence", 45 | "pole", 46 | "light", 47 | "sign", 48 | "vegetation", 49 | "terrain", 50 | "sky", 51 | "person", 52 | "rider", 53 | "car", 54 | "truck", 55 | "bus", 56 | "train", 57 | "motocycle", 58 | "bicycle"], 59 | "palette":[ 60 | [128,64,128], 61 | [244,35,232], 62 | [70,70,70], 63 | [102,102,156], 64 | [190,153,153], 65 | [153,153,153], 66 | [250,170,30], 67 | [220,220,0], 68 | [107,142,35], 69 | [152,251,152], 70 | [70,130,180], 71 | [220,20,60], 72 | [255,0,0], 73 | [0,0,142], 74 | [0,0,70], 75 | [0,60,100], 76 | [0,80,100], 77 | [0,0,230], 78 | [119,11,32], 79 | [0,0,0]], 80 | "mean":[ 81 | 73.158359210711552, 82 | 82.908917542625858, 83 | 72.392398761941593], 84 | "std":[ 85 | 47.675755341814678, 86 | 48.494214368814916, 87 | 47.736546325441594] 88 | } 89 | -------------------------------------------------------------------------------- /dataset/cityscapes_list/label_lindau.txt: -------------------------------------------------------------------------------- 1 | val/lindau/lindau_000009_000019_gtFine_labelIds.png 2 | val/lindau/lindau_000037_000019_gtFine_labelIds.png 3 | val/lindau/lindau_000015_000019_gtFine_labelIds.png 4 | val/lindau/lindau_000030_000019_gtFine_labelIds.png 5 | val/lindau/lindau_000012_000019_gtFine_labelIds.png 6 | val/lindau/lindau_000032_000019_gtFine_labelIds.png 7 | val/lindau/lindau_000000_000019_gtFine_labelIds.png 8 | val/lindau/lindau_000031_000019_gtFine_labelIds.png 9 | val/lindau/lindau_000011_000019_gtFine_labelIds.png 10 | val/lindau/lindau_000027_000019_gtFine_labelIds.png 11 | val/lindau/lindau_000026_000019_gtFine_labelIds.png 12 | val/lindau/lindau_000017_000019_gtFine_labelIds.png 13 | val/lindau/lindau_000023_000019_gtFine_labelIds.png 14 | val/lindau/lindau_000005_000019_gtFine_labelIds.png 15 | val/lindau/lindau_000025_000019_gtFine_labelIds.png 16 | val/lindau/lindau_000014_000019_gtFine_labelIds.png 17 | val/lindau/lindau_000004_000019_gtFine_labelIds.png 18 | val/lindau/lindau_000021_000019_gtFine_labelIds.png 19 | val/lindau/lindau_000033_000019_gtFine_labelIds.png 20 | val/lindau/lindau_000013_000019_gtFine_labelIds.png 21 | val/lindau/lindau_000024_000019_gtFine_labelIds.png 22 | val/lindau/lindau_000002_000019_gtFine_labelIds.png 23 | val/lindau/lindau_000016_000019_gtFine_labelIds.png 24 | val/lindau/lindau_000018_000019_gtFine_labelIds.png 25 | val/lindau/lindau_000007_000019_gtFine_labelIds.png 26 | val/lindau/lindau_000022_000019_gtFine_labelIds.png 27 | val/lindau/lindau_000038_000019_gtFine_labelIds.png 28 | val/lindau/lindau_000001_000019_gtFine_labelIds.png 29 | val/lindau/lindau_000036_000019_gtFine_labelIds.png 30 | val/lindau/lindau_000035_000019_gtFine_labelIds.png 31 | val/lindau/lindau_000003_000019_gtFine_labelIds.png 32 | val/lindau/lindau_000034_000019_gtFine_labelIds.png 33 | val/lindau/lindau_000010_000019_gtFine_labelIds.png 34 | val/lindau/lindau_000006_000019_gtFine_labelIds.png 35 | val/lindau/lindau_000019_000019_gtFine_labelIds.png 36 | val/lindau/lindau_000029_000019_gtFine_labelIds.png 37 | val/lindau/lindau_000039_000019_gtFine_labelIds.png 38 | val/lindau/lindau_000020_000019_gtFine_labelIds.png 39 | val/lindau/lindau_000028_000019_gtFine_labelIds.png 40 | val/lindau/lindau_000008_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /dataset/foggy_driving.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | from torch.utils import data 10 | from PIL import Image 11 | from os.path import join 12 | import json 13 | import scipy.misc as m 14 | 15 | class foggydrivingDataSet(data.Dataset): 16 | colors = [ 17 | [128, 64, 128], 18 | [244, 35, 232], 19 | [70, 70, 70], 20 | [102, 102, 156], 21 | [190, 153, 153], 22 | [153, 153, 153], 23 | [250, 170, 30], 24 | [220, 220, 0], 25 | [107, 142, 35], 26 | [152, 251, 152], 27 | [0, 130, 180], 28 | [220, 20, 60], 29 | [255, 0, 0], 30 | [0, 0, 142], 31 | [0, 0, 70], 32 | [0, 60, 100], 33 | [0, 80, 100], 34 | [0, 0, 230], 35 | [119, 11, 32], 36 | ] 37 | 38 | def __init__(self, root, list_path, max_iters=None, mean=(104.00698793, 116.66876762, 122.67891434), scale=None): 39 | self.root = root 40 | self.list_path = list_path 41 | self.mean = mean 42 | self.scale = scale 43 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 44 | if not max_iters==None: 45 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 46 | self.files = [] 47 | 48 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 49 | self.valid_classes = [ 50 | 7, 51 | 8, 52 | 11, 53 | 12, 54 | 13, 55 | 17, 56 | 19, 57 | 20, 58 | 21, 59 | 22, 60 | 23, 61 | 24, 62 | 25, 63 | 26, 64 | 27, 65 | 28, 66 | 31, 67 | 32, 68 | 33, 69 | ] 70 | self.class_names = [ 71 | "unlabelled", 72 | "road", 73 | "sidewalk", 74 | "building", 75 | "wall", 76 | "fence", 77 | "pole", 78 | "traffic_light", 79 | "traffic_sign", 80 | "vegetation", 81 | "terrain", 82 | "sky", 83 | "person", 84 | "rider", 85 | "car", 86 | "truck", 87 | "bus", 88 | "train", 89 | "motorcycle", 90 | "bicycle", 91 | ] 92 | 93 | self.ignore_index = 255 94 | 95 | for name in self.img_ids: 96 | img_file = osp.join(self.root, name) 97 | self.files.append({ 98 | "img": img_file, 99 | "name": name 100 | }) 101 | 102 | def __len__(self): 103 | return len(self.files) 104 | 105 | def __getitem__(self, index): 106 | datafiles = self.files[index] 107 | 108 | image = Image.open(datafiles["img"]).convert('RGB') 109 | name = datafiles["name"] 110 | 111 | if self.scale != 1: 112 | w, h = image.size 113 | new_size = (int(w*self.scale), int(h*self.scale)) 114 | image = image.resize(new_size, Image.BICUBIC) 115 | image = np.asarray(image, np.float32) 116 | 117 | size = image.shape 118 | image = image[:, :, ::-1] # change to BGR 119 | image -= self.mean 120 | image = image.transpose((2, 0, 1)) 121 | 122 | return image.copy(), np.array(size), name 123 | 124 | def encode_segmap(self, mask): 125 | # Put all void classes to zero 126 | for _voidc in self.void_classes: 127 | mask[mask == _voidc] = self.ignore_index 128 | for _validc in self.valid_classes: 129 | mask[mask == _validc] = self.class_map[_validc] 130 | return mask 131 | 132 | if __name__ == '__main__': 133 | dst = foggydrivingDataSet("/root/data1", is_transform=True) 134 | trainloader = data.DataLoader(dst, batch_size=4) 135 | for i, data in enumerate(trainloader): 136 | imgs, labels = data 137 | if i == 0: 138 | img = torchvision.utils.make_grid(imgs).numpy() 139 | img = np.transpose(img, (1, 2, 0)) 140 | img = img[:, :, ::-1] 141 | plt.imshow(img) 142 | plt.show() 143 | -------------------------------------------------------------------------------- /lists_file_names/gt_testdense_filenames.txt: -------------------------------------------------------------------------------- 1 | gtCoarse/test_extra/web/web_8050286343_fb56b08ce7_gtCoarse_labelIds.png 2 | gtCoarse/test_extra/web/web_8546258268_5aabe6779e_gtCoarse_labelIds.png 3 | gtCoarse/test_extra/web/web_tumblr_n3wkxqv7Xs1sp7ozko1_gtCoarse_labelIds.png 4 | gtCoarse/test_extra/web/web_3959112536_3a1112e59f_gtCoarse_labelIds.png 5 | gtCoarse/test_extra/web/web_5378529177_51ff165684_gtCoarse_labelIds.png 6 | gtCoarse/test_extra/web/web_12072159786_66ec504c8d_gtCoarse_labelIds.png 7 | gtCoarse/test_extra/web/web_2144738053_f1a3fb42db_gtCoarse_labelIds.png 8 | gtCoarse/test_extra/web/web_Foggy_Portland_gtCoarse_labelIds.png 9 | gtCoarse/test_extra/web/web_6925417403_0d577e62fc_gtCoarse_labelIds.png 10 | gtCoarse/test_extra/web/web_2460794106_dbabe29f49_gtCoarse_labelIds.png 11 | gtCoarse/test_extra/web/web_tim_thumb_gtCoarse_labelIds.png 12 | gtCoarse/test_extra/web/web_U137P200T1D308834F12DT_20100315001722_gtCoarse_labelIds.png 13 | gtCoarse/test_extra/web/web_9443635508_ccb844c48d_gtCoarse_labelIds.png 14 | gtCoarse/test_extra/web/web_5402273377_00e789f3ce_gtCoarse_labelIds.png 15 | gtCoarse/test_extra/web/web_fog_tcr_gtCoarse_labelIds.png 16 | gtCoarse/test_extra/web/web_012217_fogalbion4_gtCoarse_labelIds.png 17 | gtCoarse/test_extra/web/web_33032935975_a8b8617c86_gtCoarse_labelIds.png 18 | gtCoarse/test_extra/web/web_fog_5_gtCoarse_labelIds.png 19 | gtCoarse/test_extra/web/web_8263125502_1fc2c86cc1_gtCoarse_labelIds.png 20 | gtCoarse/test_extra/web/web_4247692808_e8c331417d_gtCoarse_labelIds.png 21 | gtCoarse/test_extra/pedestrian/pedestrian_20170225_094934_gtCoarse_labelIds.png 22 | -------------------------------------------------------------------------------- /lists_file_names/leftImg8bit_testdense_filenames.txt: -------------------------------------------------------------------------------- 1 | leftImg8bit/test_extra/web/web_8050286343_fb56b08ce7_leftImg8bit.png 2 | leftImg8bit/test_extra/web/web_8546258268_5aabe6779e_leftImg8bit.png 3 | leftImg8bit/test_extra/web/web_tumblr_n3wkxqv7Xs1sp7ozko1_leftImg8bit.png 4 | leftImg8bit/test_extra/web/web_3959112536_3a1112e59f_leftImg8bit.png 5 | leftImg8bit/test_extra/web/web_5378529177_51ff165684_leftImg8bit.png 6 | leftImg8bit/test_extra/web/web_12072159786_66ec504c8d_leftImg8bit.png 7 | leftImg8bit/test_extra/web/web_2144738053_f1a3fb42db_leftImg8bit.png 8 | leftImg8bit/test_extra/web/web_Foggy_Portland_leftImg8bit.png 9 | leftImg8bit/test_extra/web/web_6925417403_0d577e62fc_leftImg8bit.png 10 | leftImg8bit/test_extra/web/web_2460794106_dbabe29f49_leftImg8bit.png 11 | leftImg8bit/test_extra/web/web_tim_thumb_leftImg8bit.png 12 | leftImg8bit/test_extra/web/web_U137P200T1D308834F12DT_20100315001722_leftImg8bit.png 13 | leftImg8bit/test_extra/web/web_9443635508_ccb844c48d_leftImg8bit.png 14 | leftImg8bit/test_extra/web/web_5402273377_00e789f3ce_leftImg8bit.png 15 | leftImg8bit/test_extra/web/web_fog_tcr_leftImg8bit.png 16 | leftImg8bit/test_extra/web/web_012217_fogalbion4_leftImg8bit.png 17 | leftImg8bit/test_extra/web/web_33032935975_a8b8617c86_leftImg8bit.png 18 | leftImg8bit/test_extra/web/web_fog_5_leftImg8bit.png 19 | leftImg8bit/test_extra/web/web_8263125502_1fc2c86cc1_leftImg8bit.png 20 | leftImg8bit/test_extra/web/web_4247692808_e8c331417d_leftImg8bit.png 21 | leftImg8bit/test_extra/pedestrian/pedestrian_20170225_094934_leftImg8bit.png 22 | -------------------------------------------------------------------------------- /lists_file_names/leftImg8bit_testfine_filenames.txt: -------------------------------------------------------------------------------- 1 | leftImg8bit/test/public/public_20161213_083206_leftImg8bit.png 2 | leftImg8bit/test/public/public_20161213_082948_leftImg8bit.png 3 | leftImg8bit/test/public/public_20161213_081958_leftImg8bit.png 4 | leftImg8bit/test/public/public_20161213_081809_leftImg8bit.png 5 | leftImg8bit/test/public/public_20161213_082944_leftImg8bit.png 6 | leftImg8bit/test/public/public_20161213_081955_leftImg8bit.png 7 | leftImg8bit/test/public/public_20161213_082223_leftImg8bit.png 8 | leftImg8bit/test/public/public_20161213_082805_leftImg8bit.png 9 | leftImg8bit/test/public/public_20161213_082937_leftImg8bit.png 10 | leftImg8bit/test/public/public_20161213_081803_leftImg8bit.png 11 | leftImg8bit/test/public/public_20161213_082952_leftImg8bit.png 12 | leftImg8bit/test/public/public_20161213_081801_leftImg8bit.png 13 | leftImg8bit/test/public/public_20161213_082426_leftImg8bit.png 14 | leftImg8bit/test/public/public_20161213_082012_leftImg8bit.png 15 | leftImg8bit/test/public/public_20161213_081950_leftImg8bit.png 16 | leftImg8bit/test/public/public_20161213_081724_leftImg8bit.png 17 | leftImg8bit/test/public/public_20161213_081721_leftImg8bit.png 18 | leftImg8bit/test/public/public_20161213_082117_leftImg8bit.png 19 | leftImg8bit/test/public/public_20161213_082109_leftImg8bit.png 20 | leftImg8bit/test/public/public_20161213_082746_leftImg8bit.png 21 | leftImg8bit/test/public/public_20161213_082429_leftImg8bit.png 22 | leftImg8bit/test/public/public_20161213_083008_leftImg8bit.png 23 | leftImg8bit/test/public/public_20161213_081946_leftImg8bit.png 24 | leftImg8bit/test/public/public_20161213_081942_leftImg8bit.png 25 | leftImg8bit/test/public/public_20161213_082127_leftImg8bit.png 26 | leftImg8bit/test/public/public_20161213_081817_leftImg8bit.png 27 | leftImg8bit/test/public/public_20161213_082800_leftImg8bit.png 28 | leftImg8bit/test/pedestrian/pedestrian_20161201_101324_leftImg8bit.png 29 | leftImg8bit/test/pedestrian/pedestrian_20161201_102320_leftImg8bit.png 30 | leftImg8bit/test/pedestrian/pedestrian_20161201_101436_leftImg8bit.png 31 | leftImg8bit/test/pedestrian/pedestrian_20161201_101441_leftImg8bit.png 32 | leftImg8bit/test/pedestrian/pedestrian_20161212_094945_leftImg8bit.png 33 | leftImg8bit/test/pedestrian/pedestrian_20161212_095022_leftImg8bit.png 34 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/model/__init__.py -------------------------------------------------------------------------------- /model/fogpassfilter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FogPassFilter_conv1(nn.Module): 7 | def __init__(self, inputsize): 8 | super(FogPassFilter_conv1, self).__init__() 9 | 10 | self.hidden = nn.Linear(inputsize, inputsize//2) 11 | self.hidden2 = nn.Linear(inputsize//2, inputsize//4) 12 | self.output = nn.Linear(inputsize//4, 64) 13 | self.leakyrelu = nn.LeakyReLU() 14 | 15 | def forward(self, x): 16 | x = self.hidden(x) 17 | x = self.leakyrelu(x) 18 | x = self.hidden2(x) 19 | x = self.leakyrelu(x) 20 | x = self.output(x) 21 | 22 | return x 23 | 24 | class FogPassFilter_res1(nn.Module): 25 | def __init__(self, inputsize): 26 | super(FogPassFilter_res1, self).__init__() 27 | 28 | self.hidden = nn.Linear(inputsize, inputsize//8) 29 | self.output = nn.Linear(inputsize//8, 64) 30 | self.leakyrelu = nn.LeakyReLU() 31 | 32 | def forward(self, x): 33 | x = self.hidden(x) 34 | x = self.leakyrelu(x) 35 | x = self.output(x) 36 | 37 | return x 38 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url -------------------------------------------------------------------------------- /pytorch_metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.9.99" 2 | -------------------------------------------------------------------------------- /pytorch_metric_learning/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_distance import BaseDistance 2 | from .cosine_similarity import CosineSimilarity 3 | from .dot_product_similarity import DotProductSimilarity 4 | from .lp_distance import LpDistance 5 | from .snr_distance import SNRDistance 6 | -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/base_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/base_distance.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/cosine_similarity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/cosine_similarity.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/dot_product_similarity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/dot_product_similarity.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/lp_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/lp_distance.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/__pycache__/snr_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/distances/__pycache__/snr_distance.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/base_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils.module_with_records import ModuleWithRecords 4 | 5 | 6 | class BaseDistance(ModuleWithRecords): 7 | def __init__( 8 | self, normalize_embeddings=True, p=2, power=1, is_inverted=False, **kwargs 9 | ): 10 | super().__init__(**kwargs) 11 | self.normalize_embeddings = normalize_embeddings 12 | self.p = p 13 | self.power = power 14 | self.is_inverted = is_inverted 15 | self.add_to_recordable_attributes(list_of_names=["p", "power"], is_stat=False) 16 | 17 | def forward(self, query_emb, ref_emb=None): 18 | self.reset_stats() 19 | query_emb_normalized = self.maybe_normalize(query_emb) 20 | if ref_emb is None: 21 | ref_emb = query_emb 22 | ref_emb_normalized = query_emb_normalized 23 | else: 24 | ref_emb_normalized = self.maybe_normalize(ref_emb) 25 | self.set_default_stats( 26 | query_emb, ref_emb, query_emb_normalized, ref_emb_normalized 27 | ) 28 | mat = self.compute_mat(query_emb_normalized, ref_emb_normalized) 29 | if self.power != 1: 30 | mat = mat ** self.power 31 | assert mat.size() == torch.Size((query_emb.size(0), ref_emb.size(0))) 32 | return mat 33 | 34 | def compute_mat(self, query_emb, ref_emb): 35 | raise NotImplementedError 36 | 37 | def pairwise_distance(self, query_emb, ref_emb): 38 | raise NotImplementedError 39 | 40 | def smallest_dist(self, *args, **kwargs): 41 | if self.is_inverted: 42 | return torch.max(*args, **kwargs) 43 | return torch.min(*args, **kwargs) 44 | 45 | def largest_dist(self, *args, **kwargs): 46 | if self.is_inverted: 47 | return torch.min(*args, **kwargs) 48 | return torch.max(*args, **kwargs) 49 | 50 | # This measures the margin between x and y 51 | def margin(self, x, y): 52 | if self.is_inverted: 53 | return y - x 54 | return x - y 55 | 56 | def normalize(self, embeddings, dim=1, **kwargs): 57 | return torch.nn.functional.normalize(embeddings, p=self.p, dim=dim, **kwargs) 58 | 59 | def maybe_normalize(self, embeddings, dim=1, **kwargs): 60 | if self.normalize_embeddings: 61 | return self.normalize(embeddings, dim=dim, **kwargs) 62 | return embeddings 63 | 64 | def get_norm(self, embeddings, dim=1, **kwargs): 65 | return torch.norm(embeddings, p=self.p, dim=dim, **kwargs) 66 | 67 | def set_default_stats( 68 | self, query_emb, ref_emb, query_emb_normalized, ref_emb_normalized 69 | ): 70 | if self.collect_stats: 71 | with torch.no_grad(): 72 | stats_dict = { 73 | "initial_avg_query_norm": torch.mean( 74 | self.get_norm(query_emb) 75 | ).item(), 76 | "initial_avg_ref_norm": torch.mean(self.get_norm(ref_emb)).item(), 77 | "final_avg_query_norm": torch.mean( 78 | self.get_norm(query_emb_normalized) 79 | ).item(), 80 | "final_avg_ref_norm": torch.mean( 81 | self.get_norm(ref_emb_normalized) 82 | ).item(), 83 | } 84 | self.set_stats(stats_dict) 85 | 86 | def set_stats(self, stats_dict): 87 | for k, v in stats_dict.items(): 88 | self.add_to_recordable_attributes(name=k, is_stat=True) 89 | setattr(self, k, v) 90 | -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/cosine_similarity.py: -------------------------------------------------------------------------------- 1 | from .dot_product_similarity import DotProductSimilarity 2 | 3 | 4 | class CosineSimilarity(DotProductSimilarity): 5 | def __init__(self, **kwargs): 6 | super().__init__(normalize_embeddings=True, **kwargs) 7 | assert self.is_inverted 8 | assert self.normalize_embeddings 9 | -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/dot_product_similarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_distance import BaseDistance 4 | 5 | 6 | class DotProductSimilarity(BaseDistance): 7 | def __init__(self, **kwargs): 8 | super().__init__(is_inverted=True, **kwargs) 9 | assert self.is_inverted 10 | 11 | def compute_mat(self, query_emb, ref_emb): 12 | return torch.matmul(query_emb, ref_emb.t()) 13 | 14 | def pairwise_distance(self, query_emb, ref_emb): 15 | return torch.sum(query_emb * ref_emb, dim=1) 16 | -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/lp_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .base_distance import BaseDistance 5 | 6 | 7 | class LpDistance(BaseDistance): 8 | def __init__(self, **kwargs): 9 | super().__init__(**kwargs) 10 | assert not self.is_inverted 11 | 12 | def compute_mat(self, query_emb, ref_emb): 13 | dtype, device = query_emb.dtype, query_emb.device 14 | if ref_emb is None: 15 | ref_emb = query_emb 16 | if dtype == torch.float16: # cdist doesn't work for float16 17 | rows, cols = lmu.meshgrid_from_sizes(query_emb, ref_emb, dim=0) 18 | output = torch.zeros(rows.size(), dtype=dtype, device=device) 19 | rows, cols = rows.flatten(), cols.flatten() 20 | distances = self.pairwise_distance(query_emb[rows], ref_emb[cols]) 21 | output[rows, cols] = distances 22 | return output 23 | else: 24 | return torch.cdist(query_emb, ref_emb, p=self.p) 25 | 26 | def pairwise_distance(self, query_emb, ref_emb): 27 | return torch.nn.functional.pairwise_distance(query_emb, ref_emb, p=self.p) 28 | -------------------------------------------------------------------------------- /pytorch_metric_learning/distances/snr_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_distance import BaseDistance 4 | 5 | 6 | # Signal to Noise Ratio 7 | class SNRDistance(BaseDistance): 8 | def __init__(self, **kwargs): 9 | super().__init__(**kwargs) 10 | assert not self.is_inverted 11 | 12 | def compute_mat(self, query_emb, ref_emb): 13 | anchor_variances = torch.var(query_emb, dim=1) 14 | pairwise_diffs = query_emb.unsqueeze(1) - ref_emb 15 | pairwise_variances = torch.var(pairwise_diffs, dim=2) 16 | return pairwise_variances / (anchor_variances.unsqueeze(1)) 17 | 18 | def pairwise_distance(self, query_emb, ref_emb): 19 | query_var = torch.var(query_emb, dim=1) 20 | query_ref_var = torch.var(query_emb - ref_emb, dim=1) 21 | return query_ref_var / query_var 22 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .angular_loss import AngularLoss 2 | from .arcface_loss import ArcFaceLoss 3 | from .base_metric_loss_function import BaseMetricLossFunction, MultipleLosses 4 | from .circle_loss import CircleLoss 5 | from .contrastive_loss import ContrastiveLoss 6 | from .cosface_loss import CosFaceLoss 7 | from .cross_batch_memory import CrossBatchMemory 8 | from .fast_ap_loss import FastAPLoss 9 | from .generic_pair_loss import GenericPairLoss 10 | from .intra_pair_variance_loss import IntraPairVarianceLoss 11 | from .large_margin_softmax_loss import LargeMarginSoftmaxLoss 12 | from .lifted_structure_loss import GeneralizedLiftedStructureLoss, LiftedStructureLoss 13 | from .margin_loss import MarginLoss 14 | from .mixins import EmbeddingRegularizerMixin, WeightRegularizerMixin 15 | from .multi_similarity_loss import MultiSimilarityLoss 16 | from .n_pairs_loss import NPairsLoss 17 | from .nca_loss import NCALoss 18 | from .normalized_softmax_loss import NormalizedSoftmaxLoss 19 | from .ntxent_loss import NTXentLoss 20 | from .proxy_anchor_loss import ProxyAnchorLoss 21 | from .proxy_losses import ProxyNCALoss 22 | from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss 23 | from .soft_triple_loss import SoftTripleLoss 24 | from .sphereface_loss import SphereFaceLoss 25 | from .supcon_loss import SupConLoss 26 | from .triplet_margin_loss import TripletMarginLoss 27 | from .tuplet_margin_loss import TupletMarginLoss 28 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/angular_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/angular_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/arcface_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/arcface_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/base_metric_loss_function.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/base_metric_loss_function.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/circle_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/circle_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/contrastive_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/contrastive_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/cosface_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/cosface_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/cross_batch_memory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/cross_batch_memory.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/fast_ap_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/fast_ap_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/generic_pair_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/generic_pair_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/intra_pair_variance_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/intra_pair_variance_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/large_margin_softmax_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/large_margin_softmax_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/lifted_structure_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/lifted_structure_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/margin_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/margin_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/mixins.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/mixins.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/multi_similarity_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/multi_similarity_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/n_pairs_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/n_pairs_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/nca_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/nca_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/normalized_softmax_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/normalized_softmax_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/ntxent_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/ntxent_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/proxy_anchor_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/proxy_anchor_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/proxy_losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/proxy_losses.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/signal_to_noise_ratio_losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/signal_to_noise_ratio_losses.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/soft_triple_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/soft_triple_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/sphereface_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/sphereface_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/supcon_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/supcon_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/triplet_margin_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/triplet_margin_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/__pycache__/tuplet_margin_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/losses/__pycache__/tuplet_margin_loss.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/angular_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..distances import LpDistance 5 | from ..utils import common_functions as c_f 6 | from ..utils import loss_and_miner_utils as lmu 7 | from .base_metric_loss_function import BaseMetricLossFunction 8 | 9 | 10 | class AngularLoss(BaseMetricLossFunction): 11 | """ 12 | Implementation of https://arxiv.org/abs/1708.01682 13 | Args: 14 | alpha: The angle (as described in the paper), specified in degrees. 15 | """ 16 | 17 | def __init__(self, alpha=40, **kwargs): 18 | super().__init__(**kwargs) 19 | c_f.assert_distance_type( 20 | self, LpDistance, p=2, power=1, normalize_embeddings=True 21 | ) 22 | self.alpha = torch.tensor(np.radians(alpha)) 23 | self.add_to_recordable_attributes(list_of_names=["alpha"], is_stat=False) 24 | self.add_to_recordable_attributes(list_of_names=["average_angle"], is_stat=True) 25 | 26 | def compute_loss(self, embeddings, labels, indices_tuple): 27 | anchors, positives, keep_mask, anchor_idx = self.get_pairs( 28 | embeddings, labels, indices_tuple 29 | ) 30 | if anchors is None: 31 | return self.zero_losses() 32 | 33 | sq_tan_alpha = torch.tan(self.alpha) ** 2 34 | ap_dot = torch.sum(anchors * positives, dim=1, keepdim=True) 35 | ap_matmul_embeddings = torch.matmul( 36 | (anchors + positives), (embeddings.unsqueeze(2)) 37 | ) 38 | ap_matmul_embeddings = ap_matmul_embeddings.squeeze(2).t() 39 | 40 | final_form = (4 * sq_tan_alpha * ap_matmul_embeddings) - ( 41 | 2 * (1 + sq_tan_alpha) * ap_dot 42 | ) 43 | losses = lmu.logsumexp(final_form, keep_mask=keep_mask, add_one=True) 44 | return { 45 | "loss": { 46 | "losses": losses, 47 | "indices": anchor_idx, 48 | "reduction_type": "element", 49 | } 50 | } 51 | 52 | def get_pairs(self, embeddings, labels, indices_tuple): 53 | a1, p, a2, _ = lmu.convert_to_pairs(indices_tuple, labels) 54 | if len(a1) == 0 or len(a2) == 0: 55 | return [None] * 4 56 | anchors = self.distance.normalize(embeddings[a1]) 57 | positives = self.distance.normalize(embeddings[p]) 58 | keep_mask = labels[a1].unsqueeze(1) != labels.unsqueeze(0) 59 | self.set_stats(anchors, positives, embeddings, keep_mask) 60 | return anchors, positives, keep_mask, a1 61 | 62 | def set_stats(self, anchors, positives, embeddings, keep_mask): 63 | if self.collect_stats: 64 | with torch.no_grad(): 65 | centers = (anchors + positives) / 2 66 | ap_dist = self.distance.pairwise_distance(anchors, positives) 67 | nc_dist = self.distance.get_norm( 68 | centers - embeddings.unsqueeze(1), dim=2 69 | ).t() 70 | angles = torch.atan(ap_dist.unsqueeze(1) / (2 * nc_dist)) 71 | average_angle = torch.sum(angles[keep_mask]) / torch.sum(keep_mask) 72 | self.average_angle = np.degrees(average_angle.item()) 73 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/arcface_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..utils import common_functions as c_f 5 | from .large_margin_softmax_loss import LargeMarginSoftmaxLoss 6 | 7 | 8 | class ArcFaceLoss(LargeMarginSoftmaxLoss): 9 | """ 10 | Implementation of https://arxiv.org/pdf/1801.07698.pdf 11 | """ 12 | 13 | def __init__(self, *args, margin=28.6, scale=64, **kwargs): 14 | super().__init__(*args, margin=margin, scale=scale, **kwargs) 15 | 16 | def init_margin(self): 17 | self.margin = np.radians(self.margin) 18 | 19 | def cast_types(self, dtype, device): 20 | self.W.data = c_f.to_device(self.W.data, device=device, dtype=dtype) 21 | 22 | def modify_cosine_of_target_classes(self, cosine_of_target_classes): 23 | angles = self.get_angles(cosine_of_target_classes) 24 | return torch.cos(angles + self.margin) 25 | 26 | def scale_logits(self, logits, *_): 27 | return logits * self.scale 28 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..reducers import AvgNonZeroReducer 5 | from ..utils import common_functions as c_f 6 | from ..utils import loss_and_miner_utils as lmu 7 | from .generic_pair_loss import GenericPairLoss 8 | 9 | 10 | class CircleLoss(GenericPairLoss): 11 | """ 12 | Circle loss for pairwise labels only. 13 | 14 | Args: 15 | m: The relaxation factor that controls the radious of the decision boundary. 16 | gamma: The scale factor that determines the largest scale of each similarity score. 17 | 18 | According to the paper, the suggested default values of m and gamma are: 19 | 20 | Face Recognition: m = 0.25, gamma = 256 21 | Person Reidentification: m = 0.25, gamma = 128 22 | Fine-grained Image Retrieval: m = 0.4, gamma = 80 23 | 24 | By default, we set m = 0.4 and gamma = 80 25 | """ 26 | 27 | def __init__(self, m=0.4, gamma=80, **kwargs): 28 | super().__init__(mat_based_loss=True, **kwargs) 29 | c_f.assert_distance_type(self, CosineSimilarity) 30 | self.m = m 31 | self.gamma = gamma 32 | self.soft_plus = torch.nn.Softplus(beta=1) 33 | self.op = 1 + self.m 34 | self.on = -self.m 35 | self.delta_p = 1 - self.m 36 | self.delta_n = self.m 37 | self.add_to_recordable_attributes( 38 | list_of_names=["m", "gamma", "op", "on", "delta_p", "delta_n"], 39 | is_stat=False, 40 | ) 41 | 42 | def _compute_loss(self, mat, pos_mask, neg_mask): 43 | pos_mask_bool = pos_mask.bool() 44 | neg_mask_bool = neg_mask.bool() 45 | anchor_positive = mat[pos_mask_bool] 46 | anchor_negative = mat[neg_mask_bool] 47 | new_mat = torch.zeros_like(mat) 48 | 49 | new_mat[pos_mask_bool] = ( 50 | -self.gamma 51 | * torch.relu(self.op - anchor_positive.detach()) 52 | * (anchor_positive - self.delta_p) 53 | ) 54 | new_mat[neg_mask_bool] = ( 55 | self.gamma 56 | * torch.relu(anchor_negative.detach() - self.on) 57 | * (anchor_negative - self.delta_n) 58 | ) 59 | 60 | logsumexp_pos = lmu.logsumexp( 61 | new_mat, keep_mask=pos_mask_bool, add_one=False, dim=1 62 | ) 63 | logsumexp_neg = lmu.logsumexp( 64 | new_mat, keep_mask=neg_mask_bool, add_one=False, dim=1 65 | ) 66 | 67 | losses = self.soft_plus(logsumexp_pos + logsumexp_neg) 68 | 69 | zero_rows = torch.where( 70 | (torch.sum(pos_mask, dim=1) == 0) | (torch.sum(neg_mask, dim=1) == 0) 71 | )[0] 72 | final_mask = torch.ones_like(losses) 73 | final_mask[zero_rows] = 0 74 | losses = losses * final_mask 75 | return { 76 | "loss": { 77 | "losses": losses, 78 | "indices": c_f.torch_arange_from_size(new_mat), 79 | "reduction_type": "element", 80 | } 81 | } 82 | 83 | def get_default_reducer(self): 84 | return AvgNonZeroReducer() 85 | 86 | def get_default_distance(self): 87 | return CosineSimilarity() 88 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..reducers import AvgNonZeroReducer 4 | from ..utils import loss_and_miner_utils as lmu 5 | from .generic_pair_loss import GenericPairLoss 6 | 7 | 8 | class ContrastiveLoss(GenericPairLoss): 9 | def __init__(self, pos_margin=0, neg_margin=1, **kwargs): 10 | super().__init__(mat_based_loss=False, **kwargs) 11 | self.pos_margin = pos_margin 12 | self.neg_margin = neg_margin 13 | self.add_to_recordable_attributes( 14 | list_of_names=["pos_margin", "neg_margin"], is_stat=False 15 | ) 16 | 17 | def _compute_loss(self, pos_pair_dist, neg_pair_dist, indices_tuple): 18 | pos_loss, neg_loss = 0, 0 19 | if len(pos_pair_dist) > 0: 20 | pos_loss = self.get_per_pair_loss(pos_pair_dist, "pos") 21 | if len(neg_pair_dist) > 0: 22 | neg_loss = self.get_per_pair_loss(neg_pair_dist, "neg") 23 | pos_pairs = lmu.pos_pairs_from_tuple(indices_tuple) 24 | neg_pairs = lmu.neg_pairs_from_tuple(indices_tuple) 25 | return { 26 | "pos_loss": { 27 | "losses": pos_loss, 28 | "indices": pos_pairs, 29 | "reduction_type": "pos_pair", 30 | }, 31 | "neg_loss": { 32 | "losses": neg_loss, 33 | "indices": neg_pairs, 34 | "reduction_type": "neg_pair", 35 | }, 36 | } 37 | 38 | def get_per_pair_loss(self, pair_dists, pos_or_neg): 39 | loss_calc_func = self.pos_calc if pos_or_neg == "pos" else self.neg_calc 40 | margin = self.pos_margin if pos_or_neg == "pos" else self.neg_margin 41 | per_pair_loss = loss_calc_func(pair_dists, margin) 42 | return per_pair_loss 43 | 44 | def pos_calc(self, pos_pair_dist, margin): 45 | return torch.nn.functional.relu(self.distance.margin(pos_pair_dist, margin)) 46 | 47 | def neg_calc(self, neg_pair_dist, margin): 48 | return torch.nn.functional.relu(self.distance.margin(margin, neg_pair_dist)) 49 | 50 | def get_default_reducer(self): 51 | return AvgNonZeroReducer() 52 | 53 | def _sub_loss_names(self): 54 | return ["pos_loss", "neg_loss"] 55 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/cosface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .large_margin_softmax_loss import LargeMarginSoftmaxLoss 5 | 6 | 7 | class CosFaceLoss(LargeMarginSoftmaxLoss): 8 | """ 9 | Implementation of https://arxiv.org/pdf/1801.07698.pdf 10 | """ 11 | 12 | def __init__(self, *args, margin=0.35, scale=64, **kwargs): 13 | super().__init__(*args, margin=margin, scale=scale, **kwargs) 14 | 15 | def init_margin(self): 16 | pass 17 | 18 | def cast_types(self, dtype, device): 19 | self.W.data = c_f.to_device(self.W.data, device=device, dtype=dtype) 20 | 21 | def modify_cosine_of_target_classes(self, cosine_of_target_classes): 22 | if self.collect_stats: 23 | with torch.no_grad(): 24 | self.get_angles( 25 | cosine_of_target_classes 26 | ) # For the purpose of collecting stats 27 | return cosine_of_target_classes - self.margin 28 | 29 | def scale_logits(self, logits, *_): 30 | return logits * self.scale 31 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/fast_ap_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import LpDistance 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_metric_loss_function import BaseMetricLossFunction 7 | 8 | 9 | class FastAPLoss(BaseMetricLossFunction): 10 | def __init__(self, num_bins=10, **kwargs): 11 | super().__init__(**kwargs) 12 | c_f.assert_distance_type(self, LpDistance, normalize_embeddings=True, p=2) 13 | self.num_bins = int(num_bins) 14 | self.num_edges = self.num_bins + 1 15 | self.add_to_recordable_attributes(list_of_names=["num_bins"], is_stat=False) 16 | 17 | """ 18 | Adapted from https://github.com/kunhe/FastAP-metric-learning 19 | """ 20 | 21 | def compute_loss(self, embeddings, labels, indices_tuple): 22 | dtype, device = embeddings.dtype, embeddings.device 23 | miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=dtype) 24 | N = labels.size(0) 25 | a1_idx, p_idx, a2_idx, n_idx = lmu.get_all_pairs_indices(labels) 26 | I_pos = torch.zeros(N, N, dtype=dtype, device=device) 27 | I_neg = torch.zeros(N, N, dtype=dtype, device=device) 28 | I_pos[a1_idx, p_idx] = 1 29 | I_neg[a2_idx, n_idx] = 1 30 | N_pos = torch.sum(I_pos, dim=1) 31 | safe_N = N_pos > 0 32 | if torch.sum(safe_N) == 0: 33 | return self.zero_losses() 34 | dist_mat = self.distance(embeddings) 35 | 36 | histogram_max = 2 ** self.distance.power 37 | histogram_delta = histogram_max / self.num_bins 38 | mid_points = torch.linspace( 39 | 0.0, histogram_max, steps=self.num_edges, device=device, dtype=dtype 40 | ).view(-1, 1, 1) 41 | pulse = torch.nn.functional.relu( 42 | 1 - torch.abs(dist_mat - mid_points) / histogram_delta 43 | ) 44 | pos_hist = torch.t(torch.sum(pulse * I_pos, dim=2)) 45 | neg_hist = torch.t(torch.sum(pulse * I_neg, dim=2)) 46 | 47 | total_pos_hist = torch.cumsum(pos_hist, dim=1) 48 | total_hist = torch.cumsum(pos_hist + neg_hist, dim=1) 49 | 50 | h_pos_product = pos_hist * total_pos_hist 51 | safe_H = (h_pos_product > 0) & (total_hist > 0) 52 | if torch.sum(safe_H) > 0: 53 | FastAP = torch.zeros_like(pos_hist, device=device) 54 | FastAP[safe_H] = h_pos_product[safe_H] / total_hist[safe_H] 55 | FastAP = torch.sum(FastAP, dim=1) 56 | FastAP = FastAP[safe_N] / N_pos[safe_N] 57 | FastAP = (1 - FastAP) * miner_weights[safe_N] 58 | return { 59 | "loss": { 60 | "losses": FastAP, 61 | "indices": torch.where(safe_N)[0], 62 | "reduction_type": "element", 63 | } 64 | } 65 | return self.zero_losses() 66 | 67 | def get_default_distance(self): 68 | return LpDistance(power=2) 69 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/generic_pair_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .base_metric_loss_function import BaseMetricLossFunction 5 | 6 | 7 | class GenericPairLoss(BaseMetricLossFunction): 8 | def __init__(self, mat_based_loss, **kwargs): 9 | super().__init__(**kwargs) 10 | self.loss_method = ( 11 | self.mat_based_loss if mat_based_loss else self.pair_based_loss 12 | ) 13 | 14 | def compute_loss(self, embeddings, labels, indices_tuple): 15 | indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) 16 | if all(len(x) <= 1 for x in indices_tuple): 17 | return self.zero_losses() 18 | mat = self.distance(embeddings) 19 | return self.loss_method(mat, labels, indices_tuple) 20 | 21 | def _compute_loss(self): 22 | raise NotImplementedError 23 | 24 | def mat_based_loss(self, mat, labels, indices_tuple): 25 | a1, p, a2, n = indices_tuple 26 | pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) 27 | pos_mask[a1, p] = 1 28 | neg_mask[a2, n] = 1 29 | return self._compute_loss(mat, pos_mask, neg_mask) 30 | 31 | def pair_based_loss(self, mat, labels, indices_tuple): 32 | a1, p, a2, n = indices_tuple 33 | pos_pair, neg_pair = [], [] 34 | if len(a1) > 0: 35 | pos_pair = mat[a1, p] 36 | if len(a2) > 0: 37 | neg_pair = mat[a2, n] 38 | return self._compute_loss(pos_pair, neg_pair, indices_tuple) 39 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/intra_pair_variance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..losses import GenericPairLoss 5 | from ..utils import loss_and_miner_utils as lmu 6 | 7 | 8 | class IntraPairVarianceLoss(GenericPairLoss): 9 | def __init__(self, pos_eps=0.01, neg_eps=0.01, **kwargs): 10 | super().__init__(mat_based_loss=False, **kwargs) 11 | self.pos_eps = pos_eps 12 | self.neg_eps = neg_eps 13 | self.add_to_recordable_attributes( 14 | list_of_names=["pos_eps", "neg_eps"], is_stat=False 15 | ) 16 | 17 | # pos_pairs and neg_pairs already represent cos(theta) 18 | def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): 19 | pos_loss, neg_loss = 0, 0 20 | if len(pos_pairs) > 0: 21 | mean_pos_sim = torch.mean(pos_pairs) 22 | pos_var = self.variance_with_eps( 23 | pos_pairs, mean_pos_sim, self.pos_eps, self.distance.is_inverted 24 | ) 25 | pos_loss = torch.nn.functional.relu(pos_var) ** 2 26 | if len(neg_pairs) > 0: 27 | mean_neg_sim = torch.mean(neg_pairs) 28 | neg_var = self.variance_with_eps( 29 | neg_pairs, mean_neg_sim, self.neg_eps, not self.distance.is_inverted 30 | ) 31 | neg_loss = torch.nn.functional.relu(neg_var) ** 2 32 | pos_pairs_idx = lmu.pos_pairs_from_tuple(indices_tuple) 33 | neg_pairs_idx = lmu.neg_pairs_from_tuple(indices_tuple) 34 | return { 35 | "pos_loss": { 36 | "losses": pos_loss, 37 | "indices": pos_pairs_idx, 38 | "reduction_type": "pos_pair", 39 | }, 40 | "neg_loss": { 41 | "losses": neg_loss, 42 | "indices": neg_pairs_idx, 43 | "reduction_type": "neg_pair", 44 | }, 45 | } 46 | 47 | def variance_with_eps(self, pairs, mean_sim, eps, incentivize_increase): 48 | if incentivize_increase: 49 | return (1 - eps) * mean_sim - pairs 50 | return pairs - (1 + eps) * mean_sim 51 | 52 | def _sub_loss_names(self): 53 | return ["pos_loss", "neg_loss"] 54 | 55 | def get_default_distance(self): 56 | return CosineSimilarity() 57 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/lifted_structure_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from ..utils import loss_and_miner_utils as lmu 5 | from .generic_pair_loss import GenericPairLoss 6 | 7 | 8 | class LiftedStructureLoss(GenericPairLoss): 9 | def __init__(self, neg_margin=1, pos_margin=0, **kwargs): 10 | super().__init__(mat_based_loss=False, **kwargs) 11 | self.neg_margin = neg_margin 12 | self.pos_margin = pos_margin 13 | self.add_to_recordable_attributes( 14 | list_of_names=["pos_margin", "neg_margin"], is_stat=False 15 | ) 16 | 17 | def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): 18 | a1, p, a2, _ = indices_tuple 19 | dtype = pos_pairs.dtype 20 | 21 | if len(a1) > 0 and len(a2) > 0: 22 | pos_pairs = pos_pairs.unsqueeze(1) 23 | n_per_p = c_f.to_dtype( 24 | (a2.unsqueeze(0) == a1.unsqueeze(1)) 25 | | (a2.unsqueeze(0) == p.unsqueeze(1)), 26 | dtype=dtype, 27 | ) 28 | neg_pairs = neg_pairs * n_per_p 29 | keep_mask = ~(n_per_p == 0) 30 | 31 | remaining_pos_margin = self.distance.margin(pos_pairs, self.pos_margin) 32 | remaining_neg_margin = self.distance.margin(self.neg_margin, neg_pairs) 33 | 34 | neg_pairs_loss = lmu.logsumexp( 35 | remaining_neg_margin, keep_mask=keep_mask, add_one=False, dim=1 36 | ) 37 | loss_per_pos_pair = neg_pairs_loss + remaining_pos_margin 38 | loss_per_pos_pair = torch.relu(loss_per_pos_pair) ** 2 39 | loss_per_pos_pair /= ( 40 | 2 # divide by 2 since each positive pair will be counted twice 41 | ) 42 | return { 43 | "loss": { 44 | "losses": loss_per_pos_pair, 45 | "indices": (a1, p), 46 | "reduction_type": "pos_pair", 47 | } 48 | } 49 | return self.zero_losses() 50 | 51 | 52 | class GeneralizedLiftedStructureLoss(GenericPairLoss): 53 | # The 'generalized' lifted structure loss shown on page 4 54 | # of the "in defense of triplet loss" paper 55 | # https://arxiv.org/pdf/1703.07737.pdf 56 | def __init__(self, neg_margin=1, pos_margin=0, **kwargs): 57 | super().__init__(mat_based_loss=True, **kwargs) 58 | self.neg_margin = neg_margin 59 | self.pos_margin = pos_margin 60 | self.add_to_recordable_attributes( 61 | list_of_names=["pos_margin", "neg_margin"], is_stat=False 62 | ) 63 | 64 | def _compute_loss(self, mat, pos_mask, neg_mask): 65 | remaining_pos_margin = self.distance.margin(mat, self.pos_margin) 66 | remaining_neg_margin = self.distance.margin(self.neg_margin, mat) 67 | 68 | pos_loss = lmu.logsumexp( 69 | remaining_pos_margin, keep_mask=pos_mask.bool(), add_one=False 70 | ) 71 | neg_loss = lmu.logsumexp( 72 | remaining_neg_margin, keep_mask=neg_mask.bool(), add_one=False 73 | ) 74 | return { 75 | "loss": { 76 | "losses": torch.relu(pos_loss + neg_loss), 77 | "indices": c_f.torch_arange_from_size(mat), 78 | "reduction_type": "element", 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/margin_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..reducers import DivisorReducer 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_metric_loss_function import BaseMetricLossFunction 7 | 8 | 9 | class MarginLoss(BaseMetricLossFunction): 10 | def __init__( 11 | self, 12 | margin=0.2, 13 | nu=0, 14 | beta=1.2, 15 | triplets_per_anchor="all", 16 | learn_beta=False, 17 | num_classes=None, 18 | **kwargs 19 | ): 20 | super().__init__(**kwargs) 21 | self.margin = margin 22 | self.nu = nu 23 | self.learn_beta = learn_beta 24 | self.initialize_beta(beta, num_classes) 25 | self.triplets_per_anchor = triplets_per_anchor 26 | self.add_to_recordable_attributes( 27 | list_of_names=["margin", "nu", "beta"], is_stat=False 28 | ) 29 | 30 | def compute_loss(self, embeddings, labels, indices_tuple): 31 | indices_tuple = lmu.convert_to_triplets( 32 | indices_tuple, labels, self.triplets_per_anchor 33 | ) 34 | anchor_idx, positive_idx, negative_idx = indices_tuple 35 | if len(anchor_idx) == 0: 36 | return self.zero_losses() 37 | 38 | beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx]] 39 | beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype) 40 | 41 | mat = self.distance(embeddings) 42 | 43 | d_ap = mat[anchor_idx, positive_idx] 44 | d_an = mat[anchor_idx, negative_idx] 45 | 46 | pos_loss = torch.nn.functional.relu( 47 | self.distance.margin(d_ap, beta) + self.margin 48 | ) 49 | neg_loss = torch.nn.functional.relu( 50 | self.distance.margin(beta, d_an) + self.margin 51 | ) 52 | 53 | num_pos_pairs = torch.sum(pos_loss > 0.0) 54 | num_neg_pairs = torch.sum(neg_loss > 0.0) 55 | 56 | divisor = num_pos_pairs + num_neg_pairs 57 | 58 | margin_loss = pos_loss + neg_loss 59 | 60 | loss_dict = { 61 | "margin_loss": { 62 | "losses": margin_loss, 63 | "indices": indices_tuple, 64 | "reduction_type": "triplet", 65 | "divisor": divisor, 66 | }, 67 | "beta_reg_loss": self.compute_reg_loss(beta, anchor_idx, divisor), 68 | } 69 | 70 | return loss_dict 71 | 72 | def compute_reg_loss(self, beta, anchor_idx, divisor): 73 | if self.learn_beta: 74 | loss = beta * self.nu 75 | if len(self.beta) == 1: 76 | return { 77 | "losses": loss, 78 | "indices": None, 79 | "reduction_type": "already_reduced", 80 | } 81 | else: 82 | return { 83 | "losses": loss, 84 | "indices": anchor_idx, 85 | "reduction_type": "element", 86 | "divisor": divisor, 87 | } 88 | return self.zero_loss() 89 | 90 | def _sub_loss_names(self): 91 | return ["margin_loss", "beta_reg_loss"] 92 | 93 | def get_default_reducer(self): 94 | return DivisorReducer() 95 | 96 | def initialize_beta(self, beta, num_classes): 97 | self.beta = torch.tensor([float(beta)]) 98 | if num_classes: 99 | self.beta = torch.ones(num_classes) * self.beta 100 | if self.learn_beta: 101 | self.beta = torch.nn.Parameter(self.beta) 102 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/mixins.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | 5 | 6 | class WeightMixin: 7 | def __init__(self, weight_init_func=None, **kwargs): 8 | super().__init__(**kwargs) 9 | self.weight_init_func = weight_init_func 10 | if self.weight_init_func is None: 11 | self.weight_init_func = self.get_default_weight_init_func() 12 | 13 | def get_default_weight_init_func(self): 14 | return c_f.TorchInitWrapper(torch.nn.init.normal_) 15 | 16 | 17 | class WeightRegularizerMixin(WeightMixin): 18 | def __init__(self, weight_regularizer=None, weight_reg_weight=1, **kwargs): 19 | self.weight_regularizer = ( 20 | weight_regularizer is not None 21 | ) # hack needed to know whether reg will be in sub-loss names 22 | super().__init__(**kwargs) 23 | self.weight_regularizer = weight_regularizer 24 | self.weight_reg_weight = weight_reg_weight 25 | if self.weight_regularizer is not None: 26 | self.add_to_recordable_attributes( 27 | list_of_names=["weight_reg_weight"], is_stat=False 28 | ) 29 | 30 | def weight_regularization_loss(self, weights): 31 | if self.weight_regularizer is None: 32 | loss = 0 33 | else: 34 | loss = self.weight_regularizer(weights) * self.weight_reg_weight 35 | return {"losses": loss, "indices": None, "reduction_type": "already_reduced"} 36 | 37 | def add_weight_regularization_to_loss_dict(self, loss_dict, weights): 38 | if self.weight_regularizer is not None: 39 | loss_dict["weight_reg_loss"] = self.weight_regularization_loss(weights) 40 | 41 | def regularization_loss_names(self): 42 | return ["weight_reg_loss"] 43 | 44 | 45 | class EmbeddingRegularizerMixin: 46 | def __init__(self, embedding_regularizer=None, embedding_reg_weight=1, **kwargs): 47 | self.embedding_regularizer = ( 48 | embedding_regularizer is not None 49 | ) # hack needed to know whether reg will be in sub-loss names 50 | super().__init__(**kwargs) 51 | self.embedding_regularizer = embedding_regularizer 52 | self.embedding_reg_weight = embedding_reg_weight 53 | if self.embedding_regularizer is not None: 54 | self.add_to_recordable_attributes( 55 | list_of_names=["embedding_reg_weight"], is_stat=False 56 | ) 57 | 58 | def embedding_regularization_loss(self, embeddings): 59 | if self.embedding_regularizer is None: 60 | loss = 0 61 | else: 62 | loss = self.embedding_regularizer(embeddings) * self.embedding_reg_weight 63 | return {"losses": loss, "indices": None, "reduction_type": "already_reduced"} 64 | 65 | def add_embedding_regularization_to_loss_dict(self, loss_dict, embeddings): 66 | if self.embedding_regularizer is not None: 67 | loss_dict["embedding_reg_loss"] = self.embedding_regularization_loss( 68 | embeddings 69 | ) 70 | 71 | def regularization_loss_names(self): 72 | return ["embedding_reg_loss"] 73 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/multi_similarity_loss.py: -------------------------------------------------------------------------------- 1 | from ..distances import CosineSimilarity 2 | from ..utils import common_functions as c_f 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .generic_pair_loss import GenericPairLoss 5 | 6 | 7 | class MultiSimilarityLoss(GenericPairLoss): 8 | """ 9 | modified from https://github.com/MalongTech/research-ms-loss/ 10 | Args: 11 | alpha: The exponential weight for positive pairs 12 | beta: The exponential weight for negative pairs 13 | base: The shift in the exponent applied to both positive and negative pairs 14 | """ 15 | 16 | def __init__(self, alpha=2, beta=50, base=0.5, **kwargs): 17 | super().__init__(mat_based_loss=True, **kwargs) 18 | self.alpha = alpha 19 | self.beta = beta 20 | self.base = base 21 | self.add_to_recordable_attributes( 22 | list_of_names=["alpha", "beta", "base"], is_stat=False 23 | ) 24 | 25 | def _compute_loss(self, mat, pos_mask, neg_mask): 26 | pos_exp = self.distance.margin(mat, self.base) 27 | neg_exp = self.distance.margin(self.base, mat) 28 | pos_loss = (1.0 / self.alpha) * lmu.logsumexp( 29 | self.alpha * pos_exp, keep_mask=pos_mask.bool(), add_one=True 30 | ) 31 | neg_loss = (1.0 / self.beta) * lmu.logsumexp( 32 | self.beta * neg_exp, keep_mask=neg_mask.bool(), add_one=True 33 | ) 34 | return { 35 | "loss": { 36 | "losses": pos_loss + neg_loss, 37 | "indices": c_f.torch_arange_from_size(mat), 38 | "reduction_type": "element", 39 | } 40 | } 41 | 42 | def get_default_distance(self): 43 | return CosineSimilarity() 44 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/n_pairs_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import DotProductSimilarity 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_metric_loss_function import BaseMetricLossFunction 7 | 8 | 9 | class NPairsLoss(BaseMetricLossFunction): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | self.add_to_recordable_attributes(name="num_pairs", is_stat=True) 13 | self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="none") 14 | 15 | def compute_loss(self, embeddings, labels, indices_tuple): 16 | anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels( 17 | indices_tuple, labels 18 | ) 19 | self.num_pairs = len(anchor_idx) 20 | if self.num_pairs == 0: 21 | return self.zero_losses() 22 | anchors, positives = embeddings[anchor_idx], embeddings[positive_idx] 23 | targets = c_f.to_device(torch.arange(self.num_pairs), embeddings) 24 | sim_mat = self.distance(anchors, positives) 25 | if not self.distance.is_inverted: 26 | sim_mat = -sim_mat 27 | return { 28 | "loss": { 29 | "losses": self.cross_entropy(sim_mat, targets), 30 | "indices": anchor_idx, 31 | "reduction_type": "element", 32 | } 33 | } 34 | 35 | def get_default_distance(self): 36 | return DotProductSimilarity() 37 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/nca_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import LpDistance 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_metric_loss_function import BaseMetricLossFunction 7 | 8 | 9 | class NCALoss(BaseMetricLossFunction): 10 | def __init__(self, softmax_scale=1, **kwargs): 11 | super().__init__(**kwargs) 12 | self.softmax_scale = softmax_scale 13 | self.add_to_recordable_attributes( 14 | list_of_names=["softmax_scale"], is_stat=False 15 | ) 16 | 17 | # https://www.cs.toronto.edu/~hinton/absps/nca.pdf 18 | def compute_loss(self, embeddings, labels, indices_tuple): 19 | if len(embeddings) <= 1: 20 | return self.zero_losses() 21 | return self.nca_computation( 22 | embeddings, embeddings, labels, labels, indices_tuple 23 | ) 24 | 25 | def nca_computation( 26 | self, query, reference, query_labels, reference_labels, indices_tuple 27 | ): 28 | dtype = query.dtype 29 | miner_weights = lmu.convert_to_weights(indices_tuple, query_labels, dtype=dtype) 30 | mat = self.distance(query, reference) 31 | if not self.distance.is_inverted: 32 | mat = -mat 33 | if query is reference: 34 | mat.fill_diagonal_(c_f.neg_inf(dtype)) 35 | same_labels = c_f.to_dtype( 36 | query_labels.unsqueeze(1) == reference_labels.unsqueeze(0), dtype=dtype 37 | ) 38 | exp = torch.nn.functional.softmax(self.softmax_scale * mat, dim=1) 39 | exp = torch.sum(exp * same_labels, dim=1) 40 | non_zero = exp != 0 41 | loss = -torch.log(exp[non_zero]) * miner_weights[non_zero] 42 | return { 43 | "loss": { 44 | "losses": loss, 45 | "indices": c_f.torch_arange_from_size(query)[non_zero], 46 | "reduction_type": "element", 47 | } 48 | } 49 | 50 | def get_default_distance(self): 51 | return LpDistance(power=2) 52 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/normalized_softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import DotProductSimilarity 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_metric_loss_function import BaseMetricLossFunction 7 | from .mixins import WeightRegularizerMixin 8 | 9 | 10 | class NormalizedSoftmaxLoss(WeightRegularizerMixin, BaseMetricLossFunction): 11 | def __init__(self, num_classes, embedding_size, temperature=0.05, **kwargs): 12 | super().__init__(**kwargs) 13 | self.temperature = temperature 14 | self.W = torch.nn.Parameter(torch.Tensor(embedding_size, num_classes)) 15 | self.weight_init_func(self.W) 16 | self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="none") 17 | self.add_to_recordable_attributes( 18 | list_of_names=["embedding_size", "num_classes", "temperature"], 19 | is_stat=False, 20 | ) 21 | 22 | def cast_types(self, dtype, device): 23 | self.W.data = c_f.to_device(self.W.data, device=device, dtype=dtype) 24 | 25 | def compute_loss(self, embeddings, labels, indices_tuple): 26 | dtype, device = embeddings.dtype, embeddings.device 27 | self.cast_types(dtype, device) 28 | miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=dtype) 29 | normalized_W = self.distance.normalize(self.W, dim=0) 30 | exponent = self.distance(embeddings, normalized_W.t()) / self.temperature 31 | if not self.distance.is_inverted: 32 | exponent = -exponent 33 | unweighted_loss = self.cross_entropy(exponent, labels) 34 | miner_weighted_loss = unweighted_loss * miner_weights 35 | loss_dict = { 36 | "loss": { 37 | "losses": miner_weighted_loss, 38 | "indices": c_f.torch_arange_from_size(embeddings), 39 | "reduction_type": "element", 40 | } 41 | } 42 | self.add_weight_regularization_to_loss_dict(loss_dict, self.W.t()) 43 | return loss_dict 44 | 45 | def get_default_distance(self): 46 | return DotProductSimilarity() 47 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/ntxent_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..utils import common_functions as c_f 5 | from .generic_pair_loss import GenericPairLoss 6 | 7 | 8 | class NTXentLoss(GenericPairLoss): 9 | def __init__(self, temperature=0.07, **kwargs): 10 | super().__init__(mat_based_loss=False, **kwargs) 11 | self.temperature = temperature 12 | self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) 13 | 14 | def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): 15 | a1, p, a2, _ = indices_tuple 16 | 17 | if len(a1) > 0 and len(a2) > 0: 18 | dtype = neg_pairs.dtype 19 | # if dealing with actual distances, use negative distances 20 | if not self.distance.is_inverted: 21 | pos_pairs = -pos_pairs 22 | neg_pairs = -neg_pairs 23 | 24 | pos_pairs = pos_pairs.unsqueeze(1) / self.temperature 25 | neg_pairs = neg_pairs / self.temperature 26 | n_per_p = c_f.to_dtype(a2.unsqueeze(0) == a1.unsqueeze(1), dtype=dtype) 27 | neg_pairs = neg_pairs * n_per_p 28 | neg_pairs[n_per_p == 0] = c_f.neg_inf(dtype) 29 | 30 | max_val = torch.max( 31 | pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0] 32 | ).detach() 33 | numerator = torch.exp(pos_pairs - max_val).squeeze(1) 34 | denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator 35 | log_exp = torch.log((numerator / denominator) + c_f.small_val(dtype)) 36 | return { 37 | "loss": { 38 | "losses": -log_exp, 39 | "indices": (a1, p), 40 | "reduction_type": "pos_pair", 41 | } 42 | } 43 | return self.zero_losses() 44 | 45 | def get_default_distance(self): 46 | return CosineSimilarity() 47 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/proxy_anchor_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..reducers import DivisorReducer 5 | from ..utils import common_functions as c_f 6 | from ..utils import loss_and_miner_utils as lmu 7 | from .base_metric_loss_function import BaseMetricLossFunction 8 | from .mixins import WeightRegularizerMixin 9 | 10 | 11 | # adapted from 12 | # https://github.com/tjddus9597/Proxy-Anchor-CVPR2020/blob/master/code/losses.py 13 | # https://github.com/geonm/proxy-anchor-loss/blob/master/pytorch-proxy-anchor.py 14 | # suggested in this issue: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/32 15 | class ProxyAnchorLoss(WeightRegularizerMixin, BaseMetricLossFunction): 16 | def __init__(self, num_classes, embedding_size, margin=0.1, alpha=32, **kwargs): 17 | super().__init__(**kwargs) 18 | self.proxies = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size)) 19 | self.weight_init_func(self.proxies) 20 | self.num_classes = num_classes 21 | self.margin = margin 22 | self.alpha = alpha 23 | self.add_to_recordable_attributes( 24 | list_of_names=["num_classes", "alpha", "margin"], is_stat=False 25 | ) 26 | 27 | def cast_types(self, dtype, device): 28 | self.proxies.data = c_f.to_device(self.proxies.data, device=device, dtype=dtype) 29 | 30 | def compute_loss(self, embeddings, labels, indices_tuple): 31 | dtype, device = embeddings.dtype, embeddings.device 32 | self.cast_types(dtype, device) 33 | miner_weights = lmu.convert_to_weights( 34 | indices_tuple, labels, dtype=dtype 35 | ).unsqueeze(1) 36 | miner_weights = miner_weights - 1 37 | 38 | cos = self.distance(embeddings, self.proxies) 39 | 40 | pos_mask = torch.nn.functional.one_hot(labels, self.num_classes) 41 | neg_mask = 1 - pos_mask 42 | 43 | with_pos_proxies = torch.where(torch.sum(pos_mask, dim=0) != 0)[0] 44 | 45 | pos_exp = self.distance.margin(cos, self.margin) 46 | neg_exp = self.distance.margin(-self.margin, cos) 47 | 48 | pos_term = lmu.logsumexp( 49 | (self.alpha * pos_exp) + miner_weights, 50 | keep_mask=pos_mask.bool(), 51 | add_one=True, 52 | dim=0, 53 | ) 54 | neg_term = lmu.logsumexp( 55 | (self.alpha * neg_exp) + miner_weights, 56 | keep_mask=neg_mask.bool(), 57 | add_one=True, 58 | dim=0, 59 | ) 60 | 61 | loss_indices = c_f.torch_arange_from_size(self.proxies) 62 | 63 | loss_dict = { 64 | "pos_loss": { 65 | "losses": pos_term.squeeze(0), 66 | "indices": loss_indices, 67 | "reduction_type": "element", 68 | "divisor": len(with_pos_proxies), 69 | }, 70 | "neg_loss": { 71 | "losses": neg_term.squeeze(0), 72 | "indices": loss_indices, 73 | "reduction_type": "element", 74 | "divisor": self.num_classes, 75 | }, 76 | } 77 | 78 | self.add_weight_regularization_to_loss_dict(loss_dict, self.proxies) 79 | 80 | return loss_dict 81 | 82 | def get_default_reducer(self): 83 | return DivisorReducer() 84 | 85 | def get_default_distance(self): 86 | return CosineSimilarity() 87 | 88 | def get_default_weight_init_func(self): 89 | return c_f.TorchInitWrapper(torch.nn.init.kaiming_normal_, mode="fan_out") 90 | 91 | def _sub_loss_names(self): 92 | return ["pos_loss", "neg_loss"] 93 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/proxy_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .mixins import WeightRegularizerMixin 5 | from .nca_loss import NCALoss 6 | 7 | 8 | class ProxyNCALoss(WeightRegularizerMixin, NCALoss): 9 | def __init__(self, num_classes, embedding_size, **kwargs): 10 | super().__init__(**kwargs) 11 | self.proxies = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size)) 12 | self.weight_init_func(self.proxies) 13 | self.proxy_labels = torch.arange(num_classes) 14 | self.add_to_recordable_attributes(list_of_names=["num_classes"], is_stat=False) 15 | 16 | def cast_types(self, dtype, device): 17 | self.proxies.data = c_f.to_device(self.proxies.data, device=device, dtype=dtype) 18 | 19 | def compute_loss(self, embeddings, labels, indices_tuple): 20 | dtype, device = embeddings.dtype, embeddings.device 21 | self.cast_types(dtype, device) 22 | loss_dict = self.nca_computation( 23 | embeddings, 24 | self.proxies, 25 | labels, 26 | c_f.to_device(self.proxy_labels, labels), 27 | indices_tuple, 28 | ) 29 | self.add_weight_regularization_to_loss_dict(loss_dict, self.proxies) 30 | return loss_dict 31 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/signal_to_noise_ratio_losses.py: -------------------------------------------------------------------------------- 1 | from ..distances import SNRDistance 2 | from ..utils import common_functions as c_f 3 | from .contrastive_loss import ContrastiveLoss 4 | 5 | 6 | class SignalToNoiseRatioContrastiveLoss(ContrastiveLoss): 7 | def __init__(self, **kwargs): 8 | super().__init__(**kwargs) 9 | c_f.assert_distance_type(self, SNRDistance) 10 | 11 | def get_default_distance(self): 12 | return SNRDistance() 13 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/soft_triple_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..distances import CosineSimilarity 7 | from ..utils import common_functions as c_f 8 | from ..utils import loss_and_miner_utils as lmu 9 | from .base_metric_loss_function import BaseMetricLossFunction 10 | from .mixins import WeightRegularizerMixin 11 | 12 | 13 | ###### modified from https://github.com/idstcv/SoftTriple/blob/master/loss/SoftTriple.py ###### 14 | ###### Original code is Copyright@Alibaba Group ###### 15 | ###### ICCV'19: "SoftTriple Loss: Deep Metric Learning Without Triplet Sampling" ###### 16 | class SoftTripleLoss(WeightRegularizerMixin, BaseMetricLossFunction): 17 | def __init__( 18 | self, 19 | num_classes, 20 | embedding_size, 21 | centers_per_class=10, 22 | la=20, 23 | gamma=0.1, 24 | margin=0.01, 25 | **kwargs 26 | ): 27 | super().__init__(**kwargs) 28 | assert self.distance.is_inverted 29 | self.la = la 30 | self.gamma = 1.0 / gamma 31 | self.margin = margin 32 | self.num_classes = num_classes 33 | self.centers_per_class = centers_per_class 34 | self.fc = torch.nn.Parameter( 35 | torch.Tensor(embedding_size, num_classes * centers_per_class) 36 | ) 37 | self.weight_init_func(self.fc) 38 | self.add_to_recordable_attributes( 39 | list_of_names=[ 40 | "la", 41 | "gamma", 42 | "margin", 43 | "centers_per_class", 44 | "num_classes", 45 | "embedding_size", 46 | ], 47 | is_stat=False, 48 | ) 49 | 50 | def cast_types(self, dtype, device): 51 | self.fc.data = c_f.to_device(self.fc.data, device=device, dtype=dtype) 52 | 53 | def compute_loss(self, embeddings, labels, indices_tuple): 54 | dtype, device = embeddings.dtype, embeddings.device 55 | self.cast_types(dtype, device) 56 | miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=dtype) 57 | sim_to_centers = self.distance(embeddings, self.fc.t()) 58 | sim_to_centers = sim_to_centers.view( 59 | -1, self.num_classes, self.centers_per_class 60 | ) 61 | prob = F.softmax(sim_to_centers * self.gamma, dim=2) 62 | sim_to_classes = torch.sum(prob * sim_to_centers, dim=2) 63 | margin = torch.zeros( 64 | sim_to_classes.shape, dtype=dtype, device=embeddings.device 65 | ) 66 | margin[torch.arange(0, margin.shape[0]), labels] = self.margin 67 | loss = F.cross_entropy( 68 | self.la * (sim_to_classes - margin), labels, reduction="none" 69 | ) 70 | loss = loss * miner_weights 71 | loss_dict = { 72 | "loss": { 73 | "losses": loss, 74 | "indices": c_f.torch_arange_from_size(embeddings), 75 | "reduction_type": "element", 76 | } 77 | } 78 | self.add_weight_regularization_to_loss_dict(loss_dict, self.fc.t()) 79 | return loss_dict 80 | 81 | def get_default_distance(self): 82 | return CosineSimilarity() 83 | 84 | def get_default_weight_init_func(self): 85 | return c_f.TorchInitWrapper(torch.nn.init.kaiming_uniform_, a=math.sqrt(5)) 86 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/sphereface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .large_margin_softmax_loss import LargeMarginSoftmaxLoss 4 | 5 | 6 | class SphereFaceLoss(LargeMarginSoftmaxLoss): 7 | # implementation of https://arxiv.org/pdf/1704.08063.pdf 8 | def scale_logits(self, logits, embeddings): 9 | embedding_norms = torch.norm(embeddings, p=2, dim=1) 10 | return logits * embedding_norms.unsqueeze(1) * self.scale 11 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/supcon_loss.py: -------------------------------------------------------------------------------- 1 | from ..distances import CosineSimilarity 2 | from ..reducers import AvgNonZeroReducer 3 | from ..utils import common_functions as c_f 4 | from ..utils import loss_and_miner_utils as lmu 5 | from .generic_pair_loss import GenericPairLoss 6 | 7 | 8 | # adapted from https://github.com/HobbitLong/SupContrast 9 | class SupConLoss(GenericPairLoss): 10 | def __init__(self, temperature=0.1, **kwargs): 11 | super().__init__(mat_based_loss=True, **kwargs) 12 | self.temperature = temperature 13 | self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) 14 | 15 | def _compute_loss(self, mat, pos_mask, neg_mask): 16 | if pos_mask.bool().any() and neg_mask.bool().any(): 17 | # if dealing with actual distances, use negative distances 18 | if not self.distance.is_inverted: 19 | mat = -mat 20 | mat = mat / self.temperature 21 | mat_max, _ = mat.max(dim=1, keepdim=True) 22 | mat = mat - mat_max.detach() # for numerical stability 23 | 24 | denominator = lmu.logsumexp( 25 | mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 26 | ) 27 | log_prob = mat - denominator 28 | mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( 29 | pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) 30 | ) 31 | 32 | return { 33 | "loss": { 34 | "losses": -mean_log_prob_pos, 35 | "indices": c_f.torch_arange_from_size(mat), 36 | "reduction_type": "element", 37 | } 38 | } 39 | return self.zero_losses() 40 | 41 | def get_default_reducer(self): 42 | return AvgNonZeroReducer() 43 | 44 | def get_default_distance(self): 45 | return CosineSimilarity() 46 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/triplet_margin_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..reducers import AvgNonZeroReducer 4 | from ..utils import loss_and_miner_utils as lmu 5 | from .base_metric_loss_function import BaseMetricLossFunction 6 | 7 | 8 | class TripletMarginLoss(BaseMetricLossFunction): 9 | """ 10 | Args: 11 | margin: The desired difference between the anchor-positive distance and the 12 | anchor-negative distance. 13 | swap: Use the positive-negative distance instead of anchor-negative distance, 14 | if it violates the margin more. 15 | smooth_loss: Use the log-exp version of the triplet loss 16 | """ 17 | 18 | def __init__( 19 | self, 20 | margin=0.05, 21 | swap=False, 22 | smooth_loss=False, 23 | triplets_per_anchor="all", 24 | **kwargs 25 | ): 26 | super().__init__(**kwargs) 27 | self.margin = margin 28 | self.swap = swap 29 | self.smooth_loss = smooth_loss 30 | self.triplets_per_anchor = triplets_per_anchor 31 | self.add_to_recordable_attributes(list_of_names=["margin"], is_stat=False) 32 | 33 | def compute_loss(self, embeddings, labels, indices_tuple): 34 | indices_tuple = lmu.convert_to_triplets( 35 | indices_tuple, labels, t_per_anchor=self.triplets_per_anchor 36 | ) 37 | anchor_idx, positive_idx, negative_idx = indices_tuple 38 | if len(anchor_idx) == 0: 39 | return self.zero_losses() 40 | mat = self.distance(embeddings) 41 | ap_dists = mat[anchor_idx, positive_idx] 42 | an_dists = mat[anchor_idx, negative_idx] 43 | if self.swap: 44 | pn_dists = mat[positive_idx, negative_idx] 45 | an_dists = self.distance.smallest_dist(an_dists, pn_dists) 46 | 47 | current_margins = self.distance.margin(ap_dists, an_dists) 48 | violation = current_margins + self.margin 49 | if self.smooth_loss: 50 | loss = torch.nn.functional.softplus(violation) 51 | else: 52 | loss = torch.nn.functional.relu(violation) 53 | 54 | return { 55 | "loss": { 56 | "losses": loss, 57 | "indices": indices_tuple, 58 | "reduction_type": "triplet", 59 | } 60 | } 61 | 62 | def get_default_reducer(self): 63 | return AvgNonZeroReducer() 64 | -------------------------------------------------------------------------------- /pytorch_metric_learning/losses/tuplet_margin_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..distances import CosineSimilarity 5 | from ..utils import common_functions as c_f 6 | from ..utils import loss_and_miner_utils as lmu 7 | from .generic_pair_loss import GenericPairLoss 8 | 9 | 10 | class TupletMarginLoss(GenericPairLoss): 11 | def __init__(self, margin=5.73, scale=64, **kwargs): 12 | super().__init__(mat_based_loss=False, **kwargs) 13 | c_f.assert_distance_type(self, CosineSimilarity) 14 | self.margin = np.radians(margin) 15 | self.scale = scale 16 | self.add_to_recordable_attributes( 17 | list_of_names=["margin", "scale"], is_stat=False 18 | ) 19 | self.add_to_recordable_attributes( 20 | list_of_names=["avg_pos_angle", "avg_neg_angle"], is_stat=True 21 | ) 22 | 23 | # pos_pairs and neg_pairs already represent cos(theta) 24 | def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple): 25 | a1, p, a2, _ = indices_tuple 26 | 27 | if len(a1) > 0 and len(a2) > 0: 28 | pos_angles = torch.acos(pos_pairs) 29 | self.set_stats(pos_angles, neg_pairs) 30 | pos_pairs = torch.cos(pos_angles - self.margin) 31 | pos_pairs = pos_pairs.unsqueeze(1) 32 | neg_pairs = neg_pairs.repeat(pos_pairs.size(0), 1) 33 | inside_exp = self.scale * (neg_pairs - pos_pairs) 34 | keep_mask = a2.unsqueeze(0) == a1.unsqueeze(1) 35 | loss = lmu.logsumexp(inside_exp, keep_mask=keep_mask, add_one=True, dim=1) 36 | return { 37 | "loss": { 38 | "losses": loss, 39 | "indices": (a1, p), 40 | "reduction_type": "pos_pair", 41 | } 42 | } 43 | return self.zero_losses() 44 | 45 | def get_default_distance(self): 46 | return CosineSimilarity() 47 | 48 | def set_stats(self, pos_angles, neg_pairs): 49 | if self.collect_stats: 50 | with torch.no_grad(): 51 | neg_angles = torch.acos(neg_pairs) 52 | self.avg_pos_angle = np.degrees(torch.mean(pos_angles).item()) 53 | self.avg_neg_angle = np.degrees(torch.mean(neg_angles).item()) 54 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/__init__.py: -------------------------------------------------------------------------------- 1 | from .angular_miner import AngularMiner 2 | from .base_miner import BaseMiner, BaseSubsetBatchMiner, BaseTupleMiner 3 | from .batch_easy_hard_miner import BatchEasyHardMiner 4 | from .batch_hard_miner import BatchHardMiner 5 | from .distance_weighted_miner import DistanceWeightedMiner 6 | from .embeddings_already_packaged_as_triplets import EmbeddingsAlreadyPackagedAsTriplets 7 | from .hdc_miner import HDCMiner 8 | from .maximum_loss_miner import MaximumLossMiner 9 | from .multi_similarity_miner import MultiSimilarityMiner 10 | from .pair_margin_miner import PairMarginMiner 11 | from .triplet_margin_miner import TripletMarginMiner 12 | from .uniform_histogram_miner import UniformHistogramMiner 13 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/angular_miner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..distances import LpDistance 5 | from ..utils import common_functions as c_f 6 | from ..utils import loss_and_miner_utils as lmu 7 | from .base_miner import BaseTupleMiner 8 | 9 | 10 | class AngularMiner(BaseTupleMiner): 11 | """ 12 | Returns triplets that form an angle greater than some threshold (angle). 13 | The angle is computed as defined in the angular loss paper: 14 | https://arxiv.org/abs/1708.01682 15 | """ 16 | 17 | def __init__(self, angle=20, **kwargs): 18 | super().__init__(**kwargs) 19 | c_f.assert_distance_type( 20 | self, LpDistance, p=2, power=1, normalize_embeddings=True 21 | ) 22 | self.angle = np.radians(angle) 23 | self.add_to_recordable_attributes(list_of_names=["angle"], is_stat=False) 24 | self.add_to_recordable_attributes( 25 | list_of_names=[ 26 | "average_angle", 27 | "average_angle_above_threshold", 28 | "average_angle_below_threshold", 29 | "min_angle", 30 | "max_angle", 31 | "std_of_angle", 32 | ], 33 | is_stat=True, 34 | ) 35 | 36 | def mine(self, embeddings, labels, ref_emb, ref_labels): 37 | anchor_idx, positive_idx, negative_idx = lmu.get_all_triplets_indices( 38 | labels, ref_labels 39 | ) 40 | anchors, positives, negatives = ( 41 | embeddings[anchor_idx], 42 | ref_emb[positive_idx], 43 | ref_emb[negative_idx], 44 | ) 45 | centers = (anchors + positives) / 2 46 | ap_dist = self.distance.pairwise_distance(anchors, positives) 47 | nc_dist = self.distance.pairwise_distance(negatives, centers) 48 | angles = torch.atan(ap_dist / (2 * nc_dist)) 49 | threshold_condition = angles > self.angle 50 | self.set_stats(angles, threshold_condition) 51 | return ( 52 | anchor_idx[threshold_condition], 53 | positive_idx[threshold_condition], 54 | negative_idx[threshold_condition], 55 | ) 56 | 57 | def set_stats(self, angles, threshold_condition): 58 | if self.collect_stats: 59 | with torch.no_grad(): 60 | if len(angles) > 0: 61 | self.average_angle = np.degrees(torch.mean(angles).item()) 62 | self.min_angle = np.degrees(torch.min(angles).item()) 63 | self.max_angle = np.degrees(torch.max(angles).item()) 64 | self.std_of_angle = np.degrees(torch.std(angles).item()) 65 | if torch.sum(threshold_condition) > 0: 66 | self.average_angle_above_threshold = np.degrees( 67 | torch.mean(angles[threshold_condition]).item() 68 | ) 69 | negated_condition = ~threshold_condition 70 | if torch.sum(negated_condition) > 0: 71 | self.average_angle_below_threshold = np.degrees( 72 | torch.mean(angles[~threshold_condition]).item() 73 | ) 74 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/base_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from ..utils.module_with_records_and_reducer import ModuleWithRecordsAndDistance 5 | 6 | 7 | class BaseMiner(ModuleWithRecordsAndDistance): 8 | def mine(self, embeddings, labels, ref_emb, ref_labels): 9 | raise NotImplementedError 10 | 11 | def output_assertion(self, output): 12 | raise NotImplementedError 13 | 14 | def forward(self, embeddings, labels, ref_emb=None, ref_labels=None): 15 | """ 16 | Args: 17 | embeddings: tensor of size (batch_size, embedding_size) 18 | labels: tensor of size (batch_size) 19 | Does any necessary preprocessing, then does mining, and then checks the 20 | shape of the mining output before returning it 21 | """ 22 | self.reset_stats() 23 | with torch.no_grad(): 24 | c_f.check_shapes(embeddings, labels) 25 | labels = c_f.to_device(labels, embeddings) 26 | ref_emb, ref_labels = self.set_ref_emb( 27 | embeddings, labels, ref_emb, ref_labels 28 | ) 29 | mining_output = self.mine(embeddings, labels, ref_emb, ref_labels) 30 | self.output_assertion(mining_output) 31 | return mining_output 32 | 33 | def set_ref_emb(self, embeddings, labels, ref_emb, ref_labels): 34 | if ref_emb is not None: 35 | ref_labels = c_f.to_device(ref_labels, ref_emb) 36 | else: 37 | ref_emb, ref_labels = embeddings, labels 38 | c_f.check_shapes(ref_emb, ref_labels) 39 | return ref_emb, ref_labels 40 | 41 | 42 | class BaseTupleMiner(BaseMiner): 43 | def __init__(self, **kwargs): 44 | super().__init__(**kwargs) 45 | self.add_to_recordable_attributes( 46 | list_of_names=["num_pos_pairs", "num_neg_pairs", "num_triplets"], 47 | is_stat=True, 48 | ) 49 | 50 | def output_assertion(self, output): 51 | """ 52 | Args: 53 | output: the output of self.mine 54 | This asserts that the mining function is outputting 55 | properly formatted indices. The default is to require a tuple representing 56 | a,p,n indices or a1,p,a2,n indices within a batch of embeddings. 57 | For example, a tuple of (anchors, positives, negatives) will be 58 | (torch.tensor, torch.tensor, torch.tensor) 59 | """ 60 | if len(output) == 3: 61 | self.num_triplets = len(output[0]) 62 | assert self.num_triplets == len(output[1]) == len(output[2]) 63 | elif len(output) == 4: 64 | self.num_pos_pairs = len(output[0]) 65 | self.num_neg_pairs = len(output[2]) 66 | assert self.num_pos_pairs == len(output[1]) 67 | assert self.num_neg_pairs == len(output[3]) 68 | else: 69 | raise BaseException 70 | 71 | 72 | class BaseSubsetBatchMiner(BaseMiner): 73 | """ 74 | Args: 75 | output_batch_size: type int. The size of the subset that the miner 76 | will output. 77 | """ 78 | 79 | def __init__(self, output_batch_size, **kwargs): 80 | super().__init__(**kwargs) 81 | self.output_batch_size = output_batch_size 82 | 83 | def output_assertion(self, output): 84 | assert len(output) == self.output_batch_size 85 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/batch_hard_miner.py: -------------------------------------------------------------------------------- 1 | from .batch_easy_hard_miner import BatchEasyHardMiner 2 | 3 | 4 | class BatchHardMiner(BatchEasyHardMiner): 5 | def __init__(self, **kwargs): 6 | super().__init__( 7 | pos_strategy=BatchEasyHardMiner.HARD, 8 | neg_strategy=BatchEasyHardMiner.HARD, 9 | **kwargs 10 | ) 11 | 12 | def mine(self, *args, **kwargs): 13 | a1, p, a2, n = super().mine(*args, **kwargs) 14 | return a1, p, n 15 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/distance_weighted_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import LpDistance 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_miner import BaseTupleMiner 7 | 8 | 9 | # adapted from 10 | # https://github.com/chaoyuaw/incubator-mxnet/blob/master/example/gluon/embedding_learning/model.py 11 | class DistanceWeightedMiner(BaseTupleMiner): 12 | def __init__(self, cutoff=0.5, nonzero_loss_cutoff=1.4, **kwargs): 13 | super().__init__(**kwargs) 14 | c_f.assert_distance_type( 15 | self, LpDistance, p=2, power=1, normalize_embeddings=True 16 | ) 17 | self.cutoff = float(cutoff) 18 | self.nonzero_loss_cutoff = float(nonzero_loss_cutoff) 19 | self.add_to_recordable_attributes( 20 | list_of_names=["cutoff", "nonzero_loss_cutoff"], is_stat=False 21 | ) 22 | 23 | def mine(self, embeddings, labels, ref_emb, ref_labels): 24 | dtype = embeddings.dtype 25 | d = float(embeddings.size(1)) 26 | mat = self.distance(embeddings, ref_emb) 27 | 28 | # Cut off to avoid high variance. 29 | mat = torch.clamp(mat, min=self.cutoff) 30 | 31 | # See the first equation from Section 4 of the paper 32 | log_weights = (2.0 - d) * torch.log(mat) - ((d - 3) / 2) * torch.log( 33 | 1.0 - 0.25 * (mat ** 2.0) 34 | ) 35 | 36 | inf_or_nan = torch.isinf(log_weights) | torch.isnan(log_weights) 37 | 38 | # Sample only negative examples by setting weights of 39 | # the same-class examples to 0. 40 | mask = torch.ones_like(log_weights) 41 | same_class = labels.unsqueeze(1) == ref_labels.unsqueeze(0) 42 | mask[same_class] = 0 43 | log_weights = log_weights * mask 44 | # Subtract max(log(distance)) for stability. 45 | weights = torch.exp(log_weights - torch.max(log_weights[~inf_or_nan])) 46 | 47 | weights = ( 48 | weights * mask * (c_f.to_dtype(mat < self.nonzero_loss_cutoff, dtype=dtype)) 49 | ) 50 | weights[inf_or_nan] = 0 51 | 52 | weights = weights / torch.sum(weights, dim=1, keepdim=True) 53 | 54 | return lmu.get_random_triplet_indices( 55 | labels, ref_labels=ref_labels, weights=weights 56 | ) 57 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/embeddings_already_packaged_as_triplets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_miner import BaseTupleMiner 4 | 5 | 6 | class EmbeddingsAlreadyPackagedAsTriplets(BaseTupleMiner): 7 | # If the embeddings are grouped by triplet, 8 | # then use this miner to force the loss function to use the already-formed triplets 9 | def mine(self, embeddings, labels, ref_emb, ref_labels): 10 | batch_size = embeddings.size(0) 11 | a = torch.arange(0, batch_size, 3) 12 | p = torch.arange(1, batch_size, 3) 13 | n = torch.arange(2, batch_size, 3) 14 | return a, p, n 15 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/hdc_miner.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_miner import BaseTupleMiner 7 | 8 | 9 | # mining method used in Hard Aware Deeply Cascaded Embeddings 10 | # https://arxiv.org/abs/1611.05720 11 | class HDCMiner(BaseTupleMiner): 12 | def __init__(self, filter_percentage=0.5, **kwargs): 13 | super().__init__(**kwargs) 14 | self.filter_percentage = filter_percentage 15 | self.add_to_recordable_attributes( 16 | list_of_names=["filter_percentage"], is_stat=False 17 | ) 18 | self.reset_idx() 19 | 20 | def mine(self, embeddings, labels, ref_emb, ref_labels): 21 | mat = self.distance(embeddings, ref_emb) 22 | self.set_idx(labels, ref_labels) 23 | 24 | for name, (anchor, other) in { 25 | "pos": (self.a1, self.p), 26 | "neg": (self.a2, self.n), 27 | }.items(): 28 | if len(anchor) > 0: 29 | pairs = mat[anchor, other] 30 | num_pairs = len(pairs) 31 | k = int(math.ceil(self.filter_percentage * num_pairs)) 32 | largest = self.should_select_largest(name) 33 | _, idx = torch.topk(pairs, k=k, largest=largest) 34 | self.filter_original_indices(name, idx) 35 | 36 | return self.a1, self.p, self.a2, self.n 37 | 38 | def should_select_largest(self, name): 39 | if self.distance.is_inverted: 40 | return False if name == "pos" else True 41 | return True if name == "pos" else False 42 | 43 | def set_idx(self, labels, ref_labels): 44 | if not self.was_set_externally: 45 | self.a1, self.p, self.a2, self.n = lmu.get_all_pairs_indices( 46 | labels, ref_labels 47 | ) 48 | 49 | def set_idx_externally(self, external_indices_tuple, labels): 50 | self.a1, self.p, self.a2, self.n = lmu.convert_to_pairs( 51 | external_indices_tuple, labels 52 | ) 53 | self.was_set_externally = True 54 | 55 | def reset_idx(self): 56 | self.a1, self.p, self.a2, self.n = None, None, None, None 57 | self.was_set_externally = False 58 | 59 | def filter_original_indices(self, name, idx): 60 | if name == "pos": 61 | self.a1 = self.a1[idx] 62 | self.p = self.p[idx] 63 | else: 64 | self.a2 = self.a2[idx] 65 | self.n = self.n[idx] 66 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/maximum_loss_miner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..utils import common_functions as c_f 5 | from .base_miner import BaseSubsetBatchMiner 6 | 7 | 8 | class MaximumLossMiner(BaseSubsetBatchMiner): 9 | def __init__(self, loss, miner=None, num_trials=5, **kwargs): 10 | super().__init__(**kwargs) 11 | self.loss = loss 12 | self.miner = miner 13 | self.num_trials = num_trials 14 | self.add_to_recordable_attributes(list_of_names=["num_trials"], is_stat=False) 15 | self.add_to_recordable_attributes( 16 | list_of_names=["min_loss", "avg_loss", "max_loss"], is_stat=True 17 | ) 18 | 19 | def mine(self, embeddings, labels, *_): 20 | losses = [] 21 | all_subset_idx = [] 22 | for i in range(self.num_trials): 23 | rand_subset_idx = c_f.NUMPY_RANDOM.choice( 24 | len(embeddings), size=self.output_batch_size, replace=False 25 | ) 26 | rand_subset_idx = torch.from_numpy(rand_subset_idx) 27 | all_subset_idx.append(rand_subset_idx) 28 | curr_embeddings, curr_labels = ( 29 | embeddings[rand_subset_idx], 30 | labels[rand_subset_idx], 31 | ) 32 | indices_tuple = self.inner_miner(curr_embeddings, curr_labels) 33 | losses.append(self.loss(curr_embeddings, curr_labels, indices_tuple).item()) 34 | max_idx = np.argmax(losses) 35 | self.min_loss = np.min(losses) 36 | self.avg_loss = np.mean(losses) 37 | self.max_loss = losses[max_idx] 38 | return all_subset_idx[max_idx] 39 | 40 | def inner_miner(self, embeddings, labels): 41 | if self.miner: 42 | return self.miner(embeddings, labels) 43 | return None 44 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/multi_similarity_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..utils import common_functions as c_f 5 | from ..utils import loss_and_miner_utils as lmu 6 | from .base_miner import BaseTupleMiner 7 | 8 | 9 | class MultiSimilarityMiner(BaseTupleMiner): 10 | def __init__(self, epsilon=0.1, **kwargs): 11 | super().__init__(**kwargs) 12 | self.epsilon = epsilon 13 | self.add_to_recordable_attributes(name="epsilon", is_stat=False) 14 | 15 | def mine(self, embeddings, labels, ref_emb, ref_labels): 16 | mat = self.distance(embeddings, ref_emb) 17 | a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels) 18 | 19 | if len(a1) == 0 or len(a2) == 0: 20 | empty = torch.tensor([], device=labels.device, dtype=torch.long) 21 | return empty.clone(), empty.clone(), empty.clone(), empty.clone() 22 | 23 | mat_neg_sorting = mat 24 | mat_pos_sorting = mat.clone() 25 | 26 | dtype = mat.dtype 27 | pos_ignore = ( 28 | c_f.pos_inf(dtype) if self.distance.is_inverted else c_f.neg_inf(dtype) 29 | ) 30 | neg_ignore = ( 31 | c_f.neg_inf(dtype) if self.distance.is_inverted else c_f.pos_inf(dtype) 32 | ) 33 | 34 | mat_pos_sorting[a2, n] = pos_ignore 35 | mat_neg_sorting[a1, p] = neg_ignore 36 | if embeddings is ref_emb: 37 | mat_pos_sorting.fill_diagonal_(pos_ignore) 38 | mat_neg_sorting.fill_diagonal_(neg_ignore) 39 | 40 | pos_sorted, pos_sorted_idx = torch.sort(mat_pos_sorting, dim=1) 41 | neg_sorted, neg_sorted_idx = torch.sort(mat_neg_sorting, dim=1) 42 | 43 | if self.distance.is_inverted: 44 | hard_pos_idx = torch.where( 45 | pos_sorted - self.epsilon < neg_sorted[:, -1].unsqueeze(1) 46 | ) 47 | hard_neg_idx = torch.where( 48 | neg_sorted + self.epsilon > pos_sorted[:, 0].unsqueeze(1) 49 | ) 50 | else: 51 | hard_pos_idx = torch.where( 52 | pos_sorted + self.epsilon > neg_sorted[:, 0].unsqueeze(1) 53 | ) 54 | hard_neg_idx = torch.where( 55 | neg_sorted - self.epsilon < pos_sorted[:, -1].unsqueeze(1) 56 | ) 57 | 58 | a1 = hard_pos_idx[0] 59 | p = pos_sorted_idx[a1, hard_pos_idx[1]] 60 | a2 = hard_neg_idx[0] 61 | n = neg_sorted_idx[a2, hard_neg_idx[1]] 62 | 63 | return a1, p, a2, n 64 | 65 | def get_default_distance(self): 66 | return CosineSimilarity() 67 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/pair_margin_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .base_miner import BaseTupleMiner 5 | 6 | 7 | class PairMarginMiner(BaseTupleMiner): 8 | """ 9 | Returns positive pairs that have distance greater than a margin and negative 10 | pairs that have distance less than a margin 11 | """ 12 | 13 | def __init__(self, pos_margin=0.2, neg_margin=0.8, **kwargs): 14 | super().__init__(**kwargs) 15 | self.pos_margin = pos_margin 16 | self.neg_margin = neg_margin 17 | self.add_to_recordable_attributes( 18 | list_of_names=["pos_margin", "neg_margin"], is_stat=False 19 | ) 20 | self.add_to_recordable_attributes( 21 | list_of_names=["pos_pair_dist", "neg_pair_dist"], is_stat=True 22 | ) 23 | 24 | def mine(self, embeddings, labels, ref_emb, ref_labels): 25 | mat = self.distance(embeddings, ref_emb) 26 | a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels) 27 | pos_pair = mat[a1, p] 28 | neg_pair = mat[a2, n] 29 | self.set_stats(pos_pair, neg_pair) 30 | pos_mask = ( 31 | pos_pair < self.pos_margin 32 | if self.distance.is_inverted 33 | else pos_pair > self.pos_margin 34 | ) 35 | neg_mask = ( 36 | neg_pair > self.neg_margin 37 | if self.distance.is_inverted 38 | else neg_pair < self.neg_margin 39 | ) 40 | return a1[pos_mask], p[pos_mask], a2[neg_mask], n[neg_mask] 41 | 42 | def set_stats(self, pos_pair, neg_pair): 43 | if self.collect_stats: 44 | with torch.no_grad(): 45 | self.pos_pair_dist = ( 46 | torch.mean(pos_pair).item() if len(pos_pair) > 0 else 0 47 | ) 48 | self.neg_pair_dist = ( 49 | torch.mean(neg_pair).item() if len(neg_pair) > 0 else 0 50 | ) 51 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/triplet_margin_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .base_miner import BaseTupleMiner 5 | 6 | 7 | class TripletMarginMiner(BaseTupleMiner): 8 | """ 9 | Returns triplets that violate the margin 10 | Args: 11 | margin 12 | type_of_triplets: options are "all", "hard", or "semihard". 13 | "all" means all triplets that violate the margin 14 | "hard" is a subset of "all", but the negative is closer to the anchor than the positive 15 | "semihard" is a subset of "all", but the negative is further from the anchor than the positive 16 | "easy" is all triplets that are not in "all" 17 | """ 18 | 19 | def __init__(self, margin=0.2, type_of_triplets="all", **kwargs): 20 | super().__init__(**kwargs) 21 | self.margin = margin 22 | self.type_of_triplets = type_of_triplets 23 | self.add_to_recordable_attributes(list_of_names=["margin"], is_stat=False) 24 | self.add_to_recordable_attributes( 25 | list_of_names=["avg_triplet_margin", "pos_pair_dist", "neg_pair_dist"], 26 | is_stat=True, 27 | ) 28 | 29 | def mine(self, embeddings, labels, ref_emb, ref_labels): 30 | anchor_idx, positive_idx, negative_idx = lmu.get_all_triplets_indices( 31 | labels, ref_labels 32 | ) 33 | mat = self.distance(embeddings, ref_emb) 34 | ap_dist = mat[anchor_idx, positive_idx] 35 | an_dist = mat[anchor_idx, negative_idx] 36 | triplet_margin = ( 37 | ap_dist - an_dist if self.distance.is_inverted else an_dist - ap_dist 38 | ) 39 | 40 | if self.type_of_triplets == "easy": 41 | threshold_condition = triplet_margin > self.margin 42 | else: 43 | threshold_condition = triplet_margin <= self.margin 44 | if self.type_of_triplets == "hard": 45 | threshold_condition &= triplet_margin <= 0 46 | elif self.type_of_triplets == "semihard": 47 | threshold_condition &= triplet_margin > 0 48 | 49 | return ( 50 | anchor_idx[threshold_condition], 51 | positive_idx[threshold_condition], 52 | negative_idx[threshold_condition], 53 | ) 54 | 55 | def set_stats(self, ap_dist, an_dist, triplet_margin): 56 | if self.collect_stats: 57 | with torch.no_grad(): 58 | self.pos_pair_dist = torch.mean(ap_dist).item() 59 | self.neg_pair_dist = torch.mean(an_dist).item() 60 | self.avg_triplet_margin = torch.mean(triplet_margin).item() 61 | -------------------------------------------------------------------------------- /pytorch_metric_learning/miners/uniform_histogram_miner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import loss_and_miner_utils as lmu 4 | from .base_miner import BaseTupleMiner 5 | 6 | 7 | class UniformHistogramMiner(BaseTupleMiner): 8 | def __init__(self, num_bins=100, pos_per_bin=10, neg_per_bin=10, **kwargs): 9 | super().__init__(**kwargs) 10 | self.num_bins = num_bins 11 | self.pos_per_bin = pos_per_bin 12 | self.neg_per_bin = neg_per_bin 13 | self.add_to_recordable_attributes( 14 | list_of_names=["pos_per_bin", "neg_per_bin"], is_stat=False 15 | ) 16 | 17 | def mine(self, embeddings, labels, ref_emb, ref_labels): 18 | mat = self.distance(embeddings, ref_emb) 19 | a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels) 20 | pos_pairs = mat[a1, p] 21 | neg_pairs = mat[a2, n] 22 | 23 | if len(pos_pairs) > 0: 24 | a1, p = self.get_uniformly_distributed_pairs( 25 | pos_pairs, a1, p, self.pos_per_bin 26 | ) 27 | 28 | if len(neg_pairs) > 0: 29 | a2, n = self.get_uniformly_distributed_pairs( 30 | neg_pairs, a2, n, self.neg_per_bin 31 | ) 32 | 33 | return a1, p, a2, n 34 | 35 | def get_bins(self, pairs): 36 | device, dtype = pairs.device, pairs.dtype 37 | return torch.linspace( 38 | torch.min(pairs), 39 | torch.max(pairs), 40 | steps=self.num_bins + 1, 41 | device=device, 42 | dtype=dtype, 43 | ) 44 | 45 | def filter_by_bin(self, distances, bins, num_pairs): 46 | range_max = len(bins) - 1 47 | all_idx = [] 48 | for i in range(range_max): 49 | s, e = bins[i], bins[i + 1] 50 | low_condition = s <= distances 51 | high_condition = distances < e if i != range_max - 1 else distances <= e 52 | condition = torch.where(low_condition & high_condition)[0] 53 | if len(condition) == 0: 54 | continue 55 | idx = torch.multinomial( 56 | torch.ones_like(condition, device=condition.device, dtype=torch.float), 57 | num_pairs, 58 | replacement=True, 59 | ) 60 | all_idx.append(condition[idx]) 61 | return torch.cat(all_idx, dim=0) 62 | 63 | def get_uniformly_distributed_pairs(self, distances, anchors, others, num_pairs): 64 | bins = self.get_bins(distances) 65 | idx = self.filter_by_bin(distances, bins, num_pairs) 66 | return anchors[idx], others[idx] 67 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__init__.py: -------------------------------------------------------------------------------- 1 | from .avg_non_zero_reducer import AvgNonZeroReducer 2 | from .base_reducer import BaseReducer 3 | from .class_weighted_reducer import ClassWeightedReducer 4 | from .divisor_reducer import DivisorReducer 5 | from .do_nothing_reducer import DoNothingReducer 6 | from .mean_reducer import MeanReducer 7 | from .multiple_reducers import MultipleReducers 8 | from .per_anchor_reducer import PerAnchorReducer 9 | from .threshold_reducer import ThresholdReducer 10 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/avg_non_zero_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/avg_non_zero_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/base_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/base_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/class_weighted_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/class_weighted_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/divisor_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/divisor_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/do_nothing_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/do_nothing_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/mean_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/mean_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/multiple_reducers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/multiple_reducers.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/per_anchor_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/per_anchor_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/__pycache__/threshold_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/reducers/__pycache__/threshold_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/avg_non_zero_reducer.py: -------------------------------------------------------------------------------- 1 | from .threshold_reducer import ThresholdReducer 2 | 3 | 4 | class AvgNonZeroReducer(ThresholdReducer): 5 | def __init__(self, **kwargs): 6 | super().__init__(low=0, **kwargs) 7 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/base_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from ..utils.module_with_records import ModuleWithRecords 5 | 6 | 7 | class BaseReducer(ModuleWithRecords): 8 | def forward(self, loss_dict, embeddings, labels): 9 | self.reset_stats() 10 | assert len(loss_dict) == 1 11 | loss_name = list(loss_dict.keys())[0] 12 | loss_info = loss_dict[loss_name] 13 | self.add_to_recordable_attributes(name=loss_name, is_stat=True) 14 | losses, loss_indices, reduction_type, kwargs = self.unpack_loss_info(loss_info) 15 | loss_val = self.reduce_the_loss( 16 | losses, loss_indices, reduction_type, kwargs, embeddings, labels 17 | ) 18 | setattr(self, loss_name, loss_val.item()) 19 | return loss_val 20 | 21 | def unpack_loss_info(self, loss_info): 22 | return ( 23 | loss_info["losses"], 24 | loss_info["indices"], 25 | loss_info["reduction_type"], 26 | {}, 27 | ) 28 | 29 | def reduce_the_loss( 30 | self, losses, loss_indices, reduction_type, kwargs, embeddings, labels 31 | ): 32 | self.set_losses_size_stat(losses) 33 | if self.input_is_zero_loss(losses): 34 | return self.zero_loss(embeddings) 35 | self.assert_sizes(losses, loss_indices, reduction_type) 36 | reduction_func = self.get_reduction_func(reduction_type) 37 | return reduction_func(losses, loss_indices, embeddings, labels, **kwargs) 38 | 39 | def already_reduced_reduction(self, losses, loss_indices, embeddings, labels): 40 | assert losses.ndim == 0 or len(losses) == 1 41 | return losses 42 | 43 | def element_reduction(self, losses, loss_indices, embeddings, labels): 44 | raise NotImplementedError 45 | 46 | def pos_pair_reduction(self, losses, loss_indices, embeddings, labels): 47 | raise NotImplementedError 48 | 49 | def neg_pair_reduction(self, losses, loss_indices, embeddings, labels): 50 | raise NotImplementedError 51 | 52 | def triplet_reduction(self, losses, loss_indices, embeddings, labels): 53 | raise NotImplementedError 54 | 55 | def get_reduction_func(self, reduction_type): 56 | return getattr(self, "{}_reduction".format(reduction_type)) 57 | 58 | def assert_sizes(self, losses, loss_indices, reduction_type): 59 | getattr(self, "assert_sizes_{}".format(reduction_type))(losses, loss_indices) 60 | 61 | def zero_loss(self, embeddings): 62 | return torch.sum(embeddings * 0) 63 | 64 | def input_is_zero_loss(self, losses): 65 | if (not torch.is_tensor(losses)) and (losses == 0): 66 | return True 67 | return False 68 | 69 | def assert_sizes_already_reduced(self, losses, loss_indices): 70 | pass 71 | 72 | def assert_sizes_element(self, losses, loss_indices): 73 | assert torch.is_tensor(losses) 74 | assert torch.is_tensor(loss_indices) 75 | assert len(losses) == len(loss_indices) 76 | 77 | def assert_sizes_pair(self, losses, loss_indices): 78 | assert torch.is_tensor(losses) 79 | assert c_f.is_list_or_tuple(loss_indices) 80 | assert len(loss_indices) == 2 81 | assert all(torch.is_tensor(x) for x in loss_indices) 82 | assert len(losses) == len(loss_indices[0]) == len(loss_indices[1]) 83 | 84 | def assert_sizes_pos_pair(self, losses, loss_indices): 85 | self.assert_sizes_pair(losses, loss_indices) 86 | 87 | def assert_sizes_neg_pair(self, losses, loss_indices): 88 | self.assert_sizes_pair(losses, loss_indices) 89 | 90 | def assert_sizes_triplet(self, losses, loss_indices): 91 | assert torch.is_tensor(losses) 92 | assert c_f.is_list_or_tuple(loss_indices) 93 | assert len(loss_indices) == 3 94 | assert all(len(x) == len(losses) for x in loss_indices) 95 | 96 | def set_losses_size_stat(self, losses): 97 | if self.collect_stats: 98 | self.add_to_recordable_attributes(name="losses_size", is_stat=True) 99 | if not torch.is_tensor(losses) or losses.ndim == 0: 100 | self.losses_size = 1 101 | else: 102 | self.losses_size = len(losses) 103 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/class_weighted_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .base_reducer import BaseReducer 5 | 6 | 7 | class ClassWeightedReducer(BaseReducer): 8 | def __init__(self, weights, **kwargs): 9 | super().__init__(**kwargs) 10 | self.weights = weights 11 | 12 | def element_reduction(self, losses, loss_indices, embeddings, labels): 13 | return self.element_reduction_helper(losses, loss_indices, labels) 14 | 15 | def pos_pair_reduction(self, losses, loss_indices, embeddings, labels): 16 | return self.element_reduction_helper(losses, loss_indices[0], labels) 17 | 18 | # based on anchor label 19 | def neg_pair_reduction(self, losses, loss_indices, embeddings, labels): 20 | return self.element_reduction_helper(losses, loss_indices[0], labels) 21 | 22 | # based on anchor label 23 | def triplet_reduction(self, losses, loss_indices, embeddings, labels): 24 | return self.element_reduction_helper(losses, loss_indices[0], labels) 25 | 26 | def element_reduction_helper(self, losses, indices, labels): 27 | self.weights = c_f.to_device(self.weights, losses, dtype=losses.dtype) 28 | return torch.mean(losses * self.weights[labels[indices]]) 29 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/divisor_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_reducer import BaseReducer 4 | 5 | 6 | class DivisorReducer(BaseReducer): 7 | def unpack_loss_info(self, loss_info): 8 | losses, loss_indices, reduction_type, kwargs = super().unpack_loss_info( 9 | loss_info 10 | ) 11 | if reduction_type != "already_reduced": 12 | kwargs = {"divisor": loss_info["divisor"]} 13 | self.divisor = kwargs["divisor"] 14 | self.add_to_recordable_attributes(name="divisor", is_stat=True) 15 | return losses, loss_indices, reduction_type, kwargs 16 | 17 | def sum_and_divide(self, losses, embeddings, divisor): 18 | if divisor != 0: 19 | return torch.sum(losses) / divisor 20 | return self.zero_loss(embeddings) 21 | 22 | def element_reduction(self, losses, loss_indices, embeddings, labels, divisor=1): 23 | return self.sum_and_divide(losses, embeddings, divisor) 24 | 25 | def pos_pair_reduction(self, *args, **kwargs): 26 | return self.element_reduction(*args, **kwargs) 27 | 28 | def neg_pair_reduction(self, *args, **kwargs): 29 | return self.element_reduction(*args, **kwargs) 30 | 31 | def triplet_reduction(self, *args, **kwargs): 32 | return self.element_reduction(*args, **kwargs) 33 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/do_nothing_reducer.py: -------------------------------------------------------------------------------- 1 | from .base_reducer import BaseReducer 2 | 3 | 4 | class DoNothingReducer(BaseReducer): 5 | def forward(self, loss_dict, embeddings, labels): 6 | return loss_dict 7 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/mean_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_reducer import BaseReducer 4 | 5 | 6 | class MeanReducer(BaseReducer): 7 | def element_reduction(self, losses, *_): 8 | return torch.mean(losses) 9 | 10 | def pos_pair_reduction(self, losses, *args): 11 | return self.element_reduction(losses, *args) 12 | 13 | def neg_pair_reduction(self, losses, *args): 14 | return self.element_reduction(losses, *args) 15 | 16 | def triplet_reduction(self, losses, *args): 17 | return self.element_reduction(losses, *args) 18 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/multiple_reducers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_reducer import BaseReducer 4 | from .mean_reducer import MeanReducer 5 | 6 | 7 | class MultipleReducers(BaseReducer): 8 | def __init__(self, reducers, default_reducer=None, **kwargs): 9 | super().__init__(**kwargs) 10 | self.reducers = torch.nn.ModuleDict(reducers) 11 | self.default_reducer = ( 12 | MeanReducer() if default_reducer is None else default_reducer 13 | ) 14 | 15 | def forward(self, loss_dict, embeddings, labels): 16 | self.reset_stats() 17 | sub_losses = torch.zeros( 18 | len(loss_dict), dtype=embeddings.dtype, device=embeddings.device 19 | ) 20 | loss_count = 0 21 | for loss_name, loss_info in loss_dict.items(): 22 | input_dict = {loss_name: loss_info} 23 | if loss_name in self.reducers: 24 | loss_val = self.reducers[loss_name](input_dict, embeddings, labels) 25 | else: 26 | loss_val = self.default_reducer(input_dict, embeddings, labels) 27 | sub_losses[loss_count] = loss_val 28 | loss_count += 1 29 | return self.sub_loss_reduction(sub_losses, embeddings, labels) 30 | 31 | def sub_loss_reduction(self, sub_losses, embeddings=None, labels=None): 32 | return torch.sum(sub_losses) 33 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/per_anchor_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .base_reducer import BaseReducer 5 | from .mean_reducer import MeanReducer 6 | 7 | 8 | def aggregation_func(x, num_per_row): 9 | zero_denom = num_per_row == 0 10 | x = torch.sum(x, dim=1) / num_per_row 11 | x[zero_denom] = 0 12 | return x 13 | 14 | 15 | class PerAnchorReducer(BaseReducer): 16 | def __init__(self, reducer=None, aggregation_func=aggregation_func, **kwargs): 17 | super().__init__(**kwargs) 18 | self.reducer = reducer if reducer is not None else MeanReducer() 19 | self.aggregation_func = aggregation_func 20 | 21 | def element_reduction(self, losses, loss_indices, embeddings, labels): 22 | loss_dict = { 23 | "loss": { 24 | "losses": losses, 25 | "indices": loss_indices, 26 | "reduction_type": "element", 27 | } 28 | } 29 | return self.reducer(loss_dict, embeddings, labels) 30 | 31 | def tuple_reduction_helper(self, losses, loss_indices, embeddings, labels): 32 | batch_size = embeddings.shape[0] 33 | device, dtype = losses.device, losses.dtype 34 | new_array = torch.zeros(batch_size, batch_size, device=device, dtype=dtype) 35 | pos_inf = c_f.pos_inf(dtype) 36 | new_array += pos_inf 37 | 38 | anchors, others = loss_indices 39 | new_array[anchors, others] = losses 40 | pos_inf_mask = new_array == pos_inf 41 | num_inf = torch.sum(pos_inf_mask, dim=1) 42 | 43 | new_array[pos_inf_mask] = 0 44 | num_per_row = batch_size - num_inf 45 | output = self.aggregation_func(new_array, num_per_row) 46 | 47 | loss_dict = { 48 | "loss": { 49 | "losses": output, 50 | "indices": c_f.torch_arange_from_size(embeddings), 51 | "reduction_type": "element", 52 | } 53 | } 54 | return self.reducer(loss_dict, embeddings, labels) 55 | 56 | def pos_pair_reduction(self, *args, **kwargs): 57 | return self.tuple_reduction_helper(*args, **kwargs) 58 | 59 | def neg_pair_reduction(self, *args, **kwargs): 60 | return self.tuple_reduction_helper(*args, **kwargs) 61 | 62 | def triplet_reduction(self, *args, **kwargs): 63 | raise NotImplementedError("Triplet reduction not supported") 64 | -------------------------------------------------------------------------------- /pytorch_metric_learning/reducers/threshold_reducer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_reducer import BaseReducer 4 | 5 | 6 | class ThresholdReducer(BaseReducer): 7 | def __init__(self, low=None, high=None, **kwargs): 8 | super().__init__(**kwargs) 9 | assert (low is not None) or ( 10 | high is not None 11 | ), "At least one of low or high must be specified" 12 | self.low = low 13 | self.high = high 14 | if self.low is not None: 15 | self.add_to_recordable_attributes(list_of_names=["low"], is_stat=False) 16 | if self.high is not None: 17 | self.add_to_recordable_attributes(list_of_names=["high"], is_stat=False) 18 | 19 | def element_reduction(self, losses, loss_indices, embeddings, labels): 20 | return self.element_reduction_helper(losses, embeddings, "elements") 21 | 22 | def pos_pair_reduction(self, losses, loss_indices, embeddings, labels): 23 | return self.element_reduction_helper(losses, embeddings, "pos_pairs") 24 | 25 | def neg_pair_reduction(self, losses, loss_indices, embeddings, labels): 26 | return self.element_reduction_helper(losses, embeddings, "neg_pairs") 27 | 28 | def triplet_reduction(self, losses, loss_indices, embeddings, labels): 29 | return self.element_reduction_helper(losses, embeddings, "triplets") 30 | 31 | def element_reduction_helper(self, losses, embeddings, attr_name): 32 | low_condition = losses > self.low if self.low is not None else True 33 | high_condition = losses < self.high if self.high is not None else True 34 | threshold_condition = low_condition & high_condition 35 | num_past_filter = torch.sum(threshold_condition) 36 | if num_past_filter >= 1: 37 | loss = torch.mean(losses[threshold_condition]) 38 | else: 39 | loss = self.zero_loss(embeddings) 40 | self.set_stats(low_condition, high_condition, num_past_filter, attr_name) 41 | return loss 42 | 43 | def set_stats(self, low_condition, high_condition, num_past_filter, attr_name): 44 | if self.collect_stats: 45 | curr_attr_name = "{}_past_filter".format(attr_name) 46 | self.add_to_recordable_attributes(name=curr_attr_name, is_stat=True) 47 | setattr(self, curr_attr_name, num_past_filter.item()) 48 | with torch.no_grad(): 49 | if self.low is not None: 50 | curr_attr_name = "{}_above_low".format(attr_name) 51 | self.add_to_recordable_attributes(name=curr_attr_name, is_stat=True) 52 | setattr(self, curr_attr_name, torch.sum(low_condition).item()) 53 | if self.high is not None: 54 | curr_attr_name = "{}_below_high".format(attr_name) 55 | self.add_to_recordable_attributes(name=curr_attr_name, is_stat=True) 56 | setattr(self, curr_attr_name, torch.sum(high_condition).item()) 57 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_regularizer import BaseRegularizer 2 | from .center_invariant_regularizer import CenterInvariantRegularizer 3 | from .lp_regularizer import LpRegularizer 4 | from .regular_face_regularizer import RegularFaceRegularizer 5 | from .sparse_centers_regularizer import SparseCentersRegularizer 6 | from .zero_mean_regularizer import ZeroMeanRegularizer 7 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/base_regularizer.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_functions as c_f 2 | from ..utils.module_with_records_and_reducer import ModuleWithRecordsReducerAndDistance 3 | 4 | 5 | class BaseRegularizer(ModuleWithRecordsReducerAndDistance): 6 | def compute_loss(self, x): 7 | raise NotImplementedError 8 | 9 | def forward(self, x): 10 | """ 11 | x should have shape (N, embedding_size) 12 | """ 13 | self.reset_stats() 14 | loss_dict = self.compute_loss(x) 15 | return self.reducer(loss_dict, x, c_f.torch_arange_from_size(x)) 16 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/center_invariant_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import LpDistance 4 | from ..utils import common_functions as c_f 5 | from .base_regularizer import BaseRegularizer 6 | 7 | 8 | class CenterInvariantRegularizer(BaseRegularizer): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | c_f.assert_distance_type(self, LpDistance, power=1, normalize_embeddings=False) 12 | 13 | def compute_loss(self, weights): 14 | squared_weight_norms = self.distance.get_norm(weights) ** 2 15 | deviations_from_mean = squared_weight_norms - torch.mean(squared_weight_norms) 16 | return { 17 | "loss": { 18 | "losses": (deviations_from_mean ** 2) / 4, 19 | "indices": c_f.torch_arange_from_size(weights), 20 | "reduction_type": "element", 21 | } 22 | } 23 | 24 | def get_default_distance(self): 25 | return LpDistance(power=1, normalize_embeddings=False) 26 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/lp_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .base_regularizer import BaseRegularizer 5 | 6 | 7 | class LpRegularizer(BaseRegularizer): 8 | def __init__(self, p=2, power=1, **kwargs): 9 | super().__init__(**kwargs) 10 | self.p = p 11 | self.power = power 12 | self.add_to_recordable_attributes(list_of_names=["p", "power"], is_stat=False) 13 | 14 | def compute_loss(self, embeddings): 15 | reg = torch.norm(embeddings, p=self.p, dim=1) 16 | if self.power != 1: 17 | reg = reg ** self.power 18 | return { 19 | "loss": { 20 | "losses": reg, 21 | "indices": c_f.torch_arange_from_size(embeddings), 22 | "reduction_type": "element", 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/regular_face_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..utils import common_functions as c_f 5 | from .base_regularizer import BaseRegularizer 6 | 7 | 8 | # modified from http://kaizhao.net/regularface 9 | class RegularFaceRegularizer(BaseRegularizer): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | assert self.distance.is_inverted 13 | 14 | def compute_loss(self, weights): 15 | dtype, device = weights.dtype, weights.device 16 | num_classes = weights.size(0) 17 | cos = self.distance(weights) 18 | with torch.no_grad(): 19 | cos1 = cos.clone() 20 | cos1.fill_diagonal_(c_f.neg_inf(dtype)) 21 | _, indices = self.distance.smallest_dist(cos1, dim=1) 22 | mask = torch.zeros((num_classes, num_classes), dtype=dtype, device=device) 23 | row_nums = torch.arange(num_classes, device=device, dtype=torch.long) 24 | mask[row_nums, indices] = 1 25 | losses = torch.sum(cos * mask, dim=1) 26 | return { 27 | "loss": { 28 | "losses": losses, 29 | "indices": c_f.torch_arange_from_size(weights), 30 | "reduction_type": "element", 31 | } 32 | } 33 | 34 | def get_default_distance(self): 35 | return CosineSimilarity() 36 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/sparse_centers_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..distances import CosineSimilarity 4 | from ..reducers import DivisorReducer 5 | from ..utils import common_functions as c_f 6 | from .base_regularizer import BaseRegularizer 7 | 8 | 9 | class SparseCentersRegularizer(BaseRegularizer): 10 | def __init__(self, num_classes, centers_per_class, **kwargs): 11 | super().__init__(**kwargs) 12 | assert centers_per_class > 1 13 | c_f.assert_distance_type(self, CosineSimilarity) 14 | self.set_class_masks(num_classes, centers_per_class) 15 | self.add_to_recordable_attributes( 16 | list_of_names=["num_classes", "centers_per_class"], is_stat=False 17 | ) 18 | self.add_to_recordable_attributes( 19 | list_of_names=["same_class_center_sim", "diff_class_center_sim"], 20 | is_stat=True, 21 | ) 22 | 23 | def compute_loss(self, weights): 24 | center_similarities = self.distance(weights) 25 | small_val = c_f.small_val(weights.dtype) 26 | center_similarities_masked = torch.clamp( 27 | 2.0 * center_similarities[self.same_class_mask], max=2 28 | ) 29 | divisor = 2 * torch.sum(self.same_class_mask) 30 | reg = torch.sqrt(2.0 + small_val - center_similarities_masked) 31 | self.set_stats(center_similarities) 32 | return { 33 | "loss": { 34 | "losses": reg, 35 | "indices": c_f.torch_arange_from_size(reg), 36 | "reduction_type": "element", 37 | "divisor": divisor, 38 | } 39 | } 40 | 41 | def set_class_masks(self, num_classes, centers_per_class): 42 | total_num_centers = num_classes * centers_per_class 43 | self.diff_class_mask = torch.ones( 44 | total_num_centers, total_num_centers, dtype=torch.bool 45 | ) 46 | self.same_class_mask = torch.zeros( 47 | total_num_centers, total_num_centers, dtype=torch.bool 48 | ) 49 | for i in range(num_classes): 50 | s, e = i * centers_per_class, (i + 1) * centers_per_class 51 | curr_block = torch.ones(centers_per_class, centers_per_class) 52 | curr_block = torch.triu(curr_block, diagonal=1) 53 | self.same_class_mask[s:e, s:e] = curr_block 54 | self.diff_class_mask[s:e, s:e] = 0 55 | 56 | def set_stats(self, center_similarities): 57 | if self.collect_stats: 58 | with torch.no_grad(): 59 | self.same_class_center_sim = torch.mean( 60 | center_similarities[self.same_class_mask] 61 | ).item() 62 | self.diff_class_center_sim = torch.mean( 63 | center_similarities[self.diff_class_mask] 64 | ).item() 65 | 66 | def get_default_distance(self): 67 | return CosineSimilarity() 68 | 69 | def get_default_reducer(self): 70 | return DivisorReducer() 71 | -------------------------------------------------------------------------------- /pytorch_metric_learning/regularizers/zero_mean_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .base_regularizer import BaseRegularizer 5 | 6 | 7 | class ZeroMeanRegularizer(BaseRegularizer): 8 | def compute_loss(self, embeddings): 9 | return { 10 | "loss": { 11 | "losses": torch.abs(torch.sum(embeddings, dim=1)), 12 | "indices": c_f.torch_arange_from_size(embeddings), 13 | "reduction_type": "element", 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /pytorch_metric_learning/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fixed_set_of_triplets import FixedSetOfTriplets 2 | from .hierarchical_sampler import HierarchicalSampler 3 | from .m_per_class_sampler import MPerClassSampler 4 | from .tuples_to_weights_sampler import TuplesToWeightsSampler 5 | -------------------------------------------------------------------------------- /pytorch_metric_learning/samplers/fixed_set_of_triplets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | from ..utils import common_functions as c_f 6 | 7 | 8 | class FixedSetOfTriplets(Sampler): 9 | """ 10 | Upon initialization, this will create num_triplets triplets based on 11 | the labels provided in labels_to_indices. This is for experimental purposes, 12 | to see how algorithms perform when the only ground truth is a set of 13 | triplets, rather than having explicit labels. 14 | """ 15 | 16 | def __init__(self, labels, num_triplets): 17 | if isinstance(labels, torch.Tensor): 18 | labels = labels.numpy() 19 | self.labels_to_indices = c_f.get_labels_to_indices(labels) 20 | self.num_triplets = int(num_triplets) 21 | self.create_fixed_set_of_triplets() 22 | 23 | def __len__(self): 24 | return self.fixed_set_of_triplets.shape[0] * 3 25 | 26 | def __iter__(self): 27 | c_f.NUMPY_RANDOM.shuffle(self.fixed_set_of_triplets) 28 | flattened = self.fixed_set_of_triplets.flatten().tolist() 29 | return iter(flattened) 30 | 31 | def create_fixed_set_of_triplets(self): 32 | """ 33 | This creates self.fixed_set_of_triplets, which is a numpy array of size 34 | (num_triplets, 3). Each row is a triplet of indices: (a, p, n), where 35 | a=anchor, p=positive, and n=negative. Each triplet is created by first 36 | randomly sampling two classes, then randomly sampling an anchor, positive, 37 | and negative. 38 | """ 39 | assert self.num_triplets > 0 40 | self.fixed_set_of_triplets = np.ones((self.num_triplets, 3), dtype=np.int) * -1 41 | label_list = list(self.labels_to_indices.keys()) 42 | for i in range(self.num_triplets): 43 | anchor_label, negative_label = c_f.NUMPY_RANDOM.choice( 44 | label_list, size=2, replace=False 45 | ) 46 | anchor_list = self.labels_to_indices[anchor_label] 47 | negative_list = self.labels_to_indices[negative_label] 48 | anchor, positive = c_f.safe_random_choice(anchor_list, size=2) 49 | negative = c_f.NUMPY_RANDOM.choice(negative_list, replace=False) 50 | self.fixed_set_of_triplets[i, :] = np.array([anchor, positive, negative]) 51 | -------------------------------------------------------------------------------- /pytorch_metric_learning/samplers/hierarchical_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import defaultdict 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | 7 | from ..utils import common_functions as c_f 8 | 9 | 10 | # Inspired by 11 | # https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py 12 | class HierarchicalSampler(BatchSampler): 13 | def __init__( 14 | self, 15 | labels, 16 | batch_size, 17 | samples_per_class, 18 | batches_per_super_tuple=4, 19 | super_classes_per_batch=2, 20 | inner_label=0, 21 | outer_label=1, 22 | ): 23 | """ 24 | labels: 2D array, where rows correspond to elements, and columns correspond to the hierarchical labels 25 | batch_size: because this is a BatchSampler the batch size must be specified 26 | samples_per_class: number of instances to sample for a specific class. set to "all" if all element in a class 27 | batches_per_super_tuples: number of batches to create for a pair of categories (or super labels) 28 | inner_label: columns index corresponding to classes 29 | outer_label: columns index corresponding to the level of hierarchy for the pairs 30 | """ 31 | if torch.is_tensor(labels): 32 | labels = labels.cpu().numpy() 33 | 34 | self.batch_size = batch_size 35 | self.batches_per_super_tuple = batches_per_super_tuple 36 | self.samples_per_class = samples_per_class 37 | self.super_classes_per_batch = super_classes_per_batch 38 | 39 | # checks 40 | assert ( 41 | self.batch_size % super_classes_per_batch == 0 42 | ), f"batch_size should be a multiple of {super_classes_per_batch}" 43 | self.sub_batch_len = self.batch_size // super_classes_per_batch 44 | 45 | if self.samples_per_class != "all": 46 | assert self.samples_per_class > 0 47 | assert ( 48 | self.sub_batch_len % self.samples_per_class == 0 49 | ), "batch_size not a multiple of samples_per_class" 50 | 51 | all_super_labels = set(labels[:, outer_label]) 52 | self.super_image_lists = {slb: defaultdict(list) for slb in all_super_labels} 53 | for idx, instance in enumerate(labels): 54 | slb, lb = instance[outer_label], instance[inner_label] 55 | self.super_image_lists[slb][lb].append(idx) 56 | 57 | self.super_pairs = list( 58 | itertools.combinations(all_super_labels, super_classes_per_batch) 59 | ) 60 | self.reshuffle() 61 | 62 | def __iter__( 63 | self, 64 | ): 65 | self.reshuffle() 66 | for batch in self.batches: 67 | yield batch 68 | 69 | def __len__( 70 | self, 71 | ): 72 | return len(self.batches) 73 | 74 | def reshuffle(self): 75 | batches = [] 76 | for combinations in self.super_pairs: 77 | 78 | for b in range(self.batches_per_super_tuple): 79 | 80 | batch = [] 81 | for slb in combinations: 82 | 83 | sub_batch = [] 84 | all_classes = list(self.super_image_lists[slb].keys()) 85 | c_f.NUMPY_RANDOM.shuffle(all_classes) 86 | for cl in all_classes: 87 | if len(sub_batch) >= self.sub_batch_len: 88 | break 89 | instances = self.super_image_lists[slb][cl] 90 | samples_per_class = ( 91 | self.samples_per_class 92 | if self.samples_per_class != "all" 93 | else len(instances) 94 | ) 95 | if len(sub_batch) + samples_per_class > self.sub_batch_len: 96 | continue 97 | sub_batch.extend( 98 | c_f.safe_random_choice(instances, size=samples_per_class) 99 | ) 100 | 101 | batch.extend(sub_batch) 102 | 103 | c_f.NUMPY_RANDOM.shuffle(batch) 104 | batches.append(batch) 105 | 106 | c_f.NUMPY_RANDOM.shuffle(batches) 107 | self.batches = batches 108 | -------------------------------------------------------------------------------- /pytorch_metric_learning/samplers/m_per_class_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import Sampler 3 | 4 | from ..utils import common_functions as c_f 5 | 6 | 7 | # modified from 8 | # https://raw.githubusercontent.com/bnulihaixia/Deep_metric/master/utils/sampler.py 9 | class MPerClassSampler(Sampler): 10 | """ 11 | At every iteration, this will return m samples per class. For example, 12 | if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples 13 | each will be returned 14 | """ 15 | 16 | def __init__(self, labels, m, batch_size=None, length_before_new_iter=100000): 17 | if isinstance(labels, torch.Tensor): 18 | labels = labels.numpy() 19 | self.m_per_class = int(m) 20 | self.batch_size = int(batch_size) if batch_size is not None else batch_size 21 | self.labels_to_indices = c_f.get_labels_to_indices(labels) 22 | self.labels = list(self.labels_to_indices.keys()) 23 | self.length_of_single_pass = self.m_per_class * len(self.labels) 24 | self.list_size = length_before_new_iter 25 | if self.batch_size is None: 26 | if self.length_of_single_pass < self.list_size: 27 | self.list_size -= (self.list_size) % (self.length_of_single_pass) 28 | else: 29 | assert self.list_size >= self.batch_size 30 | assert ( 31 | self.length_of_single_pass >= self.batch_size 32 | ), "m * (number of unique labels) must be >= batch_size" 33 | assert ( 34 | self.batch_size % self.m_per_class 35 | ) == 0, "m_per_class must divide batch_size without any remainder" 36 | self.list_size -= self.list_size % self.batch_size 37 | 38 | def __len__(self): 39 | return self.list_size 40 | 41 | def __iter__(self): 42 | idx_list = [0] * self.list_size 43 | i = 0 44 | num_iters = self.calculate_num_iters() 45 | for _ in range(num_iters): 46 | c_f.NUMPY_RANDOM.shuffle(self.labels) 47 | if self.batch_size is None: 48 | curr_label_set = self.labels 49 | else: 50 | curr_label_set = self.labels[: self.batch_size // self.m_per_class] 51 | for label in curr_label_set: 52 | t = self.labels_to_indices[label] 53 | idx_list[i : i + self.m_per_class] = c_f.safe_random_choice( 54 | t, size=self.m_per_class 55 | ) 56 | i += self.m_per_class 57 | return iter(idx_list) 58 | 59 | def calculate_num_iters(self): 60 | divisor = ( 61 | self.length_of_single_pass if self.batch_size is None else self.batch_size 62 | ) 63 | return self.list_size // divisor if divisor < self.list_size else 1 64 | -------------------------------------------------------------------------------- /pytorch_metric_learning/samplers/tuples_to_weights_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | from ..testers import BaseTester 6 | from ..utils import common_functions as c_f 7 | from ..utils import loss_and_miner_utils as lmu 8 | 9 | 10 | class TuplesToWeightsSampler(Sampler): 11 | def __init__(self, model, miner, dataset, subset_size=None, **tester_kwargs): 12 | self.model = model 13 | self.miner = miner 14 | self.dataset = dataset 15 | self.subset_size = subset_size 16 | self.tester = BaseTester(**tester_kwargs) 17 | self.device = self.tester.data_device 18 | self.weights = None 19 | 20 | def __len__(self): 21 | if self.subset_size: 22 | return self.subset_size 23 | return len(self.dataset) 24 | 25 | def __iter__(self): 26 | c_f.LOGGER.info("Computing embeddings in {}".format(self.__class__.__name__)) 27 | 28 | if self.subset_size: 29 | indices = c_f.safe_random_choice( 30 | np.arange(len(self.dataset)), size=self.subset_size 31 | ) 32 | curr_dataset = torch.utils.data.Subset(self.dataset, indices) 33 | else: 34 | indices = torch.arange(len(self.dataset), device=self.device) 35 | curr_dataset = self.dataset 36 | 37 | embeddings, labels = self.tester.get_all_embeddings(curr_dataset, self.model) 38 | labels = labels.squeeze(1) 39 | hard_tuples = self.miner(embeddings, labels) 40 | 41 | self.weights = torch.zeros(len(self.dataset), device=self.device) 42 | self.weights[indices] = lmu.convert_to_weights( 43 | hard_tuples, labels, dtype=torch.float32 44 | ) 45 | return iter( 46 | torch.utils.data.WeightedRandomSampler( 47 | self.weights, self.__len__(), replacement=True 48 | ) 49 | ) 50 | -------------------------------------------------------------------------------- /pytorch_metric_learning/testers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_tester import BaseTester 2 | from .global_embedding_space import GlobalEmbeddingSpaceTester 3 | from .global_twostream_embedding_space import GlobalTwoStreamEmbeddingSpaceTester 4 | from .with_same_parent_label import WithSameParentLabelTester 5 | -------------------------------------------------------------------------------- /pytorch_metric_learning/testers/global_embedding_space.py: -------------------------------------------------------------------------------- 1 | from .base_tester import BaseTester 2 | 3 | 4 | class GlobalEmbeddingSpaceTester(BaseTester): 5 | def do_knn_and_accuracies( 6 | self, accuracies, embeddings_and_labels, query_split_name, reference_split_names 7 | ): 8 | ( 9 | query_embeddings, 10 | query_labels, 11 | reference_embeddings, 12 | reference_labels, 13 | ) = self.set_reference_and_query( 14 | embeddings_and_labels, query_split_name, reference_split_names 15 | ) 16 | self.label_levels = self.label_levels_to_evaluate(query_labels) 17 | 18 | for L in self.label_levels: 19 | curr_query_labels = query_labels[:, L] 20 | curr_reference_labels = reference_labels[:, L] 21 | a = self.accuracy_calculator.get_accuracy( 22 | query_embeddings, 23 | reference_embeddings, 24 | curr_query_labels, 25 | curr_reference_labels, 26 | self.embeddings_come_from_same_source( 27 | query_split_name, reference_split_names 28 | ), 29 | ) 30 | for metric, v in a.items(): 31 | keyname = self.accuracies_keyname(metric, label_hierarchy_level=L) 32 | accuracies[keyname] = v 33 | if len(self.label_levels) > 1: 34 | self.calculate_average_accuracies( 35 | accuracies, 36 | self.accuracy_calculator.get_curr_metrics(), 37 | self.label_levels, 38 | ) 39 | -------------------------------------------------------------------------------- /pytorch_metric_learning/testers/global_twostream_embedding_space.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | from ..utils import common_functions as c_f 5 | from .global_embedding_space import GlobalEmbeddingSpaceTester 6 | 7 | 8 | class GlobalTwoStreamEmbeddingSpaceTester(GlobalEmbeddingSpaceTester): 9 | def compute_all_embeddings(self, dataloader, trunk_model, embedder_model): 10 | s, e = 0, 0 11 | with torch.no_grad(): 12 | for i, data in enumerate(tqdm.tqdm(dataloader)): 13 | anchors, posnegs, label = self.data_and_label_getter(data) 14 | label = c_f.process_label( 15 | label, self.label_hierarchy_level, self.label_mapper 16 | ) 17 | a = self.get_embeddings_for_eval(trunk_model, embedder_model, anchors) 18 | pns = self.get_embeddings_for_eval(trunk_model, embedder_model, posnegs) 19 | if label.dim() == 1: 20 | label = label.unsqueeze(1) 21 | if i == 0: 22 | labels = torch.zeros(len(dataloader.dataset), label.size(1)) 23 | all_anchors = torch.zeros(len(dataloader.dataset), pns.size(1)) 24 | all_posnegs = torch.zeros(len(dataloader.dataset), pns.size(1)) 25 | 26 | e = s + pns.size(0) 27 | all_anchors[s:e] = a 28 | all_posnegs[s:e] = pns 29 | labels[s:e] = label 30 | s = e 31 | return all_anchors, all_posnegs, labels 32 | 33 | def get_all_embeddings(self, dataset, trunk_model, embedder_model, collate_fn): 34 | dataloader = c_f.get_eval_dataloader( 35 | dataset, self.batch_size, self.dataloader_num_workers, collate_fn 36 | ) 37 | anchor_embeddings, posneg_embeddings, labels = self.compute_all_embeddings( 38 | dataloader, trunk_model, embedder_model 39 | ) 40 | anchor_embeddings, posneg_embeddings = ( 41 | self.maybe_normalize(anchor_embeddings), 42 | self.maybe_normalize(posneg_embeddings), 43 | ) 44 | return ( 45 | torch.cat([anchor_embeddings, posneg_embeddings], dim=0), 46 | torch.cat([labels, labels], dim=0), 47 | ) 48 | 49 | def set_reference_and_query( 50 | self, embeddings_and_labels, query_split_name, reference_split_names 51 | ): 52 | assert ( 53 | query_split_name == reference_split_names[0] 54 | and len(reference_split_names) == 1 55 | ), "{} does not support different reference and query splits".format( 56 | self.__class__.__name__ 57 | ) 58 | embeddings, labels = embeddings_and_labels[query_split_name] 59 | half = int(embeddings.shape[0] / 2) 60 | anchors_embeddings = embeddings[:half] 61 | posneg_embeddings = embeddings[half:] 62 | query_labels = labels[:half] 63 | return anchors_embeddings, query_labels, posneg_embeddings, query_labels 64 | 65 | def embeddings_come_from_same_source(self, query_split_name, reference_split_names): 66 | return False 67 | -------------------------------------------------------------------------------- /pytorch_metric_learning/testers/with_same_parent_label.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ..utils import common_functions as c_f 7 | from .base_tester import BaseTester 8 | 9 | 10 | class WithSameParentLabelTester(BaseTester): 11 | def do_knn_and_accuracies( 12 | self, 13 | accuracies, 14 | embeddings_and_labels, 15 | query_split_name, 16 | reference_split_names, 17 | tag_suffix="", 18 | ): 19 | ( 20 | query_embeddings, 21 | query_labels, 22 | reference_embeddings, 23 | reference_labels, 24 | ) = self.set_reference_and_query( 25 | embeddings_and_labels, query_split_name, reference_split_names 26 | ) 27 | self.label_levels = [ 28 | L 29 | for L in self.label_levels_to_evaluate(query_labels) 30 | if L < query_labels.shape[1] - 1 31 | ] 32 | assert ( 33 | len(self.label_levels) > 0 34 | ), """There are no valid label levels to evaluate. Make sure you've set label_hierarchy_level correctly. 35 | If it is an integer, it must be less than the number of hierarchy levels minus 1.""" 36 | 37 | for L in self.label_levels: 38 | curr_query_parent_labels = query_labels[:, L + 1] 39 | curr_reference_parent_labels = reference_labels[:, L + 1] 40 | average_accuracies = defaultdict(list) 41 | for parent_label in torch.unique(curr_query_parent_labels): 42 | c_f.LOGGER.info( 43 | "Label level {} and parent label {}".format(L, parent_label) 44 | ) 45 | query_match = curr_query_parent_labels == parent_label 46 | reference_match = curr_reference_parent_labels == parent_label 47 | curr_query_labels = query_labels[:, L][query_match] 48 | curr_reference_labels = reference_labels[:, L][reference_match] 49 | curr_query_embeddings = query_embeddings[query_match] 50 | curr_reference_embeddings = reference_embeddings[reference_match] 51 | a = self.accuracy_calculator.get_accuracy( 52 | curr_query_embeddings, 53 | curr_reference_embeddings, 54 | curr_query_labels, 55 | curr_reference_labels, 56 | self.embeddings_come_from_same_source( 57 | query_split_name, reference_split_names 58 | ), 59 | ) 60 | for metric, v in a.items(): 61 | average_accuracies[metric].append(v) 62 | for metric, v in average_accuracies.items(): 63 | keyname = self.accuracies_keyname(metric, label_hierarchy_level=L) 64 | accuracies[keyname] = np.mean(v) 65 | 66 | if len(self.label_levels) > 1: 67 | self.calculate_average_accuracies( 68 | accuracies, 69 | self.accuracy_calculator.get_curr_metrics(), 70 | self.label_levels, 71 | ) 72 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .cascaded_embeddings import CascadedEmbeddings 3 | from .deep_adversarial_metric_learning import DeepAdversarialMetricLearning 4 | from .metric_loss_only import MetricLossOnly 5 | from .train_with_classifier import TrainWithClassifier 6 | from .twostream_metric_loss import TwoStreamMetricLoss 7 | from .unsupervised_embeddings_using_augmentations import ( 8 | UnsupervisedEmbeddingsUsingAugmentations, 9 | ) 10 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/cascaded_embeddings.py: -------------------------------------------------------------------------------- 1 | from .. import miners 2 | from ..utils import common_functions as c_f 3 | from .base_trainer import BaseTrainer 4 | 5 | 6 | class CascadedEmbeddings(BaseTrainer): 7 | def __init__(self, embedding_sizes, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.embedding_sizes = embedding_sizes 10 | 11 | def calculate_loss(self, curr_batch): 12 | data, labels = curr_batch 13 | embeddings = self.compute_embeddings(data) 14 | s = 0 15 | logits = [] 16 | indices_tuple = None 17 | for i, curr_size in enumerate(self.embedding_sizes): 18 | curr_loss_name = "metric_loss_%d" % i 19 | curr_miner_name = "tuple_miner_%d" % i 20 | curr_classifier_name = "classifier_%d" % i 21 | 22 | e = embeddings[:, s : s + curr_size] 23 | indices_tuple = self.maybe_mine_embeddings( 24 | e, labels, indices_tuple, curr_miner_name 25 | ) 26 | self.losses[curr_loss_name] += self.maybe_get_metric_loss( 27 | e, labels, indices_tuple, curr_loss_name 28 | ) 29 | logits.append(self.maybe_get_logits(e, curr_classifier_name)) 30 | s += curr_size 31 | 32 | for i, L in enumerate(logits): 33 | if L is None: 34 | continue 35 | curr_loss_name = "classifier_loss_%d" % i 36 | self.losses[curr_loss_name] += self.maybe_get_classifier_loss( 37 | L, labels, curr_loss_name 38 | ) 39 | 40 | def maybe_get_metric_loss(self, embeddings, labels, indices_tuple, curr_loss_name): 41 | if self.loss_weights.get(curr_loss_name, 0) > 0: 42 | return self.loss_funcs[curr_loss_name](embeddings, labels, indices_tuple) 43 | return 0 44 | 45 | def maybe_mine_embeddings( 46 | self, embeddings, labels, prev_indices_tuple, curr_miner_name 47 | ): 48 | if curr_miner_name in self.mining_funcs: 49 | curr_miner = self.mining_funcs[curr_miner_name] 50 | if isinstance(curr_miner, miners.HDCMiner): 51 | curr_miner.set_idx_externally(prev_indices_tuple, labels) 52 | curr_indices_tuple = curr_miner(embeddings, labels) 53 | curr_miner.reset_idx() 54 | else: 55 | curr_indices_tuple = curr_miner(embeddings, labels) 56 | return curr_indices_tuple 57 | return None 58 | 59 | def maybe_get_logits(self, embeddings, curr_classifier_name): 60 | if self.models.get(curr_classifier_name, None): 61 | return self.models[curr_classifier_name](embeddings) 62 | return None 63 | 64 | def maybe_get_classifier_loss(self, logits, labels, curr_loss_name): 65 | if self.loss_weights.get(curr_loss_name, 0) > 0: 66 | return self.loss_funcs[curr_loss_name]( 67 | logits, c_f.to_device(labels, logits) 68 | ) 69 | return 0 70 | 71 | def modify_schema(self): 72 | self.schema["models"].keys += ["classifier_[0-9]+"] 73 | self.schema["loss_funcs"].keys = [ 74 | "metric_loss_[0-9]+", 75 | "classifier_loss_[0-9]+", 76 | ] 77 | self.schema["mining_funcs"].keys = ["tuple_miner_[0-9]+"] 78 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/metric_loss_only.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | 3 | 4 | class MetricLossOnly(BaseTrainer): 5 | def calculate_loss(self, curr_batch): 6 | data, labels = curr_batch 7 | embeddings = self.compute_embeddings(data) 8 | indices_tuple = self.maybe_mine_embeddings(embeddings, labels) 9 | self.losses["metric_loss"] = self.maybe_get_metric_loss( 10 | embeddings, labels, indices_tuple 11 | ) 12 | 13 | def maybe_get_metric_loss(self, embeddings, labels, indices_tuple): 14 | if self.loss_weights.get("metric_loss", 0) > 0: 15 | return self.loss_funcs["metric_loss"](embeddings, labels, indices_tuple) 16 | return 0 17 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/train_with_classifier.py: -------------------------------------------------------------------------------- 1 | from ..utils import common_functions as c_f 2 | from .metric_loss_only import MetricLossOnly 3 | 4 | 5 | class TrainWithClassifier(MetricLossOnly): 6 | def calculate_loss(self, curr_batch): 7 | data, labels = curr_batch 8 | embeddings = self.compute_embeddings(data) 9 | logits = self.maybe_get_logits(embeddings) 10 | indices_tuple = self.maybe_mine_embeddings(embeddings, labels) 11 | self.losses["metric_loss"] = self.maybe_get_metric_loss( 12 | embeddings, labels, indices_tuple 13 | ) 14 | self.losses["classifier_loss"] = self.maybe_get_classifier_loss(logits, labels) 15 | 16 | def maybe_get_classifier_loss(self, logits, labels): 17 | if logits is not None: 18 | return self.loss_funcs["classifier_loss"]( 19 | logits, c_f.to_device(labels, logits) 20 | ) 21 | return 0 22 | 23 | def maybe_get_logits(self, embeddings): 24 | if ( 25 | self.models.get("classifier", None) 26 | and self.loss_weights.get("classifier_loss", 0) > 0 27 | ): 28 | return self.models["classifier"](embeddings) 29 | return None 30 | 31 | def modify_schema(self): 32 | self.schema["models"].keys += ["classifier"] 33 | self.schema["loss_funcs"].keys += ["classifier_loss"] 34 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/twostream_metric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from ..utils import loss_and_miner_utils as lmu 5 | from .base_trainer import BaseTrainer 6 | 7 | 8 | class TwoStreamMetricLoss(BaseTrainer): 9 | def calculate_loss(self, curr_batch): 10 | (anchors, posnegs), labels = curr_batch 11 | embeddings = ( 12 | self.compute_embeddings(anchors), 13 | self.compute_embeddings(posnegs), 14 | ) 15 | 16 | indices_tuple = self.maybe_mine_embeddings(embeddings, labels) 17 | self.losses["metric_loss"] = self.maybe_get_metric_loss( 18 | embeddings, labels, indices_tuple 19 | ) 20 | 21 | def get_batch(self): 22 | self.dataloader_iter, curr_batch = c_f.try_next_on_generator( 23 | self.dataloader_iter, self.dataloader 24 | ) 25 | anchors, posnegs, labels = self.data_and_label_getter(curr_batch) 26 | data = (anchors, posnegs) 27 | labels = c_f.process_label( 28 | labels, self.label_hierarchy_level, self.label_mapper 29 | ) 30 | return self.maybe_do_batch_mining(data, labels) 31 | 32 | def maybe_get_metric_loss(self, embeddings, labels, indices_tuple): 33 | if self.loss_weights.get("metric_loss", 0) > 0: 34 | current_batch_size = embeddings[0].shape[0] 35 | indices_tuple = c_f.shift_indices_tuple(indices_tuple, current_batch_size) 36 | all_labels = torch.cat([labels, labels], dim=0) 37 | all_embeddings = torch.cat(embeddings, dim=0) 38 | return self.loss_funcs["metric_loss"]( 39 | all_embeddings, all_labels, indices_tuple 40 | ) 41 | return 0 42 | 43 | def maybe_mine_embeddings(self, embeddings, labels): 44 | # for both get_all_triplets_indices and mining_funcs 45 | # we need to clone labels and pass them as ref_labels 46 | # to ensure triplets are generated between anchors and posnegs 47 | if "tuple_miner" in self.mining_funcs: 48 | (anchors_embeddings, posnegs_embeddings) = embeddings 49 | return self.mining_funcs["tuple_miner"]( 50 | anchors_embeddings, labels, posnegs_embeddings, labels.clone() 51 | ) 52 | else: 53 | return lmu.get_all_triplets_indices(labels, labels.clone()) 54 | 55 | def modify_schema(self): 56 | self.schema["mining_funcs"].keys = ["tuple_miner"] 57 | -------------------------------------------------------------------------------- /pytorch_metric_learning/trainers/unsupervised_embeddings_using_augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import common_functions as c_f 4 | from .metric_loss_only import MetricLossOnly 5 | 6 | 7 | class UnsupervisedEmbeddingsUsingAugmentations(MetricLossOnly): 8 | def __init__(self, transforms, data_and_label_setter=None, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.data_and_label_setter = data_and_label_setter 11 | self.initialize_data_and_label_setter() 12 | self.transforms = transforms 13 | self.collate_fn = self.custom_collate_fn 14 | self.initialize_dataloader() 15 | c_f.LOGGER.info("Transforms: %s" % transforms) 16 | 17 | def initialize_data_and_label_setter(self): 18 | if self.data_and_label_setter is None: 19 | self.data_and_label_setter = c_f.return_input 20 | 21 | def custom_collate_fn(self, data): 22 | transformed_data, labels = [], [] 23 | for i, d in enumerate(data): 24 | img, _ = self.data_and_label_getter(d) 25 | for t in self.transforms: 26 | transformed_data.append(t(img)) 27 | labels.append(i) 28 | return self.data_and_label_setter( 29 | (torch.stack(transformed_data, dim=0), torch.LongTensor(labels)) 30 | ) 31 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__init__.py -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__pycache__/common_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__pycache__/common_functions.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__pycache__/loss_and_miner_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__pycache__/loss_and_miner_utils.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__pycache__/module_with_records.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__pycache__/module_with_records.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/__pycache__/module_with_records_and_reducer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/pytorch_metric_learning/utils/__pycache__/module_with_records_and_reducer.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel import DistributedDataParallel as DDP 3 | 4 | from ..utils import common_functions as c_f 5 | 6 | 7 | # modified from https://github.com/allenai/allennlp 8 | def is_distributed(): 9 | return torch.distributed.is_available() and torch.distributed.is_initialized() 10 | 11 | 12 | # modified from https://github.com/JohnGiorgi/DeCLUTR 13 | def all_gather(embeddings, labels): 14 | labels = c_f.to_device(labels, embeddings) 15 | # If we are not using distributed training, this is a no-op. 16 | if not is_distributed(): 17 | return embeddings, labels 18 | world_size = torch.distributed.get_world_size() 19 | rank = torch.distributed.get_rank() 20 | # Gather the embeddings on all replicas 21 | embeddings_list = [torch.ones_like(embeddings) for _ in range(world_size)] 22 | labels_list = [torch.ones_like(labels) for _ in range(world_size)] 23 | torch.distributed.all_gather(embeddings_list, embeddings.contiguous()) 24 | torch.distributed.all_gather(labels_list, labels.contiguous()) 25 | # The gathered copy of the current replicas embeddings have no gradients, so we overwrite 26 | # them with the embeddings generated on this replica, which DO have gradients. 27 | embeddings_list[rank] = embeddings 28 | labels_list[rank] = labels 29 | # Finally, we concatenate the embeddings 30 | embeddings = torch.cat(embeddings_list) 31 | labels = torch.cat(labels_list) 32 | return embeddings, labels 33 | 34 | 35 | def all_gather_embeddings_labels(embeddings, labels): 36 | if c_f.is_list_or_tuple(embeddings): 37 | assert c_f.is_list_or_tuple(labels) 38 | all_embeddings, all_labels = [], [] 39 | for i in range(len(embeddings)): 40 | E, L = all_gather(embeddings[i], labels[i]) 41 | all_embeddings.append(E) 42 | all_labels.append(L) 43 | embeddings = torch.cat(all_embeddings, dim=0) 44 | labels = torch.cat(all_labels, dim=0) 45 | else: 46 | embeddings, labels = all_gather(embeddings, labels) 47 | 48 | return embeddings, labels 49 | 50 | 51 | class DistributedLossWrapper(torch.nn.Module): 52 | def __init__(self, loss, **kwargs): 53 | super().__init__() 54 | has_parameters = len([p for p in loss.parameters()]) > 0 55 | self.loss = DDP(loss, **kwargs) if has_parameters else loss 56 | 57 | def forward(self, embeddings, labels, *args, **kwargs): 58 | embeddings, labels = all_gather_embeddings_labels(embeddings, labels) 59 | return self.loss(embeddings, labels, *args, **kwargs) 60 | 61 | 62 | class DistributedMinerWrapper(torch.nn.Module): 63 | def __init__(self, miner): 64 | super().__init__() 65 | self.miner = miner 66 | 67 | def forward(self, embeddings, labels, ref_emb=None, ref_labels=None): 68 | embeddings, labels = all_gather_embeddings_labels(embeddings, labels) 69 | if ref_emb is not None: 70 | ref_emb, ref_labels = all_gather_embeddings_labels(ref_emb, ref_labels) 71 | return self.miner(embeddings, labels, ref_emb, ref_labels) 72 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/key_checker.py: -------------------------------------------------------------------------------- 1 | from . import common_functions as c_f 2 | 3 | 4 | class KeyCheckerDict: 5 | def __init__(self, children): 6 | self.children = children 7 | 8 | def __getitem__(self, key): 9 | return self.children[key] 10 | 11 | def __setitem__(self, key, value): 12 | self.children[key] = value 13 | 14 | def verify(self, obj): 15 | for k, v in self.children.items(): 16 | self._verify_prop(getattr(obj, k, None), k, v) 17 | 18 | def _verify_prop(self, obj, obj_name, s): 19 | val = lambda x: x(s, self.children) if callable(x) else x 20 | 21 | if s.warn_empty and obj in [None, {}]: 22 | c_f.LOGGER.warning("%s is empty" % obj_name) 23 | if obj is not None: 24 | keys = val(s.keys) 25 | for k in obj.keys(): 26 | assert any( 27 | pattern.match(k) for pattern in c_f.regex_wrapper(keys) 28 | ), "%s keys must be one of %s" % (obj_name, ", ".join(keys)) 29 | for imp_key in val(s.important): 30 | if not any(c_f.regex_wrapper(imp_key).match(k) for k in obj): 31 | c_f.LOGGER.warning('%s is missing "%s"' % (obj_name, imp_key)) 32 | for ess_key in val(s.essential): 33 | assert any( 34 | c_f.regex_wrapper(ess_key).match(k) for k in obj 35 | ), '%s must contain "%s"' % (obj_name, ess_key) 36 | 37 | 38 | def default_important(s, d): 39 | return c_f.exclude(s.keys, s.essential) 40 | 41 | 42 | # We can make this a dataclass, if we target 3.7+ 43 | class KeyChecker: 44 | def __init__( 45 | self, 46 | keys, 47 | warn_empty=True, 48 | important=default_important, 49 | essential=None, 50 | ): 51 | self.keys = keys 52 | self.warn_empty = warn_empty 53 | self.important = important 54 | self.essential = essential 55 | if self.essential is None: 56 | self.essential = [] 57 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/loss_tracker.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | 4 | class LossTracker: 5 | def __init__(self, loss_names): 6 | if "total_loss" not in loss_names: 7 | loss_names.append("total_loss") 8 | self.losses = {key: 0 for key in loss_names} 9 | self.loss_weights = {key: 1 for key in loss_names} 10 | 11 | def weight_the_losses(self, exclude_loss=("total_loss",)): 12 | for k, _ in self.losses.items(): 13 | if k not in exclude_loss: 14 | self.losses[k] *= self.loss_weights[k] 15 | 16 | def get_total_loss(self, exclude_loss=("total_loss",)): 17 | self.losses["total_loss"] = 0 18 | for k, v in self.losses.items(): 19 | if k not in exclude_loss: 20 | self.losses["total_loss"] += v 21 | 22 | def set_loss_weights(self, loss_weight_dict): 23 | for k, _ in self.losses.items(): 24 | if k in loss_weight_dict: 25 | w = loss_weight_dict[k] 26 | else: 27 | w = 1.0 28 | self.loss_weights[k] = w 29 | 30 | def update(self, loss_weight_dict): 31 | self.set_loss_weights(loss_weight_dict) 32 | self.weight_the_losses() 33 | self.get_total_loss() 34 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/module_with_records.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import common_functions as c_f 4 | 5 | 6 | class ModuleWithRecords(torch.nn.Module): 7 | def __init__(self, collect_stats=c_f.COLLECT_STATS): 8 | super().__init__() 9 | self.collect_stats = collect_stats 10 | 11 | def add_to_recordable_attributes( 12 | self, name=None, list_of_names=None, is_stat=False 13 | ): 14 | if is_stat and not self.collect_stats: 15 | pass 16 | else: 17 | c_f.add_to_recordable_attributes( 18 | self, name=name, list_of_names=list_of_names, is_stat=is_stat 19 | ) 20 | 21 | def reset_stats(self): 22 | c_f.reset_stats(self) 23 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/module_with_records_and_reducer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from ..distances import LpDistance 4 | from ..reducers import DoNothingReducer, MeanReducer, MultipleReducers 5 | from .module_with_records import ModuleWithRecords 6 | 7 | 8 | class ModuleWithRecordsAndReducer(ModuleWithRecords): 9 | def __init__(self, reducer=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.set_reducer(reducer) 12 | 13 | def get_default_reducer(self): 14 | return MeanReducer() 15 | 16 | def set_reducer(self, reducer): 17 | if isinstance(reducer, (MultipleReducers, DoNothingReducer)): 18 | self.reducer = reducer 19 | elif len(self.sub_loss_names()) == 1: 20 | self.reducer = ( 21 | self.get_default_reducer() 22 | if reducer is None 23 | else copy.deepcopy(reducer) 24 | ) 25 | else: 26 | reducer_dict = {} 27 | for k in self.sub_loss_names(): 28 | reducer_dict[k] = ( 29 | self.get_default_reducer() 30 | if reducer is None 31 | else copy.deepcopy(reducer) 32 | ) 33 | self.reducer = MultipleReducers(reducer_dict) 34 | 35 | def sub_loss_names(self): 36 | return ["loss"] 37 | 38 | 39 | class ModuleWithRecordsAndDistance(ModuleWithRecords): 40 | def __init__(self, distance=None, **kwargs): 41 | super().__init__(**kwargs) 42 | self.distance = self.get_distance() if distance is None else distance 43 | 44 | def get_default_distance(self): 45 | return LpDistance(p=2) 46 | 47 | def get_distance(self): 48 | return self.get_default_distance() 49 | 50 | 51 | class ModuleWithRecordsReducerAndDistance( 52 | ModuleWithRecordsAndReducer, ModuleWithRecordsAndDistance 53 | ): 54 | def __init__(self, **kwargs): 55 | super().__init__(**kwargs) 56 | -------------------------------------------------------------------------------- /pytorch_metric_learning/utils/stat_utils.py: -------------------------------------------------------------------------------- 1 | from . import common_functions as c_f 2 | 3 | try: 4 | import faiss 5 | except ModuleNotFoundError: 6 | c_f.LOGGER.warning( 7 | """The pytorch-metric-learning testing module requires faiss. You can install the GPU version with the command 'conda install faiss-gpu -c pytorch' 8 | or the CPU version with 'conda install faiss-cpu -c pytorch'. Learn more at https://github.com/facebookresearch/faiss/blob/master/INSTALL.md""" 9 | ) 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def add_to_index_and_search(index, reference_embeddings, test_embeddings, k): 16 | index.add(reference_embeddings) 17 | return index.search(test_embeddings, k) 18 | 19 | 20 | def try_gpu(cpu_index, reference_embeddings, test_embeddings, k): 21 | # https://github.com/facebookresearch/faiss/blob/master/faiss/gpu/utils/DeviceDefs.cuh 22 | gpu_index = None 23 | gpus_are_available = faiss.get_num_gpus() > 0 24 | if gpus_are_available: 25 | max_k_for_gpu = 1024 if float(torch.version.cuda) < 9.5 else 2048 26 | if k <= max_k_for_gpu: 27 | gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 28 | try: 29 | return add_to_index_and_search( 30 | gpu_index, reference_embeddings, test_embeddings, k 31 | ) 32 | except (AttributeError, RuntimeError) as e: 33 | if gpus_are_available: 34 | c_f.LOGGER.warning( 35 | f"Using CPU for k-nn search because k = {k} > {max_k_for_gpu}, which is the maximum allowable on GPU." 36 | ) 37 | return add_to_index_and_search( 38 | cpu_index, reference_embeddings, test_embeddings, k 39 | ) 40 | 41 | 42 | # modified from https://github.com/facebookresearch/deepcluster 43 | def get_knn( 44 | reference_embeddings, test_embeddings, k, embeddings_come_from_same_source=False 45 | ): 46 | if embeddings_come_from_same_source: 47 | k = k + 1 48 | device = reference_embeddings.device 49 | reference_embeddings = c_f.to_numpy(reference_embeddings).astype(np.float32) 50 | test_embeddings = c_f.to_numpy(test_embeddings).astype(np.float32) 51 | 52 | d = reference_embeddings.shape[1] 53 | c_f.LOGGER.info("running k-nn with k=%d" % k) 54 | c_f.LOGGER.info("embedding dimensionality is %d" % d) 55 | cpu_index = faiss.IndexFlatL2(d) 56 | distances, indices = try_gpu(cpu_index, reference_embeddings, test_embeddings, k) 57 | distances = c_f.to_device(torch.from_numpy(distances), device=device) 58 | indices = c_f.to_device(torch.from_numpy(indices), device=device) 59 | if embeddings_come_from_same_source: 60 | return indices[:, 1:], distances[:, 1:] 61 | return indices, distances 62 | 63 | 64 | # modified from https://raw.githubusercontent.com/facebookresearch/deepcluster/ 65 | def run_kmeans(x, nmb_clusters): 66 | device = x.device 67 | x = c_f.to_numpy(x).astype(np.float32) 68 | n_data, d = x.shape 69 | c_f.LOGGER.info("running k-means clustering with k=%d" % nmb_clusters) 70 | c_f.LOGGER.info("embedding dimensionality is %d" % d) 71 | 72 | # faiss implementation of k-means 73 | clus = faiss.Clustering(d, nmb_clusters) 74 | clus.niter = 20 75 | clus.max_points_per_centroid = 10000000 76 | index = faiss.IndexFlatL2(d) 77 | if faiss.get_num_gpus() > 0: 78 | index = faiss.index_cpu_to_all_gpus(index) 79 | # perform the training 80 | clus.train(x, index) 81 | _, idxs = index.search(x, 1) 82 | 83 | return torch.tensor([int(n[0]) for n in idxs], dtype=int, device=device) 84 | 85 | 86 | # modified from https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks:-clustering,-PCA,-quantization 87 | def run_pca(x, output_dimensionality): 88 | device = x.device 89 | x = c_f.to_numpy(x).astype(np.float32) 90 | mat = faiss.PCAMatrix(x.shape[1], output_dimensionality) 91 | mat.train(x) 92 | assert mat.is_trained 93 | return c_f.to_device(torch.from_numpy(mat.apply_py(x)), device=device) 94 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sohyun-l/fifo/6eaecab298acd6e5c3e35a78516463e0ef17e709/utils/__init__.py -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | IMG_SCALE = 1. / 255 5 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 6 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 7 | 8 | 9 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 10 | import os 11 | import sys 12 | from six.moves import urllib 13 | 14 | if model_dir is None: 15 | torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) 16 | model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) 17 | if not os.path.exists(model_dir): 18 | os.makedirs(model_dir) 19 | filename = "{}.pth.tar".format(model_name) 20 | cached_file = os.path.join(model_dir, filename) 21 | if not os.path.exists(cached_file): 22 | url = model_url 23 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 24 | urllib.request.urlretrieve(url, cached_file) 25 | return torch.load(cached_file, map_location=map_location) 26 | 27 | 28 | def prepare_img(img): 29 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD -------------------------------------------------------------------------------- /utils/layer_factory.py: -------------------------------------------------------------------------------- 1 | """RefineNet-LightWeight-CRP Block 2 | RefineNet-LigthWeight PyTorch for non-commercial purposes 3 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 4 | All rights reserved. 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 15 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 16 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 17 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 18 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 19 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 20 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 21 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | """ 23 | 24 | import torch.nn as nn 25 | 26 | 27 | def batchnorm(in_planes): 28 | "batch norm 2d" 29 | return nn.BatchNorm2d(in_planes, affine=True, eps=1e-5, momentum=0.1) 30 | 31 | 32 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 33 | "3x3 convolution with padding" 34 | return nn.Conv2d( 35 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias 36 | ) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 40 | "1x1 convolution" 41 | return nn.Conv2d( 42 | in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias 43 | ) 44 | 45 | 46 | def convbnrelu(in_planes, out_planes, kernel_size, stride=1, groups=1, act=True): 47 | "conv-batchnorm-relu" 48 | if act: 49 | return nn.Sequential( 50 | nn.Conv2d( 51 | in_planes, 52 | out_planes, 53 | kernel_size, 54 | stride=stride, 55 | padding=int(kernel_size / 2.0), 56 | groups=groups, 57 | bias=False, 58 | ), 59 | batchnorm(out_planes), 60 | nn.ReLU6(inplace=True), 61 | ) 62 | else: 63 | return nn.Sequential( 64 | nn.Conv2d( 65 | in_planes, 66 | out_planes, 67 | kernel_size, 68 | stride=stride, 69 | padding=int(kernel_size / 2.0), 70 | groups=groups, 71 | bias=False, 72 | ), 73 | batchnorm(out_planes), 74 | ) 75 | 76 | 77 | class CRPBlock(nn.Module): 78 | def __init__(self, in_planes, out_planes, n_stages): 79 | super(CRPBlock, self).__init__() 80 | for i in range(n_stages): 81 | setattr( 82 | self, 83 | "{}_{}".format(i + 1, "outvar_dimred"), 84 | conv1x1( 85 | in_planes if (i == 0) else out_planes, 86 | out_planes, 87 | stride=1, 88 | bias=False, 89 | ), 90 | ) 91 | self.stride = 1 92 | self.n_stages = n_stages 93 | self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 94 | 95 | def forward(self, x): 96 | top = x 97 | for i in range(self.n_stages): 98 | top = self.maxpool(top) 99 | top = getattr(self, "{}_{}".format(i + 1, "outvar_dimred"))(top) 100 | x = top + x 101 | return x -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class CrossEntropy2d(nn.Module): 8 | 9 | def __init__(self, size_average=True, ignore_label=255): 10 | super(CrossEntropy2d, self).__init__() 11 | self.size_average = size_average 12 | self.ignore_label = ignore_label 13 | 14 | def forward(self, predict, target, weight=None): 15 | """ 16 | Args: 17 | predict:(n, c, h, w) 18 | target:(n, h, w) 19 | weight (Tensor, optional): a manual rescaling weight given to each class. 20 | If given, has to be a Tensor of size "nclasses" 21 | """ 22 | assert not target.requires_grad 23 | assert predict.dim() == 4 24 | assert target.dim() == 3 25 | n, c, h, w = predict.size() 26 | n1, h1, w1 = target.size() 27 | target_mask = (target >= 0) * (target != self.ignore_label) 28 | target = target[target_mask] 29 | if not target.data.dim(): 30 | return Variable(torch.zeros(1)) 31 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 32 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 33 | loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average) 34 | return loss 35 | 36 | -------------------------------------------------------------------------------- /utils/network.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from model.refinenetlw import rf_lw101 5 | 6 | 7 | def get_segmenter( 8 | enc_backbone, enc_pretrained, num_classes, 9 | ): 10 | """Create Encoder-Decoder; for now only ResNet [50,101,152] Encoders are supported""" 11 | if enc_backbone == "50": 12 | return rf_lw50(num_classes, imagenet=enc_pretrained) 13 | elif enc_backbone == "101": 14 | return rf_lw101(num_classes, imagenet=enc_pretrained) 15 | elif enc_backbone == "152": 16 | return rf_lw152(num_classes, imagenet=enc_pretrained) 17 | else: 18 | raise ValueError("{} is not supported".format(str(enc_backbone))) 19 | 20 | 21 | def get_encoder_and_decoder_params(model): 22 | """Filter model parameters into two groups: encoder and decoder.""" 23 | logger = logging.getLogger(__name__) 24 | enc_params = [] 25 | dec_params = [] 26 | for k, v in model.named_parameters(): 27 | if bool(re.match(".*conv1.*|.*bn1.*|.*layer.*", k)): 28 | # print(k) 29 | enc_params.append(v) 30 | logger.info(" Enc. parameter: {}".format(k)) 31 | else: 32 | # print(k) 33 | dec_params.append(v) 34 | logger.info(" Dec. parameter: {}".format(k)) 35 | return enc_params, dec_params -------------------------------------------------------------------------------- /utils/optimisers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import densetorch as dt 4 | 5 | from utils.network import get_encoder_and_decoder_params 6 | 7 | 8 | def get_lr_schedulers( 9 | enc_optim, 10 | dec_optim, 11 | enc_lr_gamma, 12 | dec_lr_gamma, 13 | enc_scheduler_type, 14 | dec_scheduler_type, 15 | epochs_per_stage, 16 | ): 17 | milestones = np.cumsum(epochs_per_stage) 18 | max_epochs = milestones[-1] 19 | schedulers = [ 20 | dt.misc.create_scheduler( 21 | scheduler_type=enc_scheduler_type, 22 | optim=enc_optim, 23 | gamma=enc_lr_gamma, 24 | milestones=milestones, 25 | max_epochs=max_epochs, 26 | ), 27 | dt.misc.create_scheduler( 28 | scheduler_type=dec_scheduler_type, 29 | optim=dec_optim, 30 | gamma=dec_lr_gamma, 31 | milestones=milestones, 32 | max_epochs=max_epochs, 33 | ), 34 | ] 35 | return schedulers 36 | 37 | 38 | def get_optimisers( 39 | model, 40 | enc_optim_type, 41 | enc_lr, 42 | enc_weight_decay, 43 | enc_momentum, 44 | dec_optim_type, 45 | dec_lr, 46 | dec_weight_decay, 47 | dec_momentum, 48 | ): 49 | enc_params, dec_params = get_encoder_and_decoder_params(model) 50 | optimisers = [ 51 | dt.misc.create_optim( 52 | optim_type=enc_optim_type, 53 | parameters=enc_params, 54 | lr=enc_lr, 55 | weight_decay=enc_weight_decay, 56 | momentum=enc_momentum, 57 | ), 58 | dt.misc.create_optim( 59 | optim_type=dec_optim_type, 60 | parameters=dec_params, 61 | lr=dec_lr, 62 | weight_decay=dec_weight_decay, 63 | momentum=dec_momentum, 64 | ), 65 | ] 66 | return optimisers --------------------------------------------------------------------------------