├── Imgs ├── DAFOS.png └── ICCV_poster_PID3497.pdf ├── LICENSE ├── README.md ├── accuracy_values.npy ├── accuracy_variance_values.npy ├── config.json ├── data └── save_data_to_numpy.py ├── finetune.py ├── logs └── gradient_tape │ ├── 20230822-180101 │ └── train │ │ └── events.out.tfevents.1692727261.73c426b94e33.2297.0.v2 │ ├── 20230822-180428 │ └── finetune │ │ └── events.out.tfevents.1692727468.73c426b94e33.3399.0.v2 │ ├── 20230822-180821 │ └── test │ │ └── events.out.tfevents.1692727701.73c426b94e33.4383.0.v2 │ ├── 20230822-181123 │ └── test │ │ └── events.out.tfevents.1692727883.73c426b94e33.5136.0.v2 │ ├── 20230822-181519 │ └── finetune │ │ └── events.out.tfevents.1692728119.73c426b94e33.6090.0.v2 │ └── 20230822-184454 │ └── train │ └── events.out.tfevents.1692729894.73c426b94e33.13220.0.v2 ├── loss ├── AOCMC.py ├── cosine_loss.py ├── cross_domain_alignment.py ├── euclidian_dist.py └── prototype_diversification_loss.py ├── models ├── Feature_extractor.py ├── domain_network.py ├── high_noise_GAN.py ├── low_noise_GAN.py └── outlier_network.py ├── optimizers ├── __pycache__ │ └── gan_optimizers.cpython-310.pyc └── gan_optimizers.py ├── test.py ├── train.py └── utils ├── adamatch_aug.py ├── get_gamma_beta_from_layer.py ├── proto_test.py ├── proto_train.py ├── save_accuracy_values.py ├── test_episode.py ├── train_episode.py └── variance_metric.py /Imgs/DAFOS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/Imgs/DAFOS.png -------------------------------------------------------------------------------- /Imgs/ICCV_poster_PID3497.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/Imgs/ICCV_poster_PID3497.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Debabrata Pal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAFOS-NET 2 | 3 | The code repository for "Domain Adaptive Few-Shot Open-Set Learning" [[paper]](https://openaccess.thecvf.com/content/ICCV2023/papers/Pal_Domain_Adaptive_Few-Shot_Open-Set_Learning_ICCV_2023_paper.pdf) (ICCV'23) in Tensorflow. 4 | 5 | ### Abstract 6 | 7 | Few-shot learning has made impressive strides in addressing the crucial challenges of recognizing unknown samples from novel classes in target query sets and managing visual shifts between domains. However, existing techniques fall short when it comes to identifying target outliers under domain shifts by learning to reject pseudo-outliers from the source domain, resulting in an incomplete solution to both problems. To address these challenges comprehensively, we propose a novel approach called Domain Adaptive Few-Shot Open Set Recognition (DA-FSOS) and introduce a meta-learning-based architecture named DAFOSNET. During training, our model learns a shared and discriminative embedding space while creating a pseudo open-space decision boundary, given a fully-supervised source domain and a label-disjoint few-shot target domain. To enhance data density, we use a pair of conditional adversarial networks with tunable noise variances to augment both domains’ closed and pseudo-open spaces. Furthermore, we propose a domain-specific batch-normalized class prototypes alignment strategy to align both domains globally while ensuring class-discriminativeness through metric objectives. Our training approach ensures that DAFOS-NET can generalize well to new scenarios in the target domain. We present three benchmarks for DA-FSOS based on the Office-Home, mini-ImageNet/CUB, and DomainNet datasets and demonstrate the efficacy of DAFOSNet through extensive experimentation. 8 | 9 | 10 | 11 | ### Dependencies 12 | This code requires the following: 13 | * python 3.7+ 14 | * TensorFlow v2.12.0+ 15 | * keras_cv v1.4.0 16 | 17 | ### Data 18 | The Domain Adaptive experiments are performed by choosing two domains from Office-Home, MiniImageNet-CUB, and DomainNet datasets. 19 | 20 | Make two new folders, `data/raw` and `data/processed` 21 | 22 | Keep the two domain folders of your choice from the aforementioned datasets in the `data/raw` folder. 23 | 24 | For preprocessing details, take a look at `data/save_data_to_numpy.py`. By default, the flag for saving the input images tensor is set to "True" in `config.json` at `"pre_process":"True"`. In case that is not needed, make it "False". You don't have to run this script separately. 25 | 26 | If the flag is "True" the numpy arrays will be saved at `data/processed`. 27 | 28 | ### Usage 29 | Modify the settings in `config.json` 30 | 31 | ##### Checkpoints: 32 | train checkpoints will be saved at `./checkpoints/train` and fintune checkpoints at `./checkpoints/finetune` 33 | 34 | ##### For training: 35 |  `python3 train.py` 36 | ##### For finetuning: 37 |  `python3 finetune.py` 38 | ##### For testing: 39 |  `python3 test.py` 40 | 41 | ## Citation 42 | If you use any content of this repo for your work, please cite the following bib entry: 43 | 44 | @InProceedings{Pal_2023_ICCV, 45 | author = {Pal, Debabrata and More, Deeptej and Bhargav, Sai and Tamboli, Dipesh and Aggarwal, Vaneet and Banerjee, Biplab}, 46 | title = {Domain Adaptive Few-Shot Open-Set Learning}, 47 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 48 | month = {October}, 49 | year = {2023}, 50 | pages = {18831-18840} 51 | } 52 | 53 | ## Licence 54 | DAFOS-NET is released under the MIT license. 55 | 56 | Copyright (c) 2023 Debabrata Pal. All rights reserved. 57 | -------------------------------------------------------------------------------- /accuracy_values.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/accuracy_values.npy -------------------------------------------------------------------------------- /accuracy_variance_values.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/accuracy_variance_values.npy -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_dir":"./data/raw", 3 | "save_path_source":"./data/processed", 4 | "save_path_target":"./data/processed", 5 | "existing_checkpoint_dir":"./checkpoints/train", 6 | "current_checkpoint_dir":"./checkpoints/finetune", 7 | "num_classes":12, 8 | "num_images_per_class":8, 9 | "train_epochs":2, 10 | "finetune_epochs":2, 11 | "test_epochs":2, 12 | "lr":0.0001, 13 | "initial_learning_rate":0.0001, 14 | "decay_rate":0.001, 15 | "decay_steps":1000, 16 | "finetune_lr":0.0001, 17 | "finetune_initial_learning_rate":0.0001, 18 | "finetune_decay_rate":0.001, 19 | "finetune_decay_steps":1000, 20 | "pre_process":"True", 21 | "save_every":2, 22 | "CS":3, 23 | "CQ":3, 24 | "Ss":4, 25 | "Sd":1, 26 | "Qsk":2, 27 | "Qdk":2, 28 | "Qsu":2, 29 | "Qdu":2, 30 | "train_classes": 12, 31 | "test_classes": 12, 32 | "emb_dim":64, 33 | "ntimes":1 34 | } 35 | -------------------------------------------------------------------------------- /data/save_data_to_numpy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import numpy as np 7 | import os 8 | import cv2 9 | 10 | #==================================================Load Data from Image Folders and Save to Numpy Files================================================== 11 | def save_data(base_dir, save_path_source, save_path_target, num_classes, num_images_per_class): 12 | # Set the base directory for the RealWorld and ClipArt folders 13 | base_dir = base_dir 14 | 15 | #GPU info 16 | physical_devices = tf.config.list_physical_devices('GPU') 17 | print("GPU: ", physical_devices) 18 | 19 | # Pre-Process data 20 | # 21 | # 1. create 5-d array source_domain: `[class,image_no,height,width,channels]` 22 | # 2. create 5-d array target_domain: `[class,image_no,height,width,channels]` 23 | # 24 | # ----- 25 | # Note: Images are read using cv2 hence are in the BGR format 26 | 27 | # Path to Real World folder 28 | realworld_folder = os.path.join(base_dir, 'Real World') 29 | # Path to Clipart folder 30 | clipart_folder = os.path.join(base_dir, 'Clipart') 31 | # Numeric class labels 32 | class_labels = [] 33 | # Initialize a 5-dimensional array with the appropriate shape 34 | num_classes = num_classes 35 | # Careful all classes don't have same number of images in OfficeHome Real World and Clipart 36 | # For standardization we will take 20 images per class 37 | num_images_per_class = num_images_per_class#len(os.listdir(os.path.join(realworld_folder, str(0)))) 38 | num_channels = 3 # assuming the images are RGB 39 | image_height = 224 40 | image_width = 224 41 | source_domain = np.empty((num_classes, num_images_per_class, image_height, image_width, num_channels)) 42 | target_domain = np.empty((num_classes, num_images_per_class, image_height, image_width, num_channels)) 43 | 44 | # Load the images from the Real World folder 45 | real_world=sorted(os.listdir(realworld_folder)) 46 | print(real_world) 47 | for i, folder in enumerate(real_world): 48 | # Stop loading data if required num_classes are selected 49 | if i < num_classes: 50 | class_folder = os.path.join(realworld_folder, folder) 51 | # Print number of images in the class 52 | print(class_folder, len(os.listdir(class_folder))) 53 | for j, image_path in enumerate(os.listdir(class_folder)): 54 | # If the current iteration is greater than num_images_per_class then break the inner loop and continue with the next class 55 | if j==num_images_per_class:# as indices start from 0 56 | break 57 | # Load the image and store it in the array 58 | image = cv2.imread(os.path.join(class_folder,image_path)) 59 | # Convert images to 224*224 60 | image = cv2.resize(image, dsize=(image_height, image_width), interpolation=cv2.INTER_CUBIC) 61 | source_domain[i, j, :, :, :] = image 62 | # Append class labels 63 | class_labels.append(i) 64 | else: 65 | break 66 | 67 | # Load the images from the Clipart folder 68 | clipart=sorted(os.listdir(clipart_folder)) 69 | print(clipart) 70 | for i, folder in enumerate(clipart): 71 | # Stop loading data if required num_classes are selected 72 | if i < num_classes: 73 | class_folder = os.path.join(clipart_folder, folder) 74 | # Print number of images in the class 75 | print(class_folder, len(os.listdir(class_folder))) 76 | for j, image_path in enumerate(os.listdir(class_folder)): 77 | # If the current iteration is greater than num_images_per_class then break the inner loop and continue with the next class 78 | if j==num_images_per_class:# as indices start from 0 79 | break 80 | # Load the image and store it in the array 81 | image = cv2.imread(os.path.join(class_folder,image_path)) 82 | # Convert images to 224*224 83 | image = cv2.resize(image, dsize=(image_height, image_width), interpolation=cv2.INTER_CUBIC) 84 | target_domain[i, j, :, :, :] = image 85 | else: 86 | break 87 | 88 | np.save(f'{save_path_source}/source_{num_classes}.npy', source_domain) 89 | np.save(f'{save_path_target}/target_{num_classes}.npy', target_domain) -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import tensorflow.keras.backend as K 7 | import os 8 | import json 9 | import numpy as np 10 | import datetime 11 | from sklearn.metrics import confusion_matrix 12 | from scipy.stats import hmean 13 | 14 | #save image batches to numpy files for faster processing 15 | from data.save_data_to_numpy import save_data 16 | #feature extractor 17 | from models.Feature_extractor import CDFSOSR_Model 18 | #outlier prediction network 19 | from models.outlier_network import outlier_nn 20 | #domain prediction network 21 | from models.domain_network import domain_nn 22 | #low noise GANs 23 | from models.low_noise_GAN import generator_nn_low_s 24 | from models.low_noise_GAN import dis_nn_low_s 25 | #high noise GANs 26 | from models.high_noise_GAN import generator_nn_high_s 27 | from models.high_noise_GAN import dis_nn_high_s 28 | #custom metrics (Variance) 29 | from utils.variance_metric import ComputeIterationVaraince 30 | #save accuracy values 31 | from utils.save_accuracy_values import save_accuracy_values 32 | #proto train 33 | from utils.proto_train import proto_train 34 | #new train episode 35 | from utils.train_episode import new_episode 36 | #adamatch augmentation 37 | # from utils.adamatch_aug import adamatch_aug 38 | 39 | with open('./config.json') as json_file: 40 | config = json.load(json_file) 41 | 42 | base_dir = config['base_dir'] 43 | save_path_source = config['save_path_source'] 44 | save_path_target = config['save_path_target'] 45 | num_classes = config['num_classes']#train_classes+test_classes 46 | num_images_per_class = config['num_images_per_class'] 47 | 48 | # Hyperparameters for Support and Query images 49 | CS=config['CS'] 50 | CQ=config['CQ'] #Not necessarily equal 51 | Ss=config['Ss'] 52 | Sd=config['Sd'] 53 | Qsk=config['Qsk'] 54 | Qdk=config['Qdk']#N 55 | Qsu=config['Qsu'] 56 | Qdu=config['Qdu']#X 57 | 58 | # Fintune Classes 59 | test_classes = config['test_classes'] 60 | 61 | #embedding dimensions 62 | emb_dim = config['emb_dim'] 63 | 64 | #load the saved arrays into the memory 65 | source_domain = np.load(f'{save_path_source}/source_{num_classes}.npy') 66 | target_domain = np.load(f'{save_path_target}/target_{num_classes}.npy') 67 | 68 | # Numeric class labels 69 | class_labels = [] 70 | 71 | for i in range(len(source_domain)): 72 | class_labels.append(i) 73 | 74 | print("source_domain shape: ", source_domain.shape) 75 | print("target_domain shape: ", target_domain.shape) 76 | print("class_labels: ", class_labels) 77 | 78 | #====================================================================Optimizers================================================================ 79 | lr = config['finetune_lr'] 80 | initial_learning_rate = config['finetune_initial_learning_rate'] 81 | decay_rate = config['finetune_decay_rate'] 82 | decay_steps = config['finetune_decay_steps'] 83 | 84 | # Create the learning rate schedule 85 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 86 | initial_learning_rate, 87 | decay_steps=decay_steps, 88 | decay_rate=decay_rate, 89 | staircase=True) 90 | 91 | optim_d_low_s=tf.keras.optimizers.legacy.Adam(lr) 92 | optim_g_low_s=tf.keras.optimizers.legacy.Adam(lr) 93 | optim_d_high_s=tf.keras.optimizers.legacy.Adam(lr) 94 | optim_g_high_s=tf.keras.optimizers.legacy.Adam(lr) 95 | optim_model_Outlier = tf.keras.optimizers.legacy.Adam(initial_learning_rate) 96 | optim_model_Domain = tf.keras.optimizers.legacy.Adam(initial_learning_rate) 97 | optim2 = tf.keras.optimizers.legacy.Adam(lr)##Multitask Outlier network 98 | optim3 = tf.keras.optimizers.legacy.Adam(initial_learning_rate)##FE 99 | 100 | #====================================================================Finetune checkpoint================================================================ 101 | # Specify the directory where the existing checkpoint is saved 102 | existing_checkpoint_dir = config["existing_checkpoint_dir"] 103 | existing_checkpoint_prefix = os.path.join(existing_checkpoint_dir, "ckpt") 104 | # Define a variable to keep track of the current epoch 105 | epoch_var = tf.Variable(0, dtype=tf.int64, name="epoch") 106 | # Load the existing checkpoint variables into the current model 107 | existing_checkpoint = tf.train.Checkpoint(epoch_var, 108 | optim3=optim3, optim2=optim2, 109 | optim_d_low_s=optim_d_low_s,optim_g_low_s=optim_g_low_s, optim_d_high_s=optim_d_high_s,optim_g_high_s=optim_g_high_s, 110 | optim_model_Outlier = optim_model_Outlier, 111 | optim_model_Domain = optim_model_Domain, 112 | CDFSOSR_Model = CDFSOSR_Model, 113 | generator_nn_low_s=generator_nn_low_s, generator_nn_high_s=generator_nn_high_s, dis_nn_low_s=dis_nn_low_s, dis_nn_high_s=dis_nn_high_s, 114 | outlier_nn=outlier_nn, 115 | domain_nn=domain_nn) 116 | existing_checkpoint.restore(tf.train.latest_checkpoint(existing_checkpoint_dir)) 117 | existing_ckpt_manager = tf.train.CheckpointManager(existing_checkpoint,existing_checkpoint_prefix, max_to_keep=5) 118 | 119 | 120 | #New checkpoint to save the current values 121 | current_checkpoint_dir = config["current_checkpoint_dir"] 122 | current_checkpoint_prefix = os.path.join(current_checkpoint_dir, "ckpt") 123 | current_checkpoint = tf.train.Checkpoint(epoch_var, 124 | optim3=optim3, optim2=optim2, 125 | optim_d_low_s=optim_d_low_s,optim_g_low_s=optim_g_low_s, optim_d_high_s=optim_d_high_s,optim_g_high_s=optim_g_high_s, 126 | optim_model_Outlier = optim_model_Outlier, 127 | optim_model_Domain = optim_model_Domain, 128 | CDFSOSR_Model = CDFSOSR_Model, 129 | generator_nn_low_s=generator_nn_low_s, generator_nn_high_s=generator_nn_high_s, dis_nn_low_s=dis_nn_low_s, dis_nn_high_s=dis_nn_high_s, 130 | outlier_nn=outlier_nn, 131 | domain_nn=domain_nn) 132 | current_ckpt_manager = tf.train.CheckpointManager(current_checkpoint,current_checkpoint_prefix, max_to_keep=5) 133 | 134 | #===============================================================Finetune Step======================================================== 135 | # Metrics to gather 136 | finetune_loss = tf.metrics.Mean(name='finetune_loss') 137 | finetune_acc = tf.metrics.Mean(name='finetune_accuracy') 138 | finetune_closedoa_cm = tf.metrics.Mean(name='finetune_closedoa_cm') 139 | finetune_openoa = tf.metrics.Mean(name='finetune_openoa') 140 | finetune_outlier_acc = tf.metrics.Mean(name='finetune_outlier_acc') 141 | finetune_domain_acc = tf.metrics.Mean(name='finetune_domain_acc') 142 | finetune_closedoa_cm_variance = ComputeIterationVaraince(name="finetune_closedoa_cm_variance") 143 | finetune_openoa_variance = ComputeIterationVaraince(name="finetune_openoa_variance") 144 | 145 | ntimes=config['ntimes'] 146 | def finetune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels): 147 | # Forward & update gradients 148 | with tf.GradientTape() as tape: 149 | loss, accuracy, closed_oa_cm, outlier_det_acc, openoa, domain_acc = proto_train(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels, CDFSOSR_Model, outlier_nn, domain_nn, generator_nn_low_s, dis_nn_low_s,generator_nn_high_s, dis_nn_high_s, ntimes, lr_schedule, optim_model_Outlier, optim_model_Domain, optim_d_low_s,optim_g_low_s, optim_d_high_s,optim_g_high_s) 150 | gradients = tape.gradient(loss, CDFSOSR_Model.trainable_variables) 151 | # Compute the learning rate using the schedule 152 | learning_rate = lr_schedule(optim3.iterations) 153 | optim3.learning_rate = learning_rate 154 | optim3.apply_gradients(zip(gradients, CDFSOSR_Model.trainable_variables)) 155 | finetune_loss(loss) 156 | finetune_acc(accuracy) 157 | finetune_closedoa_cm(closed_oa_cm) 158 | finetune_openoa(openoa) 159 | finetune_outlier_acc(outlier_det_acc) 160 | finetune_domain_acc(domain_acc) 161 | finetune_closedoa_cm_variance.add(closed_oa_cm) 162 | finetune_openoa_variance.add(openoa) 163 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 164 | finetune_log_dir = 'logs/gradient_tape/' + current_time + '/finetune' 165 | finetune_summary_writer = tf.summary.create_file_writer(finetune_log_dir) 166 | 167 | #===============================================================Train Loop======================================================== 168 | start_epoch = 0 169 | if existing_ckpt_manager.latest_checkpoint: 170 | start_epoch = int(existing_ckpt_manager.latest_checkpoint.split('-')[-1]) 171 | # restoring the latest checkpoint in checkpoint_path 172 | existing_checkpoint.restore(existing_ckpt_manager.latest_checkpoint) 173 | print("Restarted from: ", start_epoch) 174 | else: 175 | print("Training from scratch...") 176 | 177 | for epoch in range(start_epoch, config["finetune_epochs"]): # 80 train + 80 tune + 100 train + 160 tune + 40 train 178 | finetune_loss.reset_states() 179 | finetune_acc.reset_states() 180 | finetune_closedoa_cm.reset_states() 181 | finetune_openoa.reset_states() 182 | finetune_outlier_acc.reset_states() 183 | finetune_domain_acc.reset_states() 184 | finetune_closedoa_cm_variance.reset_states() 185 | finetune_openoa_variance.reset_states() 186 | epoch_var.assign_add(1) 187 | for epi in range(10): 188 | tquery_patches, tsupport_patches, query_labels, support_labels, support_classes, query_dom_labels, support_dom_labels = new_episode(source_domain[-test_classes:,:,:,:,:], target_domain[-test_classes:,:,:,:,:], CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, class_labels[-test_classes:]) 189 | # tsupport_patches_aug, support_labels_aug, Ss_aug, Sd_aug, support_dom_labels_aug = adamatch_aug(tsupport_patches, support_labels, CS, Ss, Sd, support_dom_labels) 190 | finetune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels) 191 | with finetune_summary_writer.as_default(): 192 | tf.summary.scalar('loss', finetune_loss.result(), step=epoch) 193 | tf.summary.scalar('accuracy', finetune_acc.result(), step=epoch) 194 | tf.summary.scalar('closedoa_cm', finetune_closedoa_cm.result(), step=epoch) 195 | tf.summary.scalar('openoa', finetune_openoa.result(), step=epoch) 196 | tf.summary.scalar('outlier_det_acc', finetune_outlier_acc.result(), step=epoch) 197 | tf.summary.scalar('domain_det_acc', finetune_domain_acc.result(), step=epoch) 198 | 199 | template = 'Epoch {}, Finetune Loss: {:.2f}, Finetune Accuracy: {:.2f}, Finetune Closed OA: {:.2f}, Fintune Open OA: {:.2f}, Finetune Outlier Det. Acc: {:.2f}, Finetune Domain Det. Acc: {:.2f}' 200 | print(template.format(epoch+1, finetune_loss.result(), finetune_acc.result()*100, finetune_closedoa_cm.result()*100, finetune_openoa.result()*100, finetune_outlier_acc.result()*100, finetune_domain_acc.result()*100)) 201 | #save mean accuracy values per epoch 202 | save_accuracy_values(finetune_closedoa_cm.result(), finetune_openoa.result(), 'accuracy_values.npy') 203 | #save variance for accuracys 204 | save_accuracy_values(finetune_closedoa_cm_variance.compute_variance(), finetune_openoa_variance.compute_variance(), 'accuracy_variance_values.npy') 205 | 206 | if epoch % config["save_every"] == 0 and epoch != 0 : 207 | current_checkpoint.save(file_prefix = current_checkpoint_prefix) -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-180101/train/events.out.tfevents.1692727261.73c426b94e33.2297.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-180101/train/events.out.tfevents.1692727261.73c426b94e33.2297.0.v2 -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-180428/finetune/events.out.tfevents.1692727468.73c426b94e33.3399.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-180428/finetune/events.out.tfevents.1692727468.73c426b94e33.3399.0.v2 -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-180821/test/events.out.tfevents.1692727701.73c426b94e33.4383.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-180821/test/events.out.tfevents.1692727701.73c426b94e33.4383.0.v2 -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-181123/test/events.out.tfevents.1692727883.73c426b94e33.5136.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-181123/test/events.out.tfevents.1692727883.73c426b94e33.5136.0.v2 -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-181519/finetune/events.out.tfevents.1692728119.73c426b94e33.6090.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-181519/finetune/events.out.tfevents.1692728119.73c426b94e33.6090.0.v2 -------------------------------------------------------------------------------- /logs/gradient_tape/20230822-184454/train/events.out.tfevents.1692729894.73c426b94e33.13220.0.v2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/logs/gradient_tape/20230822-184454/train/events.out.tfevents.1692729894.73c426b94e33.13220.0.v2 -------------------------------------------------------------------------------- /loss/AOCMC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | 5 | #==========================AOCMC Loss========================== 6 | def anti_open_close_mode_collapse_loss(sl, sh , zl, zh, epsilon=1e-8): 7 | cos_sl_sh = tf.reduce_sum(sl * sh, axis=-1) / (tf.norm(sl, axis=-1) * tf.norm(sh, axis=-1) + epsilon) 8 | cos_zl_zh = tf.reduce_sum(zl * zh, axis=-1) / (tf.norm(zl, axis=-1) * tf.norm(zh, axis=-1) + epsilon) 9 | 10 | numerator = 1 - cos_zl_zh + epsilon 11 | denominator = 1 - cos_sl_sh + epsilon 12 | LAOCMC = 1 + tf.math.log(numerator/denominator) 13 | 14 | loss = LAOCMC 15 | return tf.reduce_mean(loss) -------------------------------------------------------------------------------- /loss/cosine_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | 5 | #==========================Cosine Loss========================== 6 | # AOL Regularizer 7 | def cosine_loss(s1,s2,z1,z2): 8 | s1=tf.nn.l2_normalize(s1,dim=1) 9 | s2=tf.nn.l2_normalize(s2,dim=1) 10 | z1=tf.nn.l2_normalize(z1,dim=1) 11 | z2=tf.nn.l2_normalize(z2,dim=1) 12 | cos_s=s1*s2 13 | cos_z=z1*z2 14 | loss = (1+tf.reduce_sum(cos_z,axis=1))*(tf.math.maximum(0.0000001, tf.reduce_sum(cos_s,axis=1))) 15 | return tf.reduce_mean(loss) -------------------------------------------------------------------------------- /loss/cross_domain_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | #==========================P Loss========================== 7 | #calculate P_loss 8 | def calc_weighted_gamma_dists(z_prototypes_s_combined,z_prototypes_d_combined,Gamma_D1,Gamma_D2,Beta_D1,Beta_D2): 9 | prot_bn_dists = [] 10 | for i in range(len(Gamma_D1)): 11 | for j in range(len(z_prototypes_s_combined)): 12 | d = ( tf.reduce_mean(z_prototypes_s_combined[j])*tf.cast(tf.reduce_mean(Gamma_D1[i]),tf.float64) + tf.cast(tf.reduce_mean(Beta_D1[i]),tf.float64) ) - ( tf.reduce_mean(z_prototypes_d_combined[j])*tf.cast(tf.reduce_mean(Gamma_D2[i]),tf.float64) + tf.cast(tf.reduce_mean(Beta_D2[i]),tf.float64) ) 13 | prot_bn_dists.append(d) 14 | return tf.reduce_mean(prot_bn_dists) -------------------------------------------------------------------------------- /loss/euclidian_dist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | #calculate euclidian distance between two 2d tensors 7 | def calc_euclidian_dists2(x, y): 8 | # x : (n,d) 9 | # y : (m,d) 10 | n = x.shape[0] 11 | m = y.shape[0] 12 | x = tf.cast(tf.tile(tf.expand_dims(x, 1), [1, m, 1]),dtype=tf.float64) 13 | y = tf.cast(tf.tile(tf.expand_dims(y, 0), [n, 1, 1]),dtype=tf.float64) 14 | return tf.reduce_mean(tf.math.pow(x - y, 2), 2) -------------------------------------------------------------------------------- /loss/prototype_diversification_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | #==========================Prototype Diversification Loss========================== 7 | #calculate prototype diversification loss (triplet loss) 8 | def calc_triplet_dists(Proto, query, query_labels, Alpha=0.5): # [3,64], [90,64], [90,3] 9 | 10 | nbprototypes = Proto.shape[0] # 3 11 | nbqueries = query.shape[0] # 90 12 | Triplet = [] 13 | 14 | for i in range(nbprototypes): 15 | pos_ind = np.where(query_labels[:,i]==1) # indexes of Positive queries for i th prototype 16 | pos_ind = list(pos_ind[0]) # indexes of Positive queries for i th prototype 17 | Neg_ind = list(set(np.arange(nbqueries)).difference(set(pos_ind))) # indexes of Negative queries for i th prototype 18 | Anchor = tf.expand_dims(Proto[i], 0) # [1, 64] # ith Prototype / Anchor 19 | 20 | Positive_d = [] 21 | for j in range(len(pos_ind)): # 15 22 | Pos = tf.expand_dims(query[pos_ind[j]], 0) # [1, 64] 23 | Pos_dist = tf.reduce_mean(tf.math.pow(Anchor-Pos, 2), 1) # scalar 24 | Positive_d.append(Pos_dist) # List 25 | Positive_d = tf.reduce_mean(Positive_d) # scalar 26 | 27 | Negative_d = [] 28 | for j in range(len(Neg_ind)): 29 | Neg = tf.expand_dims(query[Neg_ind[j]], 0) # [1, 64] 30 | Neg_dist = tf.reduce_mean(tf.math.pow(Anchor-Neg, 2), 1) # scalar 31 | Negative_d.append(Neg_dist) # List 32 | Negative_d = tf.reduce_mean(Negative_d) # scalar 33 | 34 | P_N = Positive_d-Negative_d 35 | print(P_N) 36 | Triplet.append(P_N) #list 37 | 38 | P_N_Proto_mean = tf.reduce_mean(Triplet) # p_dist - n_dist # averaged for all Prototypes # scalar 39 | return tf.math.maximum(P_N_Proto_mean + Alpha, 0) # scalar -------------------------------------------------------------------------------- /models/Feature_extractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | #Inherited from ResNet18 conventional batch norm layers are replaced with DSBN (Domain Specific Batch Norm) 7 | Model_input = tf.keras.Input(shape=(224, 224, 3), batch_size=None) 8 | X = layers.ZeroPadding2D((1, 1))(Model_input) 9 | Conv1 = layers.Conv2D(256, (3, 3), activation='relu')(X) 10 | BN1 = layers.BatchNormalization(name='BN1')(Conv1) 11 | Maxpool1 = layers.MaxPooling2D()(BN1) 12 | 13 | Usampled_Maxpool1 = layers.UpSampling2D()(Maxpool1) 14 | 15 | Skipped_1_2 = layers.Add()([Usampled_Maxpool1, Conv1]) 16 | Activation_skipped_1_2 = layers.Activation('relu')(Skipped_1_2) 17 | 18 | Conv2 = layers.Conv2D(128, (3, 3), activation='relu', strides=(2, 2), padding="same")(Activation_skipped_1_2) 19 | BN2 = layers.BatchNormalization(name='BN2')(Conv2) 20 | Maxpool2 = layers.MaxPooling2D()(BN2) 21 | 22 | Usampled_Maxpool2 = layers.UpSampling2D()(Maxpool2) 23 | 24 | Skipped_2_3 = layers.Add()([Usampled_Maxpool2, Conv2]) 25 | Activation_skipped_2_3 = layers.Activation('relu')(Skipped_2_3) 26 | 27 | # conv block 3 28 | Conv3 = layers.Conv2D(32, (3, 3), activation='relu',strides=(3, 3))(Activation_skipped_2_3) 29 | subsection3_D1 = layers.Lambda(lambda x: x[:36, :, :, :])(Conv3) 30 | subsection3_D2 = layers.Lambda(lambda x: x[36:, :, :, :])(Conv3) 31 | # (Domain Specific Batch Norm) 32 | BN3_D1 = layers.BatchNormalization(name='BN3_D1')(subsection3_D1) 33 | BN3_D2 = layers.BatchNormalization(name='BN3_D2')(subsection3_D2) 34 | Concat3 = layers.Concatenate(axis=0)([BN3_D1, BN3_D2]) 35 | BN3 = layers.BatchNormalization()(Concat3) 36 | Maxpool3 = layers.MaxPooling2D((3,3))(BN3) 37 | 38 | Flatten_l = layers.Flatten()(Maxpool3) 39 | 40 | Dense1 = layers.Dense(1024, activation='relu')(Flatten_l) 41 | BN4 = layers.BatchNormalization(name='BN4')(Dense1) 42 | Dense2 = layers.Dense(512, activation='relu')(BN4) 43 | BN5 = layers.BatchNormalization(name='BN5')(Dense2) 44 | Dense3 = layers.Dense(128, activation='relu')(BN5) 45 | BN6 = layers.BatchNormalization(name='BN6')(Dense3) 46 | Dense4 = layers.Dense(64, activation='relu')(BN6) 47 | 48 | CDFSOSR_Model = tf.keras.Model(inputs=Model_input, outputs=Dense4, name='CDFSOSR_Model') 49 | # CDFSOSR_Model.summary() 50 | 51 | 52 | -------------------------------------------------------------------------------- /models/domain_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | #============================Domain network==================== 7 | input_Feature = layers.Input(shape = (3,)) 8 | encoded_L1 = layers.Dense(128, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(input_Feature) 9 | encoded_L2 = layers.Dense(64, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L1) 10 | encoded_L3 = layers.Dense(32, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L2) 11 | encoded_L4 = layers.Dense(16, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L3) 12 | encoded_L5 = layers.Dense(8, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L4) 13 | decoded = layers.Dense(2, activation='softmax')(encoded_L5) 14 | domain_nn = tf.keras.Model(inputs=input_Feature,outputs=decoded,name='DomainNN') 15 | # domain_nn.summary() -------------------------------------------------------------------------------- /models/high_noise_GAN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | #==================================================High noise GAN======================================================== 7 | # GAN2 : to generate pseudo-unknown to augment query samples: Input high noise variance in SOURCE DOMAIN 8 | #Generator (High) 9 | generator_input_size=10+8 10 | input_feature = layers.Input(shape=(generator_input_size,)) 11 | layer_h_g1 = layers.Dense(32,activation='relu')(input_feature) 12 | layer_h_g2 = layers.Dense(48,activation='relu')(layer_h_g1) 13 | layer_h_g3 = layers.Dense(64,activation='relu')(layer_h_g2) 14 | generator_nn_high_s = tf.keras.Model(inputs=input_feature,outputs=layer_h_g3, name='generator_high_source') 15 | generator_nn_high_s.summary() 16 | 17 | # Discriminator (High) 18 | input_feature = layers.Input(shape=(74,)) 19 | layer_h_d1 = layers.Dense(48,activation='relu')(input_feature) 20 | layer_h_d2 = layers.Dense(32,activation='relu')(layer_h_d1) 21 | layer_h_d3 = layers.Dense(16,activation='relu')(layer_h_d2) 22 | layer_h_d4 = layers.Dense(8,activation='relu')(layer_h_d3) 23 | layer_h_d5 = layers.Dense(1)(layer_h_d4) 24 | dis_nn_high_s = tf.keras.Model(inputs=input_feature,outputs=layer_h_d5, name='discriminator_high_source') 25 | # dis_nn_high_s.summary() -------------------------------------------------------------------------------- /models/low_noise_GAN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | #==============================================================Low noise GAN======================================================== 7 | #GAN1 : to generate pseudo-known samples: Input low noise variance for SOURCE DOMAIN 8 | 9 | # Generator (Low) SOURCE DOMAIN 10 | generator_input_size=10+8 11 | input_feature = layers.Input(shape=(generator_input_size,)) 12 | layer_low_g1 = layers.Dense(32,activation='relu')(input_feature) 13 | layer_low_g2 = layers.Dense(48,activation='relu')(layer_low_g1) 14 | layer_low_g3 = layers.Dense(64,activation='relu')(layer_low_g2) 15 | generator_nn_low_s = tf.keras.Model(inputs=input_feature,outputs=layer_low_g3, name='generator_low_source') 16 | generator_nn_low_s.summary() 17 | 18 | #Discriminator (Low) SOURCE DOMAIN 19 | input_feature = layers.Input(shape=(74,)) 20 | layer_l_d1 = layers.Dense(48,activation='relu')(input_feature) 21 | layer_l_d2 = layers.Dense(32,activation='relu')(layer_l_d1) 22 | layer_l_d3 = layers.Dense(16,activation='relu')(layer_l_d2) 23 | layer_l_d4 = layers.Dense(8,activation='relu')(layer_l_d3) 24 | layer_l_d5 = layers.Dense(1)(layer_l_d4) 25 | dis_nn_low_s = tf.keras.Model(inputs=input_feature,outputs=layer_l_d5, name='discriminator_low_source') 26 | # dis_nn_low_s.summary() 27 | -------------------------------------------------------------------------------- /models/outlier_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | 6 | #============================Outlier network==================== 7 | input_Feature = layers.Input(shape = (3,)) 8 | encoded_L1 = layers.Dense(128, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(input_Feature) 9 | encoded_L2 = layers.Dense(64, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L1) 10 | encoded_L3 = layers.Dense(32, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L2) 11 | encoded_L4 = layers.Dense(16, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L3) 12 | encoded_L5 = layers.Dense(8, activation='relu',kernel_regularizer=tf.keras.regularizers.L1(0.01),activity_regularizer=tf.keras.regularizers.L2(0.001))(encoded_L4) 13 | decoded = layers.Dense(2, activation='softmax')(encoded_L5) 14 | outlier_nn = tf.keras.Model(inputs=input_Feature,outputs=decoded,name='ONN') 15 | # outlier_nn.summary() -------------------------------------------------------------------------------- /optimizers/__pycache__/gan_optimizers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DebabrataPal7/DAFOSNET/7cfef39c8d6c5ec59c099c320e87efdfdc3c6302/optimizers/__pycache__/gan_optimizers.cpython-310.pyc -------------------------------------------------------------------------------- /optimizers/gan_optimizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import tensorflow as tf 5 | from loss.cosine_loss import cosine_loss 6 | from loss.AOCMC import anti_open_close_mode_collapse_loss 7 | 8 | #==========================Functions to optimize GAN========================== 9 | #Function to optimize low noise GAN (GAN 1) 10 | def gan_reptile(sembed,ep_class_labels,ep_domain_labels,generator,dis,optim_d,optim_g,stdev,alpha1): 11 | sembed = tf.cast(sembed,dtype=tf.float64) 12 | ep_class_dom_labels = [] 13 | uniques_len = len(set(ep_class_labels)) 14 | for i in range(len(ep_domain_labels)): 15 | if ep_domain_labels[i] == 0: 16 | #class [0,1,2] 17 | ep_class_dom_labels.append(ep_class_labels[i]) 18 | else: 19 | #other domain same class as pseudo classes [3,4,5...] 20 | ep_class_dom_labels.append(ep_class_labels[i]+uniques_len) 21 | batch_size=sembed.shape[0] 22 | one_hot_labels=tf.one_hot(ep_class_dom_labels, depth=10, axis=-1) 23 | one_hot_labels=tf.cast(tf.reshape(one_hot_labels,(batch_size,10)),dtype=tf.float64) 24 | 25 | one_hot_labels=tf.one_hot(ep_class_labels, depth=10, axis=-1)# 26 | one_hot_labels=tf.cast(tf.reshape(one_hot_labels,(batch_size,10)),dtype=tf.float64) 27 | 28 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=stdev),dtype=tf.float64) 29 | latent=tf.concat([vector,one_hot_labels], axis=1) 30 | 31 | # loss_function 32 | loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True) 33 | 34 | # generators_low 35 | generated_emb=tf.cast(generator(latent),dtype=tf.float64) 36 | 37 | fake_embs=tf.concat([generated_emb, one_hot_labels], axis=1) #[,74] 38 | real_embs=tf.concat([sembed, one_hot_labels], axis=1) 39 | combined_embs=tf.concat([fake_embs, real_embs], axis=0) 40 | labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0) 41 | 42 | # First-order meta-Learning - Reptile 43 | # Discriminator update 44 | old_vars_gen=generator.get_weights() 45 | old_vars_dis=dis.get_weights() 46 | with tf.GradientTape() as dis_tape: 47 | predictions=dis(combined_embs) 48 | d_loss=loss_fn(labels,predictions) 49 | grads=dis_tape.gradient(d_loss,dis.trainable_variables) 50 | optim_d.apply_gradients(zip(grads, dis.trainable_variables)) 51 | new_vars_dis=dis.get_weights() 52 | 53 | for var in range(len(new_vars_dis)): 54 | new_vars_dis[var]=tf.cast(old_vars_dis[var],dtype=tf.float64) + tf.cast(((new_vars_dis[var]-old_vars_dis[var])*alpha1),dtype=tf.float64) 55 | dis.set_weights(new_vars_dis) 56 | 57 | # Generator 58 | mis_labels=tf.zeros((batch_size,1)) 59 | old_gen=generator.get_weights() 60 | with tf.GradientTape() as gen_tape: 61 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=0.2),dtype=tf.float64) 62 | latent=tf.concat([vector,one_hot_labels], axis=1) 63 | fake_emb=tf.cast(generator(latent),dtype=tf.float64) 64 | fake_emb_and_labels=tf.concat([fake_emb, one_hot_labels], axis=-1) 65 | predictions= tf.cast(dis(fake_emb_and_labels),dtype=tf.float64) 66 | g_loss=loss_fn(mis_labels,predictions) 67 | g_grads=gen_tape.gradient(g_loss, generator.trainable_variables) 68 | optim_g.apply_gradients(zip(g_grads, generator.trainable_variables)) 69 | new_gen=generator.get_weights() 70 | 71 | for var in range(len(new_gen)): 72 | new_gen[var]= tf.cast(old_gen[var],dtype=tf.float64) + tf.cast(((new_gen[var]-old_gen[var])* alpha1),dtype=tf.float64) 73 | generator.set_weights(new_gen) 74 | 75 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=stdev),dtype=tf.float64) 76 | latent=tf.concat([vector,one_hot_labels], axis=1) 77 | sembed_gen=tf.cast(generator(latent),dtype=tf.float64) 78 | return sembed_gen,vector 79 | 80 | # Function to optimize high noise GAN (GAN 2) 81 | def gan_reptile_high(sembed,sembed_low,z1,ep_class_labels,ep_domain_labels,generator,dis,optim_d,optim_g,stdev,alpha1): 82 | sembed = tf.cast(sembed,dtype=tf.float64) 83 | ep_class_dom_labels = [] 84 | uniques_len = len(set(ep_class_labels)) 85 | for i in range(len(ep_domain_labels)): 86 | if ep_domain_labels[i] == 0: 87 | #class [0,1,2] 88 | ep_class_dom_labels.append(ep_class_labels[i]) 89 | else: 90 | #other domain same class as pseudo classes [3,4,5...] 91 | ep_class_dom_labels.append(ep_class_labels[i]+uniques_len) 92 | batch_size=sembed.shape[0] 93 | one_hot_labels=tf.one_hot(ep_class_dom_labels, depth=10, axis=-1) 94 | one_hot_labels=tf.cast(tf.reshape(one_hot_labels,(batch_size,10)),dtype=tf.float64) 95 | 96 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=stdev),dtype=tf.float64) 97 | latent=tf.concat([vector,one_hot_labels], axis=1) 98 | 99 | # loss_function 100 | loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True) 101 | 102 | # generators_High 103 | generated_emb=tf.cast(generator(latent),dtype=tf.float64) 104 | # AOL regulaizer loss 105 | c_lossd=cosine_loss(generated_emb,sembed_low,vector,z1) 106 | # AOCMC regulaizer loss 107 | # l_AOCMC = anti_open_close_mode_collapse_loss(sembed_low,generated_emb,z1,vector) 108 | 109 | fake_embs=tf.concat([generated_emb, one_hot_labels], axis=1)#[,74] 110 | real_embs=tf.concat([sembed, one_hot_labels], axis=1) 111 | combined_embs=tf.concat([fake_embs, real_embs], axis=0) 112 | labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0) 113 | 114 | # First-order meta-Learning - Reptile 115 | # discriminator 116 | old_vars_gen=generator.get_weights() 117 | old_vars_dis=dis.get_weights() 118 | with tf.GradientTape() as dis_tape: 119 | predictions=tf.cast(dis(combined_embs),dtype=tf.float64) 120 | d_loss=loss_fn(labels,predictions) + c_lossd 121 | grads=dis_tape.gradient(d_loss,dis.trainable_variables) 122 | optim_d.apply_gradients(zip(grads, dis.trainable_variables)) 123 | new_vars_dis=dis.get_weights() 124 | 125 | for var in range(len(new_vars_dis)): 126 | new_vars_dis[var]=old_vars_dis[var] + ((new_vars_dis[var]-old_vars_dis[var])*alpha1) 127 | dis.set_weights(new_vars_dis) 128 | 129 | # generator 130 | mis_labels=tf.zeros((batch_size,1)) 131 | old_gen=generator.get_weights() 132 | with tf.GradientTape() as gen_tape: 133 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=0.2),dtype=tf.float64) 134 | latent=tf.concat([vector,one_hot_labels], axis=1) 135 | fake_emb=tf.cast(generator(latent),dtype=tf.float64) 136 | c_lossg=cosine_loss(fake_emb,sembed_low,vector,z1) 137 | fake_emb_and_labels=tf.concat([fake_emb, one_hot_labels], axis=-1) 138 | predictions=dis(fake_emb_and_labels) 139 | g_loss=tf.cast(loss_fn(mis_labels,predictions),dtype=tf.float64) + tf.cast(c_lossg,dtype=tf.float64) 140 | g_grads=gen_tape.gradient(g_loss, generator.trainable_variables) 141 | optim_g.apply_gradients(zip(g_grads, generator.trainable_variables)) 142 | new_gen=generator.get_weights() 143 | for var in range(len(new_gen)): 144 | new_gen[var]= tf.cast(old_gen[var],dtype=tf.float64) + tf.cast(((new_gen[var]-old_gen[var])* alpha1),dtype=tf.float64) 145 | generator.set_weights(new_gen) 146 | 147 | vector=tf.cast(tf.random.normal(shape=(batch_size,8), stddev=stdev),dtype=tf.float64) 148 | latent=tf.concat([vector,one_hot_labels], axis=1) 149 | sembed_gen=tf.cast(generator(latent),dtype=tf.float64) 150 | return sembed_gen 151 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import tensorflow.keras.backend as K 7 | import os 8 | import json 9 | import numpy as np 10 | import datetime 11 | from sklearn.metrics import confusion_matrix 12 | from scipy.stats import hmean 13 | 14 | #save image batches to numpy files for faster processing 15 | from data.save_data_to_numpy import save_data 16 | #feature extractor 17 | from models.Feature_extractor import CDFSOSR_Model 18 | #outlier prediction network 19 | from models.outlier_network import outlier_nn 20 | #domain prediction network 21 | from models.domain_network import domain_nn 22 | #low noise GANs 23 | from models.low_noise_GAN import generator_nn_low_s 24 | from models.low_noise_GAN import dis_nn_low_s 25 | #high noise GANs 26 | from models.high_noise_GAN import generator_nn_high_s 27 | from models.high_noise_GAN import dis_nn_high_s 28 | #custom metrics (Variance) 29 | from utils.variance_metric import ComputeIterationVaraince 30 | #save accuracy values 31 | from utils.save_accuracy_values import save_accuracy_values 32 | #proto train 33 | from utils.proto_test import proto_test 34 | #new train episode 35 | from utils.test_episode import test_episode 36 | 37 | with open('./config.json') as json_file: 38 | config = json.load(json_file) 39 | 40 | base_dir = config['base_dir'] 41 | save_path_source = config['save_path_source'] 42 | save_path_target = config['save_path_target'] 43 | num_classes = config['num_classes']#train_classes+test_classes 44 | num_images_per_class = config['num_images_per_class'] 45 | 46 | # Hyperparameters for Support and Query images 47 | CS=config['CS'] 48 | CQ=config['CQ'] #Not necessarily equal 49 | Ss=config['Ss'] 50 | Sd=config['Sd'] 51 | Qsk=config['Qsk'] 52 | Qdk=config['Qdk']#N 53 | Qsu=config['Qsu'] 54 | Qdu=config['Qdu']#X 55 | 56 | # Test Classes 57 | test_classes = config['test_classes'] 58 | 59 | #embedding dimensions 60 | emb_dim = config['emb_dim'] 61 | 62 | #load the saved arrays into the memory 63 | source_domain = np.load(f'{save_path_source}/source_{num_classes}.npy') 64 | target_domain = np.load(f'{save_path_target}/target_{num_classes}.npy') 65 | 66 | # Numeric class labels 67 | class_labels = [] 68 | 69 | for i in range(len(source_domain)): 70 | class_labels.append(i) 71 | 72 | print("source_domain shape: ", source_domain.shape) 73 | print("target_domain shape: ", target_domain.shape) 74 | print("class_labels: ", class_labels) 75 | 76 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 77 | test_log_dir = 'logs/gradient_tape/' + current_time + '/test' 78 | test_summary_writer = tf.summary.create_file_writer(test_log_dir) 79 | 80 | # Specify the directory where the existing checkpoint is saved 81 | existing_checkpoint_dir = config["current_checkpoint_dir"] 82 | existing_checkpoint_prefix = os.path.join(existing_checkpoint_dir, "ckpt") 83 | # Load the existing checkpoint variables into the current model 84 | existing_checkpoint = tf.train.Checkpoint( 85 | CDFSOSR_Model = CDFSOSR_Model, 86 | generator_nn_low_s=generator_nn_low_s, generator_nn_high_s=generator_nn_high_s, dis_nn_low_s=dis_nn_low_s, dis_nn_high_s=dis_nn_high_s, 87 | outlier_nn=outlier_nn, 88 | domain_nn=domain_nn) 89 | existing_checkpoint.restore(tf.train.latest_checkpoint(existing_checkpoint_dir)) 90 | 91 | # Metrics to gather 92 | test_acc = tf.metrics.Mean(name='test_accuracy') 93 | test_openoa = tf.metrics.Mean(name='test_openoa') 94 | test_closedoa_cm = tf.metrics.Mean(name='test_closedoa_cm') 95 | test_outlier_acc = tf.metrics.Mean(name='test_outlier_acc') 96 | test_domain_acc = tf.metrics.Mean(name='test_domain_acc') 97 | 98 | def test_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu): 99 | accuracy, closed_oa_cm, outlier_det_acc, openoa, domain_acc = proto_test(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu) 100 | test_acc(accuracy) 101 | test_closedoa_cm(closed_oa_cm) 102 | test_openoa(openoa) 103 | test_outlier_acc(outlier_det_acc) 104 | test_domain_acc(domain_acc) 105 | 106 | for epoch in range(config["test_epochs"]): # 80 train + 80 tune + 100 train + 160 tune + 40 train 107 | test_acc.reset_states() 108 | test_closedoa_cm.reset_states() 109 | test_openoa.reset_states() 110 | test_outlier_acc.reset_states() 111 | test_domain_acc.reset_states() 112 | for epi in range(10): 113 | tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = test_episode(source_domain[-test_classes:,:,:,:,:], target_domain[-test_classes:,:,:,:,:], CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, class_labels[-test_classes:]) 114 | test_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu) 115 | print('query_labels:', query_labels) 116 | with test_summary_writer.as_default(): 117 | #tf.summary.scalar('accuracy', test_acc.result(), step=epoch) 118 | tf.summary.scalar('closedoa_cm', test_closedoa_cm.result(), step=epoch) 119 | tf.summary.scalar('openoa',test_openoa.result(), step=epoch) 120 | tf.summary.scalar('outlier_acc',test_outlier_acc.result(), step=epoch) 121 | tf.summary.scalar('domain_acc',test_domain_acc.result(), step=epoch) 122 | 123 | template = 'Epoch {}, Test Closed OA CM: {:.2f}, Test Open OA: {:.2f}, Test Outlier Det. Acc: {:.2f}, Test Domain Det. Acc: {:.2f}' 124 | print(template.format(epoch+1,test_closedoa_cm.result()*100, test_openoa.result()*100,test_outlier_acc.result()*100,test_domain_acc.result()*100)) 125 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import tensorflow.keras.backend as K 7 | import os 8 | import json 9 | import numpy as np 10 | import datetime 11 | from sklearn.metrics import confusion_matrix 12 | from scipy.stats import hmean 13 | 14 | #save image batches to numpy files for faster processing 15 | from data.save_data_to_numpy import save_data 16 | #feature extractor 17 | from models.Feature_extractor import CDFSOSR_Model 18 | #outlier prediction network 19 | from models.outlier_network import outlier_nn 20 | #domain prediction network 21 | from models.domain_network import domain_nn 22 | #low noise GANs 23 | from models.low_noise_GAN import generator_nn_low_s 24 | from models.low_noise_GAN import dis_nn_low_s 25 | #high noise GANs 26 | from models.high_noise_GAN import generator_nn_high_s 27 | from models.high_noise_GAN import dis_nn_high_s 28 | #custom metrics (Variance) 29 | from utils.variance_metric import ComputeIterationVaraince 30 | #save accuracy values 31 | from utils.save_accuracy_values import save_accuracy_values 32 | #proto train 33 | from utils.proto_train import proto_train 34 | #new train episode 35 | from utils.train_episode import new_episode 36 | #adamatch augmentation 37 | # from utils.adamatch_aug import adamatch_aug 38 | 39 | with open('./config.json') as json_file: 40 | config = json.load(json_file) 41 | 42 | base_dir = config['base_dir'] 43 | save_path_source = config['save_path_source'] 44 | save_path_target = config['save_path_target'] 45 | num_classes = config['num_classes']#train_classes+test_classes 46 | num_images_per_class = config['num_images_per_class'] 47 | 48 | # Hyperparameters for Support and Query images 49 | CS=config['CS'] 50 | CQ=config['CQ'] #Not necessarily equal 51 | Ss=config['Ss'] 52 | Sd=config['Sd'] 53 | Qsk=config['Qsk'] 54 | Qdk=config['Qdk']#N 55 | Qsu=config['Qsu'] 56 | Qdu=config['Qdu']#X 57 | 58 | # Train Class split 59 | train_classes = config['train_classes'] 60 | 61 | #embedding dimensions 62 | emb_dim = config['emb_dim'] 63 | 64 | #bool to check if we need to save the input images to a numpy file 65 | pre_process=config['pre_process'] 66 | 67 | #make numpy arrays of the image files 68 | if pre_process == "True": 69 | save_data(base_dir, save_path_source, save_path_target, num_classes, num_images_per_class) 70 | else: 71 | print("Not saving input images into Numpy files") 72 | 73 | #load the saved arrays into the memory 74 | source_domain = np.load(f'{save_path_source}/source_{num_classes}.npy') 75 | target_domain = np.load(f'{save_path_target}/target_{num_classes}.npy') 76 | 77 | 78 | # Numeric class labels 79 | class_labels = [] 80 | 81 | for i in range(len(source_domain)): 82 | class_labels.append(i) 83 | 84 | print("source_domain shape: ", source_domain.shape) 85 | print("target_domain shape: ", target_domain.shape) 86 | print("class_labels: ", class_labels) 87 | 88 | #====================================================================Optimizers================================================================ 89 | lr = config['lr'] 90 | initial_learning_rate = config['initial_learning_rate'] 91 | decay_rate = config['decay_rate'] 92 | decay_steps = config['decay_steps'] 93 | 94 | # Create the learning rate schedule 95 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 96 | initial_learning_rate, 97 | decay_steps=decay_steps, 98 | decay_rate=decay_rate, 99 | staircase=True) 100 | 101 | optim_d_low_s=tf.keras.optimizers.legacy.Adam(lr) 102 | optim_g_low_s=tf.keras.optimizers.legacy.Adam(lr) 103 | optim_d_high_s=tf.keras.optimizers.legacy.Adam(lr) 104 | optim_g_high_s=tf.keras.optimizers.legacy.Adam(lr) 105 | optim_model_Outlier = tf.keras.optimizers.legacy.Adam(initial_learning_rate) 106 | optim_model_Domain = tf.keras.optimizers.legacy.Adam(initial_learning_rate) 107 | optim2 = tf.keras.optimizers.legacy.Adam(lr)##Multitask Outlier network 108 | optim3 = tf.keras.optimizers.legacy.Adam(initial_learning_rate)##FE 109 | #====================================================================Train checkpoint================================================================ 110 | checkpoint_dir = config["existing_checkpoint_dir"] 111 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 112 | 113 | # Define a variable to keep track of the current epoch 114 | epoch_var = tf.Variable(0, dtype=tf.int64, name="epoch") 115 | 116 | checkpoint = tf.train.Checkpoint(epoch_var, 117 | optim3=optim3, optim2=optim2, 118 | optim_d_low_s=optim_d_low_s,optim_g_low_s=optim_g_low_s, optim_d_high_s=optim_d_high_s,optim_g_high_s=optim_g_high_s, 119 | optim_model_Outlier = optim_model_Outlier, 120 | optim_model_Domain = optim_model_Domain, 121 | CDFSOSR_Model = CDFSOSR_Model, 122 | generator_nn_low_s=generator_nn_low_s, generator_nn_high_s=generator_nn_high_s, dis_nn_low_s=dis_nn_low_s, dis_nn_high_s=dis_nn_high_s, 123 | outlier_nn=outlier_nn, 124 | domain_nn=domain_nn) 125 | ckpt_manager = tf.train.CheckpointManager(checkpoint,checkpoint_dir, max_to_keep=5) 126 | 127 | #===============================================================Train Step======================================================== 128 | # Metrics to gather 129 | train_loss = tf.metrics.Mean(name='train_loss') 130 | train_acc = tf.metrics.Mean(name='train_accuracy') 131 | train_closedoa_cm = tf.metrics.Mean(name='train_closedoa_cm') 132 | train_openoa = tf.metrics.Mean(name='train_openoa') 133 | train_outlier_acc = tf.metrics.Mean(name='train_outlier_acc') 134 | train_domain_acc = tf.metrics.Mean(name='train_domain_acc') 135 | train_closedoa_cm_variance = ComputeIterationVaraince(name="train_closedoa_cm_variance") 136 | train_openoa_variance = ComputeIterationVaraince(name="train_openoa_variance") 137 | 138 | ntimes=config['ntimes'] 139 | 140 | def train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels): 141 | # Forward & update gradients 142 | with tf.GradientTape() as tape: 143 | loss, accuracy, closed_oa_cm, outlier_det_acc, openoa, domain_acc = proto_train(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels, CDFSOSR_Model, outlier_nn, domain_nn, generator_nn_low_s, dis_nn_low_s,generator_nn_high_s, dis_nn_high_s, ntimes, lr_schedule, optim_model_Outlier, optim_model_Domain, optim_d_low_s,optim_g_low_s, optim_d_high_s,optim_g_high_s) 144 | gradients = tape.gradient(loss, CDFSOSR_Model.trainable_variables) 145 | # Compute the learning rate using the schedule 146 | learning_rate = lr_schedule(optim3.iterations) 147 | optim3.learning_rate = learning_rate 148 | optim3.apply_gradients(zip(gradients, CDFSOSR_Model.trainable_variables)) 149 | train_loss(loss) 150 | train_acc(accuracy) 151 | train_closedoa_cm(closed_oa_cm) 152 | train_openoa(openoa) 153 | train_outlier_acc(outlier_det_acc) 154 | train_domain_acc(domain_acc) 155 | train_closedoa_cm_variance.add(closed_oa_cm) 156 | train_openoa_variance.add(openoa) 157 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 158 | train_log_dir = 'logs/gradient_tape/' + current_time + '/train' 159 | train_summary_writer = tf.summary.create_file_writer(train_log_dir) 160 | 161 | #===============================================================Train Loop======================================================== 162 | start_epoch = 0 163 | if ckpt_manager.latest_checkpoint: 164 | start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) 165 | # restoring the latest checkpoint in checkpoint_path 166 | checkpoint.restore(ckpt_manager.latest_checkpoint) 167 | print("Restarted from: ", start_epoch) 168 | else: 169 | print("Training from scratch...") 170 | 171 | for epoch in range(start_epoch, config["train_epochs"]): # 80 train + 80 tune + 100 train + 160 tune + 40 train 172 | train_loss.reset_states() 173 | train_acc.reset_states() 174 | train_closedoa_cm.reset_states() 175 | train_openoa.reset_states() 176 | train_outlier_acc.reset_states() 177 | train_domain_acc.reset_states() 178 | train_closedoa_cm_variance.reset_states() 179 | train_openoa_variance.reset_states() 180 | epoch_var.assign_add(1) 181 | for epi in range(10): 182 | tquery_patches, tsupport_patches, query_labels, support_labels, support_classes, query_dom_labels, support_dom_labels = new_episode(source_domain[:train_classes,:,:,:,:], target_domain[:train_classes,:,:,:,:], CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, class_labels[:train_classes]) 183 | # tsupport_patches_aug, support_labels_aug, Ss_aug, Sd_aug, support_dom_labels_aug = adamatch_aug(tsupport_patches, support_labels, CS, Ss, Sd, support_dom_labels) 184 | train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels) 185 | with train_summary_writer.as_default(): 186 | tf.summary.scalar('loss', train_loss.result(), step=epoch) 187 | #tf.summary.scalar('accuracy', train_acc.result(), step=epoch) 188 | tf.summary.scalar('closedoa_cm', train_closedoa_cm.result(), step=epoch) 189 | tf.summary.scalar('openoa', train_openoa.result(), step=epoch) 190 | tf.summary.scalar('outlier_det_acc', train_outlier_acc.result(), step=epoch) 191 | tf.summary.scalar('domain_det_acc', train_domain_acc.result(), step=epoch) 192 | 193 | template = 'Epoch {}, Train Loss: {:.2f}, Train Closed OA: {:.2f},Train Open OA: {:.2f}, Train Outlier Det. Acc: {:.2f}, Train Domain Det. Acc: {:.2f}' 194 | print(template.format(epoch+1,train_loss.result(),train_closedoa_cm.result()*100,train_openoa.result()*100,train_outlier_acc.result()*100,train_domain_acc.result()*100)) 195 | #save mean accuracy values per epoch 196 | save_accuracy_values(train_closedoa_cm.result(), train_openoa.result(), 'accuracy_values.npy') 197 | #save variance for accuracys 198 | save_accuracy_values(train_closedoa_cm_variance.compute_variance(), train_openoa_variance.compute_variance(), 'accuracy_variance_values.npy') 199 | 200 | if epoch % config["save_every"] == 0 and epoch != 0 : 201 | print("Checkpoint saved at: ", checkpoint_dir) 202 | checkpoint.save(file_prefix = checkpoint_prefix) 203 | -------------------------------------------------------------------------------- /utils/adamatch_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | from keras_cv.layers import RandAugment 7 | 8 | #===================================Adamatch Augmentation=================================== 9 | 10 | RESIZE_TO=224 11 | # Initialize `RandAugment` object with 2 layers of 12 | # augmentation transforms and strength of 5. 13 | augmenter = RandAugment(value_range=(0, 255), augmentations_per_image=2, magnitude=0.5) 14 | 15 | def weak_augment(image, source=True): 16 | if image.dtype != tf.float64: 17 | image = tf.cast(image, tf.float64) 18 | image = tf.image.random_flip_left_right(image) 19 | image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3)) 20 | return image 21 | 22 | def strong_augment(image, source=True): 23 | if image.dtype != tf.float64: 24 | image = tf.cast(image, tf.float64) 25 | image = augmenter(image) 26 | return image 27 | 28 | def concat_images_to_same_class(tensor1,tensor2,num_classes,Ss1,Ss2): 29 | '''This function takes in two tensors where each array is of the form [[class1],[class2],[class3]....] and returns a single tensor of the form [[class1,class2,class3]....] by assigning images beloning to the same class from both tensors to the new tensor''' 30 | tensor1_domain1 = tensor1[:num_classes*Ss1,:,:,:] 31 | tensor2_domain1 = tensor2[:num_classes*Ss2,:,:,:] 32 | tensor_domain1 = tf.concat((tensor1_domain1, tensor2_domain1),axis=0) 33 | tensor1_domain2 = tensor1[num_classes*Ss1:,:,:,:] 34 | tensor2_domain2 = tensor2[num_classes*Ss2:,:,:,:] 35 | tensor_domain2 = tf.concat((tensor1_domain2, tensor2_domain2),axis=0) 36 | new_tensor = tf.concat((tensor_domain1,tensor_domain2),axis=0) 37 | return new_tensor 38 | 39 | def adamatch_aug(tsupport_patches, support_labels, CS, Ss, Sd, support_dom_labels): 40 | '''return original + strong + weak augmented samples Ss_aug = Ss*3 Sd_aug = Sd*3''' 41 | #patches 42 | tsupport_patches_w = tf.convert_to_tensor(list(map(strong_augment,tsupport_patches)), dtype=tf.float64) 43 | tsupport_patches_s = tf.convert_to_tensor(list(map(weak_augment,tsupport_patches)), dtype=tf.float64) 44 | 45 | #concat weak and strong augmented samples 46 | tsupport_concat_w_s = concat_images_to_same_class(tsupport_patches_w, tsupport_patches_s, CS, Ss, Ss) 47 | #concat weak and strong augmented samples with original samples 48 | tsupport_patches_concat_original_w_s = concat_images_to_same_class(tsupport_patches, tsupport_concat_w_s, CS, Ss, Ss*2) 49 | 50 | Ss_aug=Ss*3 51 | Sd_aug=Sd*3 52 | #labels 53 | support_labels_source = support_labels[:CS*Ss] 54 | support_labels_target = support_labels[CS*Ss:CS*Ss+CS*Sd] 55 | 56 | support_dom_labels_source = support_dom_labels[:CS*Ss] 57 | support_dom_labels_target = support_dom_labels[CS*Ss:CS*Ss+CS*Sd] 58 | 59 | support_labels_aug = support_labels_source*3 + support_labels_target*3 60 | support_dom_labels_aug = support_dom_labels_source*3 + support_dom_labels_target*3 61 | 62 | return tsupport_patches_concat_original_w_s, support_labels_aug, Ss_aug, Sd_aug, support_dom_labels_aug -------------------------------------------------------------------------------- /utils/get_gamma_beta_from_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from models.Feature_extractor import CDFSOSR_Model 5 | import tensorflow as tf 6 | 7 | #get current gamma beta values from FE 8 | def get_gamma_beta(CDFSOSR_Model): 9 | 10 | Gamma3_D1 = CDFSOSR_Model.get_layer('BN3_D1').get_weights()[0] 11 | Beta3_D1 = CDFSOSR_Model.get_layer('BN3_D1').get_weights()[1] 12 | Gamma3_D2 = CDFSOSR_Model.get_layer('BN3_D2').get_weights()[0] 13 | Beta3_D2 = CDFSOSR_Model.get_layer('BN3_D2').get_weights()[1] 14 | 15 | return Gamma3_D1, Gamma3_D2, Beta3_D1, Beta3_D2 -------------------------------------------------------------------------------- /utils/proto_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | #feature extractor 7 | from models.Feature_extractor import CDFSOSR_Model 8 | #outlier prediction network 9 | from models.outlier_network import outlier_nn 10 | #domain prediction network 11 | from models.domain_network import domain_nn 12 | #loss 13 | from loss.cross_domain_alignment import calc_weighted_gamma_dists 14 | from loss.prototype_diversification_loss import calc_triplet_dists 15 | #compute distances 16 | from loss.euclidian_dist import calc_euclidian_dists2 17 | #batch norm parameters 18 | from utils.get_gamma_beta_from_layer import get_gamma_beta 19 | #confusion matrix 20 | from sklearn.metrics import confusion_matrix 21 | #harmonic mean 22 | from scipy.stats import hmean 23 | 24 | def proto_test(support_images_tensor,query_images_tensor,support_labels_tensor,query_labels_tensor,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu):#3,3,4,1,5,5,5,5 25 | 26 | sembed = CDFSOSR_Model(support_images_tensor) # [x, 64]-No Aug-> [x, 64] (support) 27 | qembed = CDFSOSR_Model(query_images_tensor) # [p, 64] --> [p, 64] (query) 28 | 29 | qembed = tf.cast(qembed, dtype=tf.float64) 30 | 31 | sembed_s = [] 32 | sembed_s_labels = [] 33 | sembed_d = [] 34 | sembed_d_labels = [] 35 | for i in range(len(sembed)): #Support set is in the format [Ss1,Ss1,Ss1,Ss1,Sd1, Ss2,Ss2.....] 36 | if i % (Ss + Sd) < Ss: 37 | sembed_s.append(sembed[i]) 38 | sembed_s_labels.append(support_labels_tensor[i]) #source domain support images 39 | else: 40 | sembed_d.append(sembed[i]) 41 | sembed_d_labels.append(support_labels_tensor[i]) #target domain support images 42 | 43 | sembed_s = tf.convert_to_tensor(sembed_s, dtype=tf.float64) #(Ss*CS, 64) 44 | sembed_d = tf.convert_to_tensor(sembed_d, dtype=tf.float64) #(Sd*CS, 64) 45 | 46 | #========================================================================================== 47 | # Calculate the number of elements in each of the 4 arrays 48 | num_elements_qsk = CS * Qsk 49 | num_elements_qdk = CS * Qdk 50 | num_elements_qsu = CQ * Qsu 51 | num_elements_qdu = CQ * Qdu 52 | 53 | # Slice the query images embeddings to get the 4 arrays for calculating triplet loss 54 | query_Qsk = qembed[:num_elements_qsk] 55 | query_Qdk = qembed[num_elements_qsk:num_elements_qsk+num_elements_qdk] 56 | query_Qsu = qembed[num_elements_qsk+num_elements_qdk:num_elements_qsk+num_elements_qdk+num_elements_qsu] 57 | query_Qdu = qembed[num_elements_qsk+num_elements_qdk+num_elements_qsu:] 58 | 59 | # Slice the query labels tensor 60 | query_Qsk_labels = query_labels_tensor[:num_elements_qsk] 61 | query_Qdk_labels = query_labels_tensor[num_elements_qsk:num_elements_qsk+num_elements_qdk] 62 | 63 | #=======================================Prototypes=================================================== 64 | #support SOURCE DOMAIN 65 | 66 | z_prototypes_s = tf.reshape(sembed_s,[CS, Ss, sembed_s.shape[-1]]) 67 | # prototypes SOURCE DOMAIN 68 | z_prototypes_s = tf.reduce_mean(z_prototypes_s, axis=1) 69 | 70 | #TARGET DOMAIN 71 | z_prototypes_d = tf.reshape(sembed_d,[CS, Sd, sembed_d.shape[-1]]) 72 | # prototypes TARGET DOMAIN 73 | z_prototypes_d = tf.reduce_mean(z_prototypes_d , axis=1) 74 | 75 | #=======================================Adder 2=================================================== 76 | #Adder 2 SOURCE DOMAIN Just consider (unknown query) not known query 77 | query_Qsu = tf.convert_to_tensor(query_Qsu,dtype=tf.float64) 78 | #Adder 2 TARGET DOMAIN 79 | query_Qdu= tf.convert_to_tensor(query_Qdu,dtype=tf.float64) 80 | #=======================================Distances=================================================== 81 | #euclidian distances: 82 | dists_s_k = calc_euclidian_dists2(query_Qsk, z_prototypes_s) 83 | dists_d_k = calc_euclidian_dists2(query_Qdk, z_prototypes_d) 84 | dists_s_u = calc_euclidian_dists2(query_Qsu, z_prototypes_s) 85 | dists_d_u = calc_euclidian_dists2(query_Qdu, z_prototypes_d) 86 | 87 | #for domain net 88 | dists_s = tf.concat((dists_s_k,dists_s_u),axis=0) 89 | dists_d = tf.concat((dists_d_k,dists_d_u),axis=0) 90 | #concatanate distances 91 | dists = tf.concat((dists_s,dists_d),axis=0) #[CS*(Ss+Qsk+Qsu)+CS*(Sd+Qdk+Qdu), CS] 92 | 93 | log_p_y = tf.nn.log_softmax(-dists,axis=-1) #[CS*(Ss+Qsk+Qsu)+CS*(Sd+Qdk+Qdu), CS 94 | 95 | #for outlier net 96 | dists_inlier = tf.concat((dists_s_k,dists_d_k),axis=0) 97 | dists_outlier = tf.concat((dists_s_u,dists_d_u),axis=0) 98 | #concatanate distances 99 | dists_o = tf.concat((dists_outlier,dists_inlier),axis=0) 100 | 101 | #================================================================================================= 102 | #outlier prediction labels 103 | outliers_source_o = [0]*len(query_Qsu)#1-->inlier 0-->outlier 104 | outliers_target_o = [0]*len(query_Qdu) 105 | inliers_source_o = [1]*len(query_Qsk) 106 | inliers_target_o = [1]*len(query_Qdk) 107 | y_outlier = outliers_source_o + outliers_target_o + inliers_source_o + inliers_target_o 108 | 109 | #domain prediction labels: 110 | inliers_source_d = [0]*len(query_Qsk) #0-->source 1-->target 111 | outliers_source_d = [0]*len(query_Qsu) 112 | inliers_target_d = [1]*len(query_Qdk) 113 | outliers_target_d = [1]*len(query_Qdu) 114 | y_domain = inliers_source_d + outliers_source_d + inliers_target_d + outliers_target_d 115 | 116 | #======================================Accuracy=================================================== 117 | #accuracy calculation 118 | outlier_pred = outlier_nn(dists_o) 119 | domain_pred = domain_nn(dists) 120 | 121 | outlier_index = tf.squeeze(tf.argmax(outlier_pred,axis=-1)) 122 | 123 | #check the dimesions here before using this for calculating accuracy later 124 | predictions = tf.nn.softmax(-dists_o, axis=-1) #[0,0,1] [0,0,0] .... 125 | 126 | #return the index of the maximum value 127 | pred_index = tf.squeeze(tf.argmax(predictions,axis=-1))# for all query samples 128 | 129 | true_query_labels_source_k = query_Qsk_labels #true labels for known class source domain 130 | query_Qsu_labels = [-1]*len(query_Qsu) #unknown class labels so it does not matter what we assign as label as long as it is not present in support classes 131 | true_query_labels_target_k = query_Qdk_labels #true labels for known class target domain 132 | query_Qdu_labels = [-1]*len(query_Qdu) #unknown class labels so it does not matter what we assign as label as long as it is not present in support classes 133 | 134 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 135 | true_query_labels_s = query_Qsu_labels + true_query_labels_source_k#source out + source in 136 | true_query_labels_d = query_Qdu_labels + true_query_labels_target_k#target out + target in 137 | 138 | #source 139 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 140 | pred_index_s_out = pred_index[:len(outliers_source_o)] 141 | pred_index_s_in = pred_index[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o)] 142 | pred_index_s = tf.concat((pred_index_s_out,pred_index_s_in),axis=0) #source out + source in 143 | #target 144 | pred_index_d_out = pred_index[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o)] 145 | pred_index_d_in = pred_index[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):] 146 | pred_index_d = tf.concat((pred_index_d_out,pred_index_d_in),axis=0) #target out + target in 147 | 148 | #source 149 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 150 | outlier_index_s_out = outlier_index[:len(outliers_source_o)] 151 | outlier_index_s_in = outlier_index[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o)] 152 | outlier_index_s = tf.concat((outlier_index_s_out, outlier_index_s_in),axis=0) #source out + source in 153 | #target 154 | outlier_index_d_out = outlier_index[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o)] 155 | outlier_index_d_in = outlier_index[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):] 156 | outlier_index_d = tf.concat((outlier_index_d_out,outlier_index_d_in),axis=0) #target out + target in 157 | 158 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 159 | predictions_s_out = predictions[:len(outliers_source_o),:] 160 | predictions_s_in = predictions[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o),:] 161 | predictions_s = tf.concat((predictions_s_out, predictions_s_in),axis=0) #source out + source in 162 | #target 163 | predictions_d_out = predictions[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o),:] 164 | predictions_d_in = predictions[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):,:] 165 | predictions_d = tf.concat((predictions_d_out, predictions_d_in),axis=0) #target out + target in 166 | #rearranged predictions ==>source+target 167 | predictions_rearranged = tf.concat((predictions_s, predictions_d),axis=0) 168 | 169 | correct_pred_s = 0 170 | outlier_s = 0 171 | correct_pred_d = 0 172 | outlier_d = 0 173 | 174 | for i in range(len(true_query_labels_s)) : #SOURCE DOMAIN 175 | #if class is from support set 176 | if true_query_labels_s[i] in support_classes : #known class 177 | if outlier_index_s[i] == 1 : 178 | x = support_classes.index(true_query_labels_s[i]) 179 | if x == pred_index_s[i] : 180 | correct_pred_s += 1 181 | #if it is open set class 182 | else : #unknown class 183 | if outlier_index_s[i] == 0 : 184 | outlier_s = outlier_s + 1 185 | 186 | #only for query samples 187 | accuracy_s = correct_pred_s/(len(true_query_labels_source_k)) #only known samples from SOURCE DOMAIN 188 | outlier_det_acc_s = outlier_s/(len(query_Qsu_labels)) #only unknown samples 189 | 190 | for i in range(len(true_query_labels_d)) : #TARGET DOMAIN 191 | #if class is from support set 192 | if true_query_labels_d[i] in support_classes : #known class 193 | if outlier_index_d[i] == 1 : 194 | x = support_classes.index(true_query_labels_d[i]) 195 | if x == pred_index_d[i] : 196 | correct_pred_d += 1 197 | #if it is open set class 198 | else : #unknown class 199 | if outlier_index_d[i] == 0: 200 | outlier_d = outlier_d + 1 201 | 202 | #only for query samples 203 | accuracy_d = correct_pred_d/(len(true_query_labels_target_k)) #only known samples from TARGET DOMAIN 204 | outlier_det_acc_d = outlier_d/(len(query_Qdu_labels)) #only unknown samples 205 | 206 | accuracy = hmean([accuracy_s,accuracy_d]) 207 | print(outlier_det_acc_s,outlier_det_acc_d) 208 | outlier_det_acc = hmean([outlier_det_acc_s,outlier_det_acc_d]) 209 | 210 | y_pred_c_s = np.zeros((len(true_query_labels_source_k))) #all quries known from SOURCE 211 | for i in range(len(true_query_labels_source_k)) : 212 | if outlier_index_s_in[i] == 1 : 213 | y_pred_c_s[i] = pred_index_s_in[i] 214 | print(y_pred_c_s[i],true_query_labels_source_k[i]) 215 | else : 216 | y_pred_c_s[i] = -1 217 | cm1 = confusion_matrix(true_query_labels_source_k,y_pred_c_s) 218 | FP = cm1.sum(axis=0) - np.diag(cm1) 219 | FN = cm1.sum(axis=1) - np.diag(cm1) 220 | TP = np.diag(cm1) 221 | TN = cm1.sum() - (FP + FN + TP) 222 | closed_oa_s = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 223 | 224 | y_pred_c_d = np.zeros((len(true_query_labels_target_k))) #all quries known from TARGET 225 | for i in range(len(true_query_labels_target_k)) : 226 | if outlier_index_d_in[i] == 1 : 227 | y_pred_c_d[i] = pred_index_d_in[i] 228 | print(y_pred_c_d[i],true_query_labels_target_k[i]) 229 | else : 230 | y_pred_c_d[i] = -1 #Make sure 0 is not in support labels 231 | cm2 = confusion_matrix(true_query_labels_target_k,y_pred_c_d) 232 | FP = cm2.sum(axis=0) - np.diag(cm2) 233 | FN = cm2.sum(axis=1) - np.diag(cm2) 234 | TP = np.diag(cm2) 235 | TN = cm2.sum() - (FP + FN + TP) 236 | closed_oa_d = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 237 | 238 | closed_oa_cm = hmean([closed_oa_s, closed_oa_d]) 239 | 240 | #open oa 241 | y_pred_o_s = np.zeros((len(true_query_labels_s))) #all quries known+unknown from SOURCE 242 | for i in range(len(true_query_labels_s)) : 243 | if outlier_index_s[i] == 1 : 244 | y_pred_o_s[i] = pred_index_s[i] 245 | else : 246 | y_pred_o_s[i] = -1 #Make sure 0 is not in support labels 247 | cm3 = confusion_matrix(true_query_labels_s, y_pred_o_s) 248 | FP = cm3.sum(axis=0) - np.diag(cm3) 249 | FN = cm3.sum(axis=1) - np.diag(cm3) 250 | TP = np.diag(cm3) 251 | TN = cm3.sum() - (FP + FN + TP) 252 | open_oa_s = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 253 | 254 | y_pred_o_d = np.zeros((len(true_query_labels_d))) #all quries known+unknown from TARGET 255 | for i in range(len(true_query_labels_d)) : 256 | if outlier_index[i] == 1 : 257 | y_pred_o_d[i] = pred_index_d[i] 258 | else : 259 | y_pred_o_d[i] = -1 #Make sure 0 is not in support labels 260 | cm4 = confusion_matrix(pred_index_d,y_pred_o_d) 261 | FP = cm4.sum(axis=0) - np.diag(cm4) 262 | FN = cm4.sum(axis=1) - np.diag(cm4) 263 | TP = np.diag(cm4) 264 | TN = cm4.sum() - (FP + FN + TP) 265 | open_oa_d = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 266 | 267 | open_oa = hmean([open_oa_s,open_oa_d]) 268 | 269 | #domain prediction accuracy 270 | correct_domain=0 271 | dom_bin = tf.argmax(domain_pred,axis=-1) 272 | for i,j in zip(y_domain,dom_bin): 273 | if i==j: 274 | correct_domain+=1 275 | domain_acc = correct_domain/(len(dom_bin)) 276 | 277 | return accuracy, closed_oa_cm, outlier_det_acc, open_oa, domain_acc # scalar, scalar -------------------------------------------------------------------------------- /utils/proto_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | #Optimizers for GANs 7 | from optimizers.gan_optimizers import gan_reptile 8 | from optimizers.gan_optimizers import gan_reptile_high 9 | #loss 10 | from loss.cross_domain_alignment import calc_weighted_gamma_dists 11 | from loss.prototype_diversification_loss import calc_triplet_dists 12 | #compute distances 13 | from loss.euclidian_dist import calc_euclidian_dists2 14 | #batch norm parameters 15 | from utils.get_gamma_beta_from_layer import get_gamma_beta 16 | #confusion matrix 17 | from sklearn.metrics import confusion_matrix 18 | #harmonic mean 19 | from scipy.stats import hmean 20 | 21 | 22 | 23 | #===============================================================Proto Train function=========================================================== 24 | def proto_train(support_images_tensor,query_images_tensor,support_labels_tensor,query_labels_tensor,support_classes,CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, query_dom_labels, support_dom_labels, CDFSOSR_Model, outlier_nn, domain_nn, generator_nn_low_s, dis_nn_low_s,generator_nn_high_s, dis_nn_high_s, ntimes, lr_schedule, optim_model_Outlier, optim_model_Domain, optim_d_low_s,optim_g_low_s, optim_d_high_s,optim_g_high_s):#3,3,4,1,5,5,5,5 25 | sembed = CDFSOSR_Model(support_images_tensor) # [x, 64]-Aug-> [3x, 64] (support) 26 | qembed = CDFSOSR_Model(query_images_tensor) # [p, 64] --> [p, 64] (query) 27 | 28 | qembed = tf.cast(qembed, dtype=tf.float64) 29 | #================================GAN PART================================================== 30 | # gan --> reptile 31 | #source domain+target domain 32 | ngan = 5 33 | for _ in range(ngan): 34 | sembed_low,vector = gan_reptile(sembed,support_labels_tensor,support_dom_labels,generator_nn_low_s,dis_nn_low_s,optim_d_low_s,optim_g_low_s,stdev=0.2,alpha1=0.003) 35 | sembed_high = gan_reptile_high(sembed,sembed_low,vector,support_labels_tensor,support_dom_labels,generator_nn_high_s,dis_nn_high_s,optim_d_high_s,optim_g_high_s,stdev=1.0,alpha1=0.003) 36 | 37 | #========================================================================================== 38 | sembed_low_s = [] 39 | sembed_high_s = [] 40 | sembed_s_labels = [] 41 | sembed_s = [] 42 | sembed_low_d = [] 43 | sembed_high_d = [] 44 | sembed_d_labels = [] 45 | sembed_d = [] 46 | for i in range(len(sembed)): #Support set is in the format [Ss1,Ss1,Ss1,Ss1,Sd1, Ss2,Ss2.....] 47 | if i < CS*Ss: 48 | sembed_low_s.append(sembed_low[i]) 49 | sembed_high_s.append(sembed_high[i]) 50 | sembed_s.append(sembed[i]) 51 | sembed_s_labels.append(support_labels_tensor[i]) #source domain support images 52 | else: 53 | sembed_low_d.append(sembed_low[i]) 54 | sembed_high_d.append(sembed_high[i]) 55 | sembed_d.append(sembed[i]) 56 | sembed_d_labels.append(support_labels_tensor[i]) #target domain support images 57 | 58 | sembed_s = tf.convert_to_tensor(sembed_s, dtype=tf.float64) #(CS*Ss, 64) 59 | sembed_d = tf.convert_to_tensor(sembed_d, dtype=tf.float64) #(CS*Sd, 64) 60 | sembed_low_s = tf.convert_to_tensor(sembed_low_s, dtype=tf.float64) #(CS*Ss, 64) 61 | sembed_low_d = tf.convert_to_tensor(sembed_low_d, dtype=tf.float64) #(CS*Sd, 64) 62 | sembed_high_s = tf.convert_to_tensor(sembed_high_s, dtype=tf.float64) #(CS*Ss, 64) 63 | sembed_high_d = tf.convert_to_tensor(sembed_high_d, dtype=tf.float64) #(CS*Sd, 64) 64 | sembed_s_labels = tf.convert_to_tensor(sembed_s_labels, dtype=tf.float64) #(CS*Ss, 64) 65 | sembed_d_labels = tf.convert_to_tensor(sembed_d_labels, dtype=tf.float64) #(CS*Sd, 64) 66 | 67 | # Calculate the number of elements in each of the 4 arrays 68 | num_elements_qsk = CS * Qsk 69 | num_elements_qdk = CS * Qdk 70 | num_elements_qsu = CQ * Qsu 71 | num_elements_qdu = CQ * Qdu 72 | 73 | # Slice the query images embeddings to get the 4 arrays for calculating triplet loss 74 | query_Qsk = qembed[:num_elements_qsk] 75 | query_Qdk = qembed[num_elements_qsk:num_elements_qsk+num_elements_qdk] 76 | query_Qsu = qembed[num_elements_qsk+num_elements_qdk:num_elements_qsk+num_elements_qdk+num_elements_qsu] 77 | query_Qdu = qembed[num_elements_qsk+num_elements_qdk+num_elements_qsu:] 78 | 79 | # Slice the query labels tensor 80 | query_Qsk_labels = query_labels_tensor[:num_elements_qsk] 81 | query_Qdk_labels = query_labels_tensor[num_elements_qsk:num_elements_qsk+num_elements_qdk] 82 | 83 | #=======================================Prototypes=================================================== 84 | #support + low noise gan SOURCE DOMAIN 85 | z_proto_low_s = tf.reshape(sembed_low_s,[CS, Ss, sembed_low_s.shape[-1]]) 86 | z_prototypes_s = tf.reshape(sembed_s,[CS, Ss, sembed_s.shape[-1]]) #weak + strong + original support images 87 | z_prototypes_s_combined = np.empty((CS,Ss*2,sembed_s.shape[-1]), dtype=np.float64) 88 | 89 | #combine the prototypes from original images and GAN output for SOURCE DOMAIN 90 | #Adder 1 source 91 | for i in range(z_prototypes_s_combined.shape[0]): 92 | z_prototypes_s_combined[i,:,:] = tf.concat((tf.cast(z_prototypes_s[i,:,:], dtype=tf.float64),tf.cast(z_proto_low_s[i,:,:], dtype=tf.float64)),axis=0) #[original source support + low noise GAN output for source support] 93 | 94 | # prototypes SOURCE DOMAIN 95 | z_prototypes_s_combined = tf.reshape(z_prototypes_s_combined, [CS,Ss*2, sembed_s.shape[-1]]) 96 | z_prototypes_s_combined = tf.reduce_mean(z_prototypes_s_combined, axis=1) 97 | 98 | #support + low noise gan TARGET DOMAIN 99 | z_proto_low_d = tf.reshape(sembed_low_d,[CS, Sd, sembed_low_d.shape[-1]]) 100 | z_prototypes_d = tf.reshape(sembed_d,[CS, Sd, sembed_d.shape[-1]]) 101 | z_prototypes_d_combined = np.empty((CS,Sd*2,sembed_d.shape[-1]), dtype=np.float64) 102 | 103 | #combine the prototypes from original images and GAN output for SOURCE DOMAIN 104 | #Adder 1 target 105 | for i in range(z_prototypes_d_combined.shape[0]): 106 | z_prototypes_d_combined[i,:,:] = tf.concat((tf.cast(z_prototypes_d[i,:,:], dtype=tf.float64),tf.cast(z_proto_low_d[i,:,:], dtype=tf.float64)),axis=0) #[original target support + low noise GAN output for target support] 107 | 108 | # prototypes TARGET DOMAIN 109 | z_prototypes_d_combined = tf.reshape(z_prototypes_d_combined, [CS,Sd*2, sembed_d.shape[-1]]) 110 | z_prototypes_d_combined = tf.reduce_mean(z_prototypes_d_combined, axis=1) 111 | 112 | #=======================================Adder 2=================================================== 113 | #Adder 2 SOURCE DOMAIN Just consider (unknown query and high noise augmented) not known query 114 | qembedU_high_s = tf.convert_to_tensor(tf.concat((sembed_high_s,query_Qsu),axis=0),dtype=tf.float64) 115 | #Adder 2 TARGET DOMAIN 116 | qembedU_high_d = tf.convert_to_tensor(tf.concat((sembed_high_d,query_Qdu),axis=0),dtype=tf.float64) 117 | #=======================================Distances=================================================== 118 | #euclidian distances: 119 | dists_s_k = calc_euclidian_dists2(query_Qsk, z_prototypes_s_combined) #z_prototypes_s_combined with known query samples concatenated with z_prototypes_s_combined with unknown query samples 120 | dists_d_k = calc_euclidian_dists2(query_Qdk, z_prototypes_d_combined) 121 | dists_s_u = calc_euclidian_dists2(qembedU_high_s, z_prototypes_s_combined) #z_prototypes_s_combined with known query samples concatenated with z_prototypes_s_combined with unknown query samples 122 | dists_d_u = calc_euclidian_dists2(qembedU_high_d, z_prototypes_d_combined) 123 | 124 | #for domain net 125 | dists_s = tf.concat((dists_s_k,dists_s_u),axis=0) 126 | dists_d = tf.concat((dists_d_k,dists_d_u),axis=0) 127 | 128 | #concatanate distances 129 | dists = tf.concat((dists_s,dists_d),axis=0) #[CS*(Ss+Qsk+Qsu)+CS*(Sd+Qdk+Qdu), CS] 130 | 131 | log_p_y = tf.nn.log_softmax(-dists,axis=-1) ##[CS*(Ss+Qsk+Qsu)+CS*(Sd+Qdk+Qdu), CS] 132 | 133 | #for outlier net 134 | dists_inlier = tf.concat((dists_s_k,dists_d_k),axis=0) 135 | dists_outlier = tf.concat((dists_s_u,dists_d_u),axis=0) 136 | #concatanate distances 137 | dists_o = tf.concat((dists_outlier,dists_inlier),axis=0) 138 | 139 | #=====================================Domain + Outlier separate Networks============================== 140 | #two types of labels: 1) outlier prediction and 2)domain prediction 141 | #two types of labels: 1) outlier prediction and 2)domain prediction 142 | #outlier prediction labels 143 | outliers_source_o = [0]*len(qembedU_high_s)#1-->inlier 0-->outlier 144 | outliers_target_o = [0]*len(qembedU_high_d) 145 | inliers_source_o = [1]*len(query_Qsk) 146 | inliers_target_o = [1]*len(query_Qdk) 147 | y_outlier = outliers_source_o + outliers_target_o + inliers_source_o + inliers_target_o 148 | 149 | #domain prediction labels: 150 | inliers_source_d = [0]*len(query_Qsk) #0-->source 1-->target 151 | outliers_source_d = [0]*len(qembedU_high_s) 152 | inliers_target_d = [1]*len(query_Qdk) 153 | outliers_target_d = [1]*len(qembedU_high_d) 154 | y_domain = inliers_source_d + outliers_source_d + inliers_target_d + outliers_target_d 155 | 156 | for i in range(ntimes): 157 | with tf.GradientTape() as Outlier_tape: 158 | outlier_pred = outlier_nn(dists_o) 159 | outlier_pred = tf.squeeze(outlier_pred) 160 | loss_outlier = tf.keras.losses.SparseCategoricalCrossentropy()(y_outlier, outlier_pred)#0,1 161 | 162 | gradients_out = Outlier_tape.gradient(loss_outlier, outlier_nn.trainable_variables) 163 | # Compute the learning rate using the schedule 164 | learning_rate = lr_schedule(optim_model_Outlier.iterations) 165 | optim_model_Outlier.learning_rate = learning_rate 166 | optim_model_Outlier.apply_gradients(zip(gradients_out, outlier_nn.trainable_variables)) 167 | 168 | for i in range(ntimes): 169 | with tf.GradientTape() as Domain_tape: 170 | domain_pred = domain_nn(dists) 171 | domain_pred = tf.squeeze(domain_pred) 172 | loss_domain = tf.keras.losses.SparseCategoricalCrossentropy()(y_domain, domain_pred)#0,1 173 | 174 | gradients_dom = Domain_tape.gradient(loss_domain, domain_nn.trainable_variables) 175 | # Compute the learning rate using the schedule 176 | learning_rate = lr_schedule(optim_model_Domain.iterations) 177 | optim_model_Domain.learning_rate = learning_rate 178 | optim_model_Domain.apply_gradients(zip(gradients_dom, domain_nn.trainable_variables)) 179 | 180 | #======================================Accuracy=================================================== 181 | #accuracy calculation 182 | outlier_pred = outlier_nn(dists_o) 183 | domain_pred = domain_nn(dists) 184 | 185 | outlier_index = tf.squeeze(tf.argmax(outlier_pred,axis=-1)) 186 | 187 | #check the dimesions here before using this for calculating accuracy later 188 | predictions = tf.nn.softmax(-dists_o, axis=-1) #[0,0,1] [0,0,0] .... 189 | 190 | #return the index of the maximum value 191 | pred_index = tf.squeeze(tf.argmax(predictions,axis=-1))# for all query samples 192 | 193 | true_query_labels_source_k = query_Qsk_labels #true labels for known class source domain 194 | qembedU_high_s_labels = [-1]*len(qembedU_high_s) #unknown class labels so it does not matter what we assign as label as long as it is not present in support classes 195 | true_query_labels_target_k = query_Qdk_labels #true labels for known class target domain 196 | qembedU_high_d_labels = [-1]*len(qembedU_high_d) #unknown class labels so it does not matter what we assign as label as long as it is not present in support classes 197 | 198 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 199 | true_query_labels_s = qembedU_high_s_labels + true_query_labels_source_k#source out + source in 200 | true_query_labels_d = qembedU_high_d_labels + true_query_labels_target_k#target out + target in 201 | 202 | #source 203 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 204 | pred_index_s_out = pred_index[:len(outliers_source_o)] 205 | print(len(outliers_source_o)+len(outliers_target_o),":",len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o)) 206 | pred_index_s_in = pred_index[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o)] 207 | pred_index_s = tf.concat((pred_index_s_out,pred_index_s_in),axis=0) #source out + source in 208 | #target 209 | pred_index_d_out = pred_index[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o)] 210 | pred_index_d_in = pred_index[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):] 211 | pred_index_d = tf.concat((pred_index_d_out,pred_index_d_in),axis=0) #target out + target in 212 | 213 | #source 214 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 215 | outlier_index_s_out = outlier_index[:len(outliers_source_o)] 216 | outlier_index_s_in = outlier_index[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o)] 217 | outlier_index_s = tf.concat((outlier_index_s_out, outlier_index_s_in),axis=0) #source out + source in 218 | #target 219 | outlier_index_d_out = outlier_index[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o)] 220 | outlier_index_d_in = outlier_index[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):] 221 | outlier_index_d = tf.concat((outlier_index_d_out,outlier_index_d_in),axis=0) #target out + target in 222 | 223 | #source unknown + target unknown + target known + source known ==> in dists_o calculation for reference 224 | predictions_s_out = predictions[:len(outliers_source_o),:] 225 | predictions_s_in = predictions[len(outliers_source_o)+len(outliers_target_o):len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o),:] 226 | predictions_s = tf.concat((predictions_s_out, predictions_s_in),axis=0) #source out + source in 227 | #target 228 | predictions_d_out = predictions[len(outliers_source_o):len(outliers_source_o)+len(outliers_target_o),:] 229 | predictions_d_in = predictions[len(outliers_source_o)+len(outliers_target_o)+len(inliers_source_o):,:] 230 | predictions_d = tf.concat((predictions_d_out, predictions_d_in),axis=0) #target out + target in 231 | #rearranged predictions ==>source+target 232 | predictions_rearranged = tf.concat((predictions_s, predictions_d),axis=0) 233 | 234 | correct_pred_s = 0 235 | outlier_s = 0 236 | correct_pred_d = 0 237 | outlier_d = 0 238 | 239 | for i in range(len(true_query_labels_s)) : #known and unknown samples from SOURCE DOMAIN 240 | #if class is from support set 241 | if true_query_labels_s[i] in support_classes : #known class 242 | if outlier_index_s[i] == 1 : 243 | x = support_classes.index(true_query_labels_s[i]) 244 | if x == pred_index_s[i] : 245 | correct_pred_s += 1 246 | #if it is open set class 247 | else : #unknown class 248 | if outlier_index_s[i] == 0 : 249 | outlier_s = outlier_s + 1 250 | 251 | #only for query samples 252 | accuracy_s = correct_pred_s/(len(true_query_labels_source_k)) #only known samples from SOURCE DOMAIN 253 | outlier_det_acc_s = outlier_s/(len(qembedU_high_s_labels)) #only unknown samples 254 | 255 | for i in range(len(true_query_labels_d)) : #known and unknown samples from TARGET DOMAIN 256 | #if class is from support set 257 | if true_query_labels_d[i] in support_classes : #known class 258 | if outlier_index_d[i] == 1 : 259 | x = support_classes.index(true_query_labels_d[i]) 260 | if x == pred_index_d[i] : 261 | correct_pred_d += 1 262 | #if it is open set class 263 | else : #unknown class 264 | if outlier_index_d[i] == 0 : 265 | outlier_d = outlier_d + 1 266 | 267 | #only for query samples 268 | accuracy_d = correct_pred_d/(len(true_query_labels_target_k)) #only known samples from TARGET DOMAIN 269 | outlier_det_acc_d = outlier_d/(len(qembedU_high_d_labels)) #only unknown samples 270 | 271 | accuracy = hmean([accuracy_s,accuracy_d]) 272 | outlier_det_acc = hmean([outlier_det_acc_s,outlier_det_acc_d]) 273 | 274 | y_pred_c_s = np.zeros((len(true_query_labels_source_k))) #all quries known from SOURCE 275 | for i in range(len(true_query_labels_source_k)) : 276 | if outlier_index_s_in[i] == 1 : 277 | y_pred_c_s[i] = pred_index_s_in[i] 278 | else : 279 | y_pred_c_s[i] = -1 280 | cm1 = confusion_matrix(true_query_labels_source_k,y_pred_c_s) 281 | FP = cm1.sum(axis=0) - np.diag(cm1) 282 | FN = cm1.sum(axis=1) - np.diag(cm1) 283 | TP = np.diag(cm1) 284 | TN = cm1.sum() - (FP + FN + TP) 285 | closed_oa_s = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 286 | 287 | y_pred_c_d = np.zeros((len(true_query_labels_target_k))) #all quries known from TARGET 288 | for i in range(len(true_query_labels_target_k)) : 289 | if outlier_index_d_in[i] == 1 : 290 | y_pred_c_d[i] = pred_index_d_in[i] 291 | else : 292 | y_pred_c_d[i] = -1 #Make sure 0 is not in support labels 293 | cm2 = confusion_matrix(true_query_labels_target_k,y_pred_c_d) 294 | FP = cm2.sum(axis=0) - np.diag(cm2) 295 | FN = cm2.sum(axis=1) - np.diag(cm2) 296 | TP = np.diag(cm2) 297 | TN = cm2.sum() - (FP + FN + TP) 298 | closed_oa_d = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 299 | 300 | closed_oa_cm = hmean([closed_oa_s, closed_oa_d]) 301 | 302 | #open oa 303 | y_pred = np.zeros((len(true_query_labels_s))) #all quries known+unknown from SOURCE 304 | for i in range(len(true_query_labels_s)) : 305 | if outlier_index_s[i] == 1 : 306 | y_pred[i] = pred_index_s[i] 307 | else : 308 | y_pred[i] = -1 #Make sure 0 is not in support labels 309 | cm2 = confusion_matrix(true_query_labels_s,y_pred) 310 | FP = cm2.sum(axis=0) - np.diag(cm2) 311 | FN = cm2.sum(axis=1) - np.diag(cm2) 312 | TP = np.diag(cm2) 313 | TN = cm2.sum() - (FP + FN + TP) 314 | open_oa_s = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 315 | 316 | y_pred = np.zeros((len(true_query_labels_d))) #all quries known+unknown from TARGET 317 | for i in range(len(true_query_labels_d)) : 318 | if outlier_index[i] == 1 : 319 | y_pred[i] = pred_index_d[i] 320 | else : 321 | y_pred[i] = -1 #Make sure 0 is not in support labels 322 | cm2 = confusion_matrix(pred_index_d,y_pred) 323 | FP = cm2.sum(axis=0) - np.diag(cm2) 324 | FN = cm2.sum(axis=1) - np.diag(cm2) 325 | TP = np.diag(cm2) 326 | TN = cm2.sum() - (FP + FN + TP) 327 | open_oa_d = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN)) 328 | 329 | open_oa = hmean([open_oa_s,open_oa_d]) 330 | 331 | one_hot_labels_s = np.asarray(np.zeros((len(true_query_labels_s),CS)),dtype=np.float64) # (*, 3) 332 | one_hot_labels_d = np.asarray(np.zeros((len(true_query_labels_d),CS)),dtype=np.float64) # (*, 3) 333 | #map source query labels irrespective of class to [0. 0. 0.],[0. 0. 1.],[0. 1. 0.][1. 0. 0. 334 | for i in range(len(true_query_labels_s)) : 335 | if true_query_labels_s[i] in support_classes : 336 | x = support_classes[:CS].index(true_query_labels_s[i]) 337 | one_hot_labels_s[i][x] = 1. # [[0., 0., 1.], [0., 0., 0.], ... (*,3) 338 | print(i,x) 339 | #map target query labels irrespective of class to [0. 0. 0.],[0. 0. 1.],[0. 1. 0.][1. 0. 0.] 340 | for i in range(len(true_query_labels_d)) : 341 | if true_query_labels_d[i] in support_classes : 342 | x = support_classes[CS:].index(true_query_labels_d[i]) 343 | one_hot_labels_d[i][x] = 1. # [[0., 0., 1.], [0., 0., 0.], ... (*,3) 344 | 345 | #domain prediction accuracy 346 | correct_domain=0 347 | dom_bin = tf.argmax(domain_pred,axis=-1) 348 | for i,j in zip(y_domain,dom_bin): 349 | if i==j: 350 | correct_domain+=1 351 | domain_acc = correct_domain/(len(dom_bin)) 352 | 353 | #FE loss 354 | quad_dis_s = calc_triplet_dists(z_prototypes_s_combined, tf.concat((qembedU_high_s,query_Qsk),axis=0), one_hot_labels_s) 355 | quad_dis_d = calc_triplet_dists(z_prototypes_d_combined, tf.concat((qembedU_high_d,query_Qdk),axis=0), one_hot_labels_d) 356 | #add P_loss 357 | Gamma_D1,Gamma_D2,Beta_D1,Beta_D2 = get_gamma_beta(CDFSOSR_Model) 358 | 359 | cross_domain_align_loss = calc_weighted_gamma_dists(z_prototypes_s_combined,z_prototypes_d_combined,Gamma_D1,Gamma_D2,Beta_D1,Beta_D2) 360 | 361 | cec_loss = -tf.reduce_mean((tf.reduce_sum(tf.multiply(tf.cast(tf.concat((one_hot_labels_s,one_hot_labels_d),axis=0),dtype=tf.float64), predictions_rearranged), axis=-1))) 362 | 363 | ## Choose appropriate weighting scalars for the loss acccording to your dataset 364 | loss = tf.cast(0.4*100*cec_loss, tf.float64) + tf.cast(0.3*loss_outlier, tf.float64) + tf.cast(0.05*loss_domain, tf.float64) + tf.cast(0.075*quad_dis_s, tf.float64) + tf.cast(0.125*quad_dis_d, tf.float64) + tf.cast(0.05*1000*cross_domain_align_loss, tf.float64) 365 | 366 | return loss, accuracy, closed_oa_cm, outlier_det_acc, open_oa, domain_acc # scalar, scalar -------------------------------------------------------------------------------- /utils/save_accuracy_values.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | #save closed oa and open oa values per epoch 4 | def save_accuracy_values(closed_oa, open_oa, filepath): 5 | """ 6 | Saves closed and open overall accuracy scalar values per epoch to a numpy file. 7 | 8 | Args: 9 | closed_oa: A scalar value representing the closed overall accuracy per epoch. 10 | open_oa: A scalar value representing the open overall accuracy per epoch. 11 | epoch: An integer representing the epoch number. 12 | filepath: A string representing the file path to save the accuracy values to. 13 | """ 14 | data = np.array([[closed_oa, open_oa]]) 15 | 16 | # Check if the file exists 17 | if os.path.isfile(filepath): 18 | # Load the existing numpy file 19 | npfile = np.load(filepath) 20 | # Append the new data to the numpy file 21 | npfile = np.concatenate((npfile, data), axis=0) 22 | # Save the modified numpy file 23 | np.save(filepath, npfile) 24 | else: 25 | # Create a new numpy file with the data 26 | np.save(filepath, data) 27 | 28 | print(f"Accuracy values saved to {filepath}") -------------------------------------------------------------------------------- /utils/test_episode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | def test_episode(source_domain, target_domain, CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, class_labels): 9 | # Select CS classes from source_domain 10 | support_classes = list(class_labels[:CS]) 11 | # Select Ss samples from the selected classes in source domain as part of the support set 12 | tsupport_patches_d, support_labels_d = [], [] 13 | # Select Sd samples from the selected classes in target domain as part of the support set 14 | tsupport_patches_t, support_labels_t = [], [] 15 | # Select Qsk images from the previously selected class in (1a) which were NOT SELECTED EARLIER from domain 1 16 | known_query_patches_1, known_query_labels_1 = [], [] 17 | # Select the same images (Qdk=Qsk) in (2a) for the same previously selected class in (1a) from domain 2 18 | known_query_patches_2, known_query_labels_2 = [], [] 19 | if Ss>0: 20 | for x in range(len(support_classes)): 21 | sran_indices = np.random.choice(source_domain[x].shape[0], Ss, replace=False) 22 | support_patches = source_domain[x][sran_indices,:,:,:] 23 | tsupport_patches_d.extend(support_patches) 24 | for i in range(Ss): 25 | support_labels_d.append(x) 26 | 27 | unselected_indices = [j for j in list(range(source_domain[x].shape[0])) if j not in sran_indices] # to make sure there is no overlap between source and query images 28 | qran_indices_known = np.random.choice(unselected_indices, Qsk, replace=False) 29 | query_patches_1 = source_domain[x][qran_indices_known,:,:,:] 30 | known_query_patches_1.extend(query_patches_1) 31 | for i in range(Qsk): 32 | known_query_labels_1.append(x) 33 | 34 | if Sd>0: 35 | sran_indices = np.random.choice(target_domain[x].shape[0], Sd, replace=False) 36 | support_patches = target_domain[x][sran_indices,:,:,:] 37 | tsupport_patches_t.extend(support_patches) 38 | for i in range(Sd): 39 | support_labels_t.append(x) 40 | 41 | unselected_indices = [j for j in list(range(target_domain[x].shape[0])) if j not in sran_indices] # to make sure there is no overlap between source and query images 42 | qran_indices_known = np.random.choice(unselected_indices, Qdk, replace=False) 43 | query_patches_2 = target_domain[x][qran_indices_known,:,:,:] 44 | known_query_patches_2.extend(query_patches_2) 45 | for i in range(Qsk): 46 | known_query_labels_2.append(x) 47 | 48 | # 1c and 1d together form the support set 49 | tsupport_patches = tsupport_patches_d + tsupport_patches_t # They are python arrays hence adding, not tf.tensors 50 | support_labels = support_labels_d + support_labels_t 51 | 52 | # 2a and 2b together form the known classes query set 53 | tquery_patches = known_query_patches_1 + known_query_patches_2 # They are python arrays hence adding, not tf.tensors 54 | query_labels = known_query_labels_1 + known_query_labels_2 55 | 56 | # Select in other CQ classes which was NOT SELECTED EARLIER in (1a) from domain 1 as a part of unknown query set 57 | other_classes = [c for c in class_labels if c not in support_classes] 58 | selected_classes = list(np.random.choice(other_classes, CQ, replace=False)) 59 | 60 | # Randomly select Qsu images from the selected classes (CQ) in (2d) from domain 1 as a part of unknown query set 61 | unknown_query_patches_1, unknown_query_labels_1 = [], [] 62 | # Randomly select Qdu images from the selected classes (CQ) in (2d) from domain 2 as a part of unknown query set 63 | unknown_query_patches_2, unknown_query_labels_2 = [], [] 64 | if Qsu>0: 65 | for x in range(len(selected_classes)): 66 | qran_indices_unknown = np.random.choice(source_domain[x].shape[0], Qsu, replace=False) 67 | query_patches_1 = source_domain[x][qran_indices_unknown,:,:,:] 68 | unknown_query_patches_1.extend(query_patches_1) 69 | for i in range(Qsu): 70 | unknown_query_labels_1.append(x) 71 | if Qdu>0: 72 | query_patches_2 = target_domain[x][qran_indices_unknown,:,:,:] 73 | unknown_query_patches_2.extend(query_patches_2) 74 | for i in range(Qdu): 75 | unknown_query_labels_2.append(x) 76 | 77 | # 2e and 2f together form the unknown classes query set 78 | unknown_query_patches = unknown_query_patches_1 + unknown_query_patches_2 # They are python arrays hence adding, not tf.tensors 79 | unknown_query_labels = unknown_query_labels_1 + unknown_query_labels_2 80 | 81 | # Concatenate the known and unknown query sets 82 | tquery_patches += unknown_query_patches 83 | query_labels += unknown_query_labels 84 | 85 | # Convert the lists to tensors 86 | tquery_patches = tf.convert_to_tensor(tquery_patches, dtype=tf.float64)/255.0 # Normalize # (Qsk+Qdk)*CS + (Qsu+Qdu)*CQ) ,image_height,image_width,channels) 87 | tsupport_patches = tf.convert_to_tensor(tsupport_patches, dtype=tf.float64)/255.0 # Normalize #( (CS*Ss+CS*Sd),image_height,image_width,channels) 88 | 89 | return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes -------------------------------------------------------------------------------- /utils/train_episode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Import necessary libraries 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | #====================================Data Loader====================================: 9 | def new_episode(source_domain, target_domain, CS, CQ, Ss, Sd, Qsk, Qdk, Qsu, Qdu, class_labels): 10 | # Select CS(source) + CS(target) support classes from the universal set 11 | support_classes = list(np.random.choice(class_labels, CS*2, replace=False)) 12 | # Select CS classes from source_domain 13 | source_support_classes = list(np.random.choice(support_classes, CS, replace=False)) 14 | # Select Ss samples from the selected classes in source domain as part of the support set 15 | tsupport_patches_d, support_labels_d, support_dom_labels_d = [], [], [] 16 | # Select Sd samples from the selected classes in target domain as part of the support set 17 | tsupport_patches_t, support_labels_t, support_dom_labels_t = [], [], [] 18 | # Select Qsk images from the previously selected class in (1a) which were NOT SELECTED EARLIER from domain 1 19 | known_query_patches_1, known_query_labels_1, known_query_dom_labels_1 = [], [], [] 20 | # Select the same images (Qdk=Qsk) in (2a) for the same previously selected class in (1a) from domain 2 21 | known_query_patches_2, known_query_labels_2, known_query_dom_labels_2 = [], [], [] 22 | 23 | if Ss>0: 24 | for x in source_support_classes: 25 | sran_indices = np.random.choice(source_domain[x].shape[0], Ss, replace=False) 26 | support_patches = source_domain[x][sran_indices,:,:,:] 27 | tsupport_patches_d.extend(support_patches) 28 | support_dom_labels_d.extend([0]*len(sran_indices)) 29 | for i in range(Ss): 30 | support_labels_d.append(x) 31 | 32 | unselected_indices = [j for j in list(range(source_domain[x].shape[0])) if j not in sran_indices] #to make sure there is no overlap between source and query images 33 | qran_indices_known = np.random.choice(unselected_indices, Qsk, replace=False) 34 | query_patches_1 = source_domain[x][qran_indices_known,:,:,:] 35 | known_query_patches_1.extend(query_patches_1) 36 | known_query_dom_labels_1.extend([0]*len(qran_indices_known)) 37 | for i in range(Qsk): 38 | known_query_labels_1.append(x) 39 | 40 | target_support_classes = [c for c in support_classes if c not in source_support_classes] 41 | if Sd>0: 42 | for x in target_support_classes: 43 | sran_indices = np.random.choice(target_domain[x].shape[0], Sd, replace=False) 44 | support_patches = target_domain[x][sran_indices,:,:,:] 45 | support_dom_labels_t.extend([1]*len(sran_indices)) 46 | tsupport_patches_t.extend(support_patches) 47 | for i in range(Sd): 48 | support_labels_t.append(x) 49 | 50 | unselected_indices = [j for j in list(range(target_domain[x].shape[0])) if j not in sran_indices] #to make sure there is no overlap between source and query images 51 | qran_indices_known = np.random.choice(unselected_indices, Qdk, replace=False) 52 | query_patches_2 = target_domain[x][qran_indices_known,:,:,:] 53 | known_query_patches_2.extend(query_patches_2) 54 | known_query_dom_labels_2.extend([1]*len(qran_indices_known)) 55 | for i in range(Qsk): 56 | known_query_labels_2.append(x) 57 | 58 | # 1c and 1d together form the support set 59 | tsupport_patches = tsupport_patches_d + tsupport_patches_t # They are python arrays hence adding, not tf.tensors 60 | support_labels = support_labels_d + support_labels_t 61 | support_dom_labels = support_dom_labels_d + support_dom_labels_t 62 | 63 | # 2a and 2b together form the known classes query set 64 | tquery_patches = known_query_patches_1 + known_query_patches_2 # They are python arrays hence adding, not tf.tensors 65 | query_labels = known_query_labels_1 + known_query_labels_2 66 | query_dom_labels = known_query_dom_labels_1 + known_query_dom_labels_2 67 | 68 | # Select in other CQ classes which was NOT SELECTED EARLIER in (1a) from domain 1 as a part of unknown query set 69 | other_classes = [c for c in class_labels if c not in support_classes] 70 | # print('Unselected classes:', other_classes) 71 | unknown_classes = list(np.random.choice(other_classes, CQ*2, replace=False)) 72 | source_unknown_classes = list(np.random.choice(unknown_classes, CQ, replace=False)) 73 | # Randomly select Qsu images from the selected classes (CQ) in (2d) from domain 1 as a part of unknown query set 74 | unknown_query_patches_1, unknown_query_labels_1, unknown_query_labels_dom_1 = [], [], [] 75 | # Randomly select Qdu images from the selected classes (CQ) in (2d) from domain 2 as a part of unknown query set 76 | unknown_query_patches_2, unknown_query_labels_2, unknown_query_labels_dom_2 = [], [], [] 77 | if Qsu>0: 78 | for x in source_unknown_classes: 79 | qran_indices_unknown = np.random.choice(source_domain[x].shape[0], Qsu, replace=False) 80 | query_patches_1 = source_domain[x][qran_indices_unknown,:,:,:] 81 | unknown_query_patches_1.extend(query_patches_1) 82 | unknown_query_labels_dom_1.extend([0]*len(qran_indices_unknown)) 83 | for i in range(Qsu): 84 | unknown_query_labels_1.append(x) 85 | 86 | source_unknown_and_support_classes = source_unknown_classes + support_classes 87 | target_unknown_classes = [c for c in class_labels if c not in source_unknown_and_support_classes] 88 | 89 | if Qdu>0: 90 | for x in target_unknown_classes: 91 | query_patches_2 = target_domain[x][qran_indices_unknown,:,:,:] 92 | unknown_query_patches_2.extend(query_patches_2) 93 | unknown_query_labels_dom_2.extend([1]*len(qran_indices_unknown)) 94 | for i in range(Qdu): 95 | unknown_query_labels_2.append(x) 96 | 97 | # 2e and 2f together form the unknown classes query set 98 | unknown_query_patches = unknown_query_patches_1 + unknown_query_patches_2 # They are python arrays hence adding, not tf.tensors 99 | unknown_query_labels = unknown_query_labels_1 + unknown_query_labels_2 100 | unknown_query_labels_dom = unknown_query_labels_dom_1 + unknown_query_labels_dom_2 101 | 102 | # Concatenate the known and unknown query sets 103 | tquery_patches += unknown_query_patches 104 | query_labels += unknown_query_labels 105 | query_dom_labels += unknown_query_labels_dom 106 | 107 | # Convert the lists to tensors 108 | #rearrange support classes in following format [:CS]-->source [CS:]-->target 109 | support_classes_rearranged = source_support_classes + target_support_classes 110 | tquery_patches = tf.convert_to_tensor(tquery_patches, dtype=tf.float64)/255.0 # Normalize # (Qsk+Qdk)*CS + (Qsu+Qdu)*CQ) ,image_height,image_width,channels) 111 | tsupport_patches = tf.convert_to_tensor(tsupport_patches, dtype=tf.float64)/255.0 # Normalize #( (CS*Ss+CS*Sd),image_height,image_width,channels) 112 | 113 | return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes_rearranged, query_dom_labels, support_dom_labels -------------------------------------------------------------------------------- /utils/variance_metric.py: -------------------------------------------------------------------------------- 1 | #custom variance computation per epoch 2 | class ComputeIterationVaraince: 3 | def __init__(self,name=""): 4 | self.name = name 5 | self.array = [] 6 | 7 | def add(self,number): 8 | self.array.append(number) 9 | 10 | def compute_variance(self): 11 | numbers = self.array 12 | n = len(numbers) 13 | mean = sum(numbers) / n 14 | variance = sum((x - mean) ** 2 for x in numbers) / n 15 | return variance 16 | 17 | def reset_states(self): 18 | self.array = [] --------------------------------------------------------------------------------