├── VR_Encoder ├── dataloader │ ├── vg150_dataset.py │ └── vrr_vg_dataset.py ├── configs │ ├── model_config.yaml │ ├── train_config.yaml │ └── data_config.yaml └── model │ ├── concat.py │ └── vtranse.py ├── .gitignore ├── VR_SimilarityNetwork ├── dataloader │ ├── VG150Dataset.py │ ├── VrRVGDatasetTrain.py │ └── VrRVGDatasetTest.py ├── configs │ ├── model_config.yaml │ ├── train_config.yaml │ ├── data_config_train.yaml │ ├── data_config_test.yaml │ └── test_config.yaml └── model │ ├── SimilarityNetworkConcat.py │ └── SimilarityNetworkVREncoder.py ├── requirements.txt ├── data ├── vrrvg_predicates_test.json └── vrrvg_predicates_train.json ├── data_preparation ├── README.md └── vrc_extract_frcnn_feats.py ├── utils ├── sampling_utils.py ├── utils.py └── logger.py ├── README.md ├── SimilarityNetworkTrain.py ├── ConcatplusSimilarityNetworkTrain.py ├── FullModelTest.py └── train_vr_encoder.py /VR_Encoder/dataloader/vg150_dataset.py: -------------------------------------------------------------------------------- 1 | # TODO: include vg150 dataset here -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | *.pyc 3 | *.txt 4 | !requirements.txt 5 | *.pt 6 | *.pth 7 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/dataloader/VG150Dataset.py: -------------------------------------------------------------------------------- 1 | # TODO: Implement datasetloader for vg150 dataset for similarity network 2 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/configs/model_config.yaml: -------------------------------------------------------------------------------- 1 | SimilarityNetworkVREncoderInputSize: 1000 2 | SimilarityNetworkConcatInputSize: 7306 -------------------------------------------------------------------------------- /VR_Encoder/configs/model_config.yaml: -------------------------------------------------------------------------------- 1 | model_name: VTransE 2 | index_sp: True 3 | index_cls: True 4 | num_pred: 100 5 | output_size: 500 6 | input_size: 500 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.62.0 2 | scipy==1.7.1 3 | omegaconf==2.1.0 4 | numpy==1.20.3 5 | opencv_python==4.4.0.42 6 | Pillow==8.3.1 7 | PyYAML==5.4.1 8 | -------------------------------------------------------------------------------- /VR_Encoder/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | saved_checkpoints : VR_encoder_ckeckpoints 2 | logs : VR_encoder_logs 3 | num_train_epochs : 50 4 | 5 | optimizer: 6 | params: 7 | eps: 1.0e-08 8 | lr: 3e-4 9 | weight_decay: 0.0 10 | type: Adam 11 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | epochs: 10 2 | batch_size: 1 3 | NETWORK: SimilarityNetworkVREncoder # options are : SimilarityNetworkVREncoder and SimilarityNetworkConcat 4 | NETWORK_WEIGHTS_PATH: SimilarityNetworkWeights/ 5 | Dataset: VrRVG 6 | scheduler: 7 | mode: max 8 | factor: 0.1 9 | patience: 0 10 | verbose: True 11 | -------------------------------------------------------------------------------- /VR_Encoder/configs/data_config.yaml: -------------------------------------------------------------------------------- 1 | per_gpu_train_batch_size: 64 2 | xml_file_path: VrR-VG # path to vrr-vg xml files directory 3 | npy_file_path: all_vg_frcnn # path to faster r-cnn features directory 4 | saved_dir: saved_vr_encoder_relations # path to save vr during training 5 | train_predicates_path: data/vrrvg_predicates_train.json 6 | saved_vtranse_input: True # save relations while training 7 | training_split_ratio : 0.8 -------------------------------------------------------------------------------- /VR_Encoder/model/concat.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Concat(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, sub_sp, sub_cls, sub_fc, ob_sp, ob_cls, ob_fc): 10 | sub_emb = torch.cat([sub_sp, sub_cls, sub_fc], dim=1) 11 | ob_emb = torch.cat([ob_sp, ob_cls, ob_fc], dim=1) 12 | 13 | vr_emb = torch.cat([sub_emb, ob_emb], dim=1) 14 | return vr_emb 15 | -------------------------------------------------------------------------------- /data/vrrvg_predicates_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "biting": 0, 3 | "petting": 1, 4 | "sniffing": 2, 5 | "pointing": 3, 6 | "placed on": 4, 7 | "stacked on": 5, 8 | "balancing on": 6, 9 | "drawn on": 7, 10 | "sewn on": 8, 11 | "sticking out of": 9, 12 | "at bottom of": 10, 13 | "following": 11, 14 | "entering": 12, 15 | "leaning on": 13, 16 | "in corner of": 14, 17 | "surrounded by": 15, 18 | "in center of": 16 19 | } -------------------------------------------------------------------------------- /VR_SimilarityNetwork/configs/data_config_train.yaml: -------------------------------------------------------------------------------- 1 | SAMPLE_PER_RELATION: 1 # samples/bags for each predicates , total train predicates = 100 2 | XML_FILE_PATH_VrRVG: VrR-VG # xml files for VrRVG 3 | NPY_FILE_PATH: all_vg_frcnn 4 | DATASET: VrRVG # options are VrRVG and VG150 5 | VREncoder_Net_Checkpoint: VR_encoder_checkpoints/checkpoint_2.pt 6 | train_predicates_path : data/vrrvg_predicates_train.json 7 | VREncoderEmbeddings: VTransE # VTransE and VRConcat 8 | VREncoderConfig: VR_Encoder/configs/model_config.yaml 9 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/configs/data_config_test.yaml: -------------------------------------------------------------------------------- 1 | XML_FILE_PATH_VrRVG: VrR-VG # xml files for VrRVG 2 | NPY_FILE_PATH: all_vg_frcnn 3 | DATASET: VrRVG # options are VrRVG and VG150 4 | VREncoderConfig: VR_Encoder/configs/model_config.yaml 5 | VREncoder_Net_Checkpoint: VR_encoder_checkpoints/checkpoint_2.pt 6 | SAMPLE_PER_RELATION: 1 # samples/bags for each predicates , total train predicates = 100 7 | test_predicates_path : data/vrrvg_predicates_test.json 8 | VisualGenomeImageDir1: visual_genome/VG_100K/ 9 | VisualGenomeImageDir2: visual_genome/VG_100K_2/ -------------------------------------------------------------------------------- /VR_SimilarityNetwork/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | RESULT_FOLDER: SimilarityNetworkResults/ 2 | SUBJECT_ANCHORED: False # subject anchored for each image 3 | ANCHOR_IMAGE: False 4 | BATCH_SIZE: 2 5 | RelationNET_CHECKPOINT: SimilarityNetworkWeights/SimilarityNetworkVREncodercheckpoint_1.pth #/DATA/trevant/Vaibhav/tempVRC/SimilarityNetworkWeights/SimilarityNetworkConcatcheckpoint.pth #/DATA/trevant/Vaibhav/tempVRC/SimilarityNetworkWeights/SimilarityNetworkVREncodercheckpoint_1.pth 6 | SIMILARITY_NET_CONCAT_CHECKPOINT : SimilarityNetworkWeights/SimilarityNetworkConcatcheckpoint.pth 7 | BAG_SIZE: 2 8 | SIMILARITY: relation_net # options are relation_net and cosine 9 | CONCAT: False # use concat network or not 10 | top_k: 5 11 | SAVE_OUTPUT: False 12 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/model/SimilarityNetworkConcat.py: -------------------------------------------------------------------------------- 1 | # Defining Network 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SimilarityNetworkConcat(nn.Module): 8 | def __init__(self, model_config): 9 | super().__init__() 10 | input_size = model_config.SimilarityNetworkConcatInputSize * 2 11 | self.fc1 = nn.Linear(input_size, 7306) 12 | self.fc2 = nn.Linear(input_size, 7306) 13 | self.fc3 = nn.Linear(7306, 1) 14 | 15 | 16 | def forward(self,rela1, rela2): 17 | # Assuming that both relations are of shape [batch_size, 500], we keep the axis as 1 if shape is not like this please change axis. 18 | x = torch.cat((rela1, rela2)) 19 | x1 = F.tanh((self.fc1(x))) 20 | x2 = torch.sigmoid((self.fc2(x))) 21 | x = (x1 * x2 ) 22 | x = x + ((rela1 + rela2)/2) 23 | x = self.fc3(x) 24 | return x 25 | 26 | -------------------------------------------------------------------------------- /data_preparation/README.md: -------------------------------------------------------------------------------- 1 | # Instructions to extract faster r-cnn features: 2 | Check the script [vrc_extract_frcnn_feats.py](vrc_extract_frcnn_feats.py) and follow these instructions (also written as comments in the script) 3 | 4 | 0. Activate vrc conda environment 5 | ``` 6 | $ conda activate vrc 7 | ``` 8 | 9 | 1. Install maskrcnn-benchmark : FRCNN Model 10 | ``` 11 | $ git clone https://gitlab.com/meetshah1995/vqa-maskrcnn-benchmark.git 12 | $ cd vqa-maskrcnn-benchmark 13 | $ python setup.py build 14 | $ python setup.py develop 15 | ``` 16 | 2. download pre-trained detectron weights 17 | ``` 18 | $ mkdir detectron_weights 19 | $ wget -O detectron_weights/detectron_model.pth https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.pth 20 | $ wget -O detectron_weights/detectron_model.yaml https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.yaml 21 | ``` 22 | 23 | NOTE: just modify the code in /content/vqa-maskrcnn-benchmark/maskrcnn_benchmark/utils/imports.py, change PY3 to PY37 24 | 25 | to run the script 26 | ``` 27 | $ python vrc_extract_frcnn_feats.py --image_dir= 28 | ``` 29 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/model/SimilarityNetworkVREncoder.py: -------------------------------------------------------------------------------- 1 | # Defining Network 2 | from utils.utils import load_config_file 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | use_cuda = torch.cuda.is_available() 8 | print(use_cuda) 9 | cuda = torch.device('cuda') 10 | 11 | 12 | MODEL_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/model_config.yaml" 13 | 14 | model_config = load_config_file(MODEL_CONFIG_PATH) 15 | 16 | 17 | class SimilarityNetworkVREncoder(nn.Module): 18 | def __init__(self, model_config): 19 | super().__init__() 20 | input_size = model_config.SimilarityNetworkVREncoderInputSize 21 | self.fc1 = nn.Linear(input_size, 500) 22 | self.fc2 = nn.Linear(input_size, 500) 23 | self.fc3 = nn.Linear(500, 1) 24 | self.tan_bn = nn.BatchNorm1d(500) 25 | self.sig_bn = nn.BatchNorm1d(500) 26 | 27 | def forward(self,rela1, rela2): 28 | # Assuming that both relations are of shape [batch_size, 500], we keep the axis as 1 if shape is not like this please change axis. 29 | x = torch.cat((rela1, rela2)) 30 | x1 = F.tanh((self.fc1(x))) 31 | x2 = torch.sigmoid((self.fc2(x))) 32 | x = (x1 * x2 ) + ((rela1 + rela2)/2) 33 | x = self.fc3(x) 34 | return x 35 | -------------------------------------------------------------------------------- /utils/sampling_utils.py: -------------------------------------------------------------------------------- 1 | def get_iou(bb1, bb2): 2 | assert bb1['xmin'] < bb1['xmax'] 3 | assert bb1['ymin'] < bb1['ymax'] 4 | assert bb2['xmin'] < bb2['xmax'] 5 | assert bb2['ymin'] < bb2['ymax'] 6 | 7 | # determine the coordinates of the intersection rectangle 8 | x_left = max(bb1['xmin'], bb2['xmin']) 9 | y_top = max(bb1['ymin'], bb2['ymin']) 10 | x_right = min(bb1['xmax'], bb2['xmax']) 11 | y_bottom = min(bb1['ymax'], bb2['ymax']) 12 | 13 | if x_right < x_left or y_bottom < y_top: 14 | return 0.0 15 | 16 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 17 | bb1_area = (bb1['xmax'] - bb1['xmin']) * (bb1['ymax'] - bb1['ymin']) 18 | bb2_area = (bb2['xmax'] - bb2['xmin']) * (bb2['ymax'] - bb2['ymin']) 19 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 20 | assert iou >= 0.0 21 | assert iou <= 1.0 22 | return iou 23 | 24 | 25 | def get_roi_index(gt_bbox, info): 26 | indexes = [] 27 | rois_info = info.item().get('bbox') 28 | rois = rois_info.shape[0] 29 | subj_roi_iou = [] 30 | for i in range(0, rois): 31 | bbox_roi = rois_info[i] 32 | roi_bbox_dict = {} 33 | roi_bbox_dict["xmin"] = float(bbox_roi[0]) 34 | roi_bbox_dict["ymin"] = float(bbox_roi[1]) 35 | roi_bbox_dict["xmax"] = float(bbox_roi[2]) 36 | roi_bbox_dict["ymax"] = float(bbox_roi[3]) 37 | 38 | iou = get_iou(gt_bbox, roi_bbox_dict) 39 | if(iou > 0.60): 40 | indexes.append(i) 41 | subj_roi_iou.append(iou) 42 | 43 | return indexes, subj_roi_iou 44 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | # import pandas as pd 4 | from pathlib import Path 5 | from collections import OrderedDict 6 | import errno 7 | import os 8 | import os.path as op 9 | import yaml 10 | import random 11 | import numpy as np 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def mkdir(path): 16 | # if it is the current folder, skip. 17 | if path == '': 18 | return 19 | try: 20 | os.makedirs(path) 21 | except OSError as e: 22 | if e.errno != errno.EEXIST: 23 | raise 24 | 25 | 26 | def load_config_file(file_path): 27 | with open(file_path, 'r') as fp: 28 | return OmegaConf.load(fp) 29 | 30 | 31 | def set_seed(seed, n_gpu): 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | if n_gpu > 0: 36 | torch.cuda.manual_seed_all(seed) 37 | 38 | 39 | def load_from_yaml_file(yaml_file): 40 | with open(yaml_file, 'r') as fp: 41 | return yaml.load(fp, Loader=yaml.FullLoader) 42 | 43 | 44 | def find_file_path_in_yaml(fname, root): 45 | if fname is not None: 46 | if op.isfile(fname): 47 | return fname 48 | elif op.isfile(op.join(root, fname)): 49 | return op.join(root, fname) 50 | else: 51 | raise FileNotFoundError( 52 | errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) 53 | ) 54 | 55 | 56 | def ensure_dir(dirname): 57 | dirname = Path(dirname) 58 | if not dirname.is_dir(): 59 | dirname.mkdir(parents=True, exist_ok=False) 60 | 61 | 62 | def read_json(fname): 63 | fname = Path(fname) 64 | with fname.open('rt') as handle: 65 | return json.load(handle, object_hook=OrderedDict) 66 | 67 | 68 | def write_json(content, fname): 69 | fname = Path(fname) 70 | with fname.open('wt') as handle: 71 | json.dump(content, handle, indent=4, sort_keys=False) 72 | -------------------------------------------------------------------------------- /VR_Encoder/model/vtranse.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class VTransE(nn.Module): 6 | def __init__(self, index_sp=True, index_cls=True, num_pred=100, output_size=500, input_size=500): 7 | super().__init__() 8 | self.index_sp = index_sp 9 | self.index_cls = index_cls 10 | if(index_sp == True): 11 | input_size += 4 # 2028 12 | if(index_cls == True): 13 | input_size += 1601 # 1601 14 | self.roi_contract_sub = nn.Linear(2048, 500) 15 | self.roi_contract_obj = nn.Linear(2048, 500) 16 | 17 | self.relu = nn.ReLU() 18 | self.sub_fc_layer = nn.Linear(input_size, output_size) 19 | self.obj_fc_layer = nn.Linear(input_size, output_size) 20 | self.rela_layer = nn.Linear(output_size, num_pred) 21 | 22 | def forward(self, sub_sp, sub_cls, sub_fc, ob_sp, ob_cls, ob_fc): 23 | sub_fc = self.relu(self.roi_contract_sub(sub_fc)) 24 | ob_fc = self.relu(self.roi_contract_obj(ob_fc)) 25 | if self.index_sp: 26 | sub_fc = torch.cat([sub_fc, sub_sp], axis=1) 27 | ob_fc = torch.cat([ob_fc, ob_sp], axis=1) 28 | if self.index_cls: 29 | sub_fc = torch.cat([sub_fc, sub_cls], axis=1) 30 | ob_fc = torch.cat([ob_fc, ob_cls], axis=1) 31 | sub_emb = torch.relu(self.sub_fc_layer(sub_fc)) 32 | ob_emb = torch.relu(self.obj_fc_layer(ob_fc)) 33 | 34 | vr_emb = ob_emb - sub_emb 35 | 36 | rela_score = self.rela_layer(vr_emb) 37 | 38 | return rela_score, vr_emb 39 | 40 | def forward_inference(self, sub_sp, sub_cls, sub_fc, ob_sp, ob_cls, ob_fc): 41 | sub_fc = self.relu(self.roi_contract_sub(sub_fc)) 42 | ob_fc = self.relu(self.roi_contract_obj(ob_fc)) 43 | if self.index_sp: 44 | sub_fc = torch.cat([sub_fc, sub_sp]) 45 | ob_fc = torch.cat([ob_fc, ob_sp]) 46 | if self.index_cls: 47 | sub_fc = torch.cat([sub_fc, sub_cls]) 48 | ob_fc = torch.cat([ob_fc, ob_cls]) 49 | sub_emb = torch.relu(self.sub_fc_layer(sub_fc)) 50 | ob_emb = torch.relu(self.obj_fc_layer(ob_fc)) 51 | 52 | vr_emb = ob_emb - sub_emb 53 | 54 | rela_score = self.rela_layer(vr_emb) 55 | 56 | return rela_score, vr_emb -------------------------------------------------------------------------------- /data/vrrvg_predicates_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "talking on": 0, 3 | "printed on": 1, 4 | "bordering": 2, 5 | "leading to": 3, 6 | "alongside": 4, 7 | "grabbing": 5, 8 | "besides": 6, 9 | "on other side of": 7, 10 | "on back of": 8, 11 | "reflected on": 9, 12 | "surrounding": 10, 13 | "are inside": 11, 14 | "working on": 12, 15 | "skiing down": 13, 16 | "swimming in": 14, 17 | "tied around": 15, 18 | "swinging": 16, 19 | "among": 17, 20 | "stopped at": 18, 21 | "seen in": 19, 22 | "beside": 20, 23 | "appearing in": 21, 24 | "gripping": 22, 25 | "standing with": 23, 26 | "surfing in": 24, 27 | "catching": 25, 28 | "containing": 26, 29 | "touching": 27, 30 | "built into": 28, 31 | "running in": 29, 32 | "atop": 30, 33 | "belonging to": 31, 34 | "written on": 32, 35 | "across": 33, 36 | "draped over": 34, 37 | "wrapped around": 35, 38 | "line": 36, 39 | "boarding": 37, 40 | "brown": 38, 41 | "laying in": 39, 42 | "beneath": 40, 43 | "on head of": 41, 44 | "walking towards": 42, 45 | "hitting": 43, 46 | "crashing on": 44, 47 | "painted": 45, 48 | "part of": 46, 49 | "eating from": 47, 50 | "sleeping on": 48, 51 | "against": 49, 52 | "apart of": 50, 53 | "hanging on a": 51, 54 | "on surface of": 52, 55 | "selling": 53, 56 | "mounted on": 54, 57 | "beyond": 55, 58 | "waiting at": 56, 59 | "adorning": 57, 60 | "licking": 58, 61 | "displayed on": 59, 62 | "located on": 60, 63 | "flying in": 61, 64 | "overlooking": 62, 65 | "matches": 63, 66 | "down": 64, 67 | "chasing": 65, 68 | "leaving": 66, 69 | "marking": 67, 70 | "to a": 68, 71 | "separating": 69, 72 | "in middle of": 70, 73 | "before": 71, 74 | "playing with": 72, 75 | "stuck in": 73, 76 | "decorating": 74, 77 | "reflecting off": 75, 78 | "served on": 76, 79 | "train": 77, 80 | "driving down": 78, 81 | "dressed in": 79, 82 | "sitting on a": 80, 83 | "on end of": 81, 84 | "securing": 82, 85 | "facing": 83, 86 | "light": 84, 87 | "at edge of": 85, 88 | "jumping on": 86, 89 | "supporting": 87, 90 | "visible on": 88, 91 | "grazing on": 89, 92 | "approaching": 90, 93 | "between": 91, 94 | "moving": 92, 95 | "attached": 93, 96 | "formed in": 94, 97 | "top of": 95, 98 | "resting on": 96, 99 | "under a": 97, 100 | "floating in": 98, 101 | "standing near": 99 102 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VRC 2 | **Official implementation of the Few-shot Visual Relationship Co-localization (ICCV 2021) paper** 3 | 4 | [project page](https://vl2g.github.io/projects/vrc/) | [paper](https://vl2g.github.io/projects/vrc/docs/VRC-ICCV2021.pdf) 5 | 6 | ## Requirements 7 | * Use **python >= 3.8.5**. Conda recommended : [https://docs.anaconda.com/anaconda/install/linux/](https://docs.anaconda.com/anaconda/install/linux/) 8 | 9 | * Use **pytorch 1.7.0 CUDA 10.2** 10 | 11 | * Other requirements from 'requirements.txt' 12 | 13 | **To setup environment** 14 | ``` 15 | # create new env vrc 16 | $ conda create -n vrc python=3.8.5 17 | 18 | # activate vrc 19 | $ conda activate vrc 20 | 21 | # install pytorch, torchvision 22 | $ conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch 23 | 24 | # install other dependencies 25 | $ pip install -r requirements.txt 26 | ``` 27 | 28 | ## Training 29 | 30 | ### Preparing dataset 31 | - Download VG images from [https://visualgenome.org/](https://visualgenome.org/) 32 | 33 | - Extract faster_rcnn features of VG images using [data_preparation/vrc_extract_frcnn_feats.py](data_preparation/vrc_extract_frcnn_feats.py). Please follow instructions [here](data_preparation/README.md). 34 | 35 | - Download VrR-VG dataset from [http://vrr-vg.com/](http://vrr-vg.com/) or [Google Drive Link](https://drive.google.com/file/d/1X7lYDviVKJI9bGmQAbQikTM271P3aoWZ/view?usp=sharing) 36 | 37 | ### Training VR Encoder (VTransE) 38 | 39 | #### Training parameters 40 | To check and update training, model and dataset parameters see [VR_Encoder/configs](VR_Encoder/configs) 41 | 42 | #### To train VR Encoder: 43 | ``` 44 | $ python train_vr_encoder.py 45 | ``` 46 | 47 | ### Training VR Similarity Network (Relation Network) 48 | 49 | #### Training parameters 50 | To check and update training, testing, model and dataset parameters see [VR_SimilarityNetwork/configs](VR_SimilarityNetwork/configs) 51 | 52 | #### To train VR Similarity Network: 53 | ``` 54 | $ python SimilarityNetworkTrain.py 55 | ``` 56 | 57 | #### To train VR Similarity Network (w/ concat as VR Encoding): 58 | ``` 59 | $ python ConcatplusSimilarityNetworkTrain.py 60 | ``` 61 | 62 | #### To evaluate (set eval setting in [test_config.yaml](VR_SimilarityNetwork/configs/test_config.yaml)) 63 | ``` 64 | $ python FullModelTest.py 65 | ``` 66 | 67 | ## Cite 68 | If you find this code/paper useful for your research, please consider citing. 69 | ``` 70 | @InProceedings{teotiaMMM2021, 71 | author = "Teotia, Revant and Mishra, Vaibhav and Maheshwari, Mayank and Mishra, Anand", 72 | title = "Few-shot Visual Relationship Co-Localization", 73 | booktitle = "ICCV", 74 | year = "2021", 75 | } 76 | ``` 77 | 78 | ## Acknowledgements 79 | This repo uses https://gitlab.com/meetshah1995/vqa-maskrcnn-benchmark and scripts from https://github.com/facebookresearch/mmf for Faster R-CNN feature extraction. 80 | 81 | Code provided by https://github.com/zawlin/cvpr17_vtranse and https://github.com/yangxuntu/vrd helped in implementing VR encoder. 82 | 83 | 84 | ### Contact 85 | For any clarification, comment, or suggestion please create an issue or contact [Revant](https://revantteotia.github.io/), [Vaibhav](https://www.linkedin.com/in/vaibhav-mishra-iitj/) or [Mayank](https://www.linkedin.com/in/maheshwarimayank333/). -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import StreamHandler, Handler, getLevelName 3 | import os 4 | import sys 5 | 6 | 7 | # this class is a copy of logging.FileHandler except we end self.close() 8 | # at the end of each emit. While closing file and reopening file after each 9 | # write is not efficient, it allows us to see partial logs when writing to 10 | # fused Azure blobs, which is very convenient 11 | class FileHandler(StreamHandler): 12 | """ 13 | A handler class which writes formatted logging records to disk files. 14 | """ 15 | def __init__(self, filename, mode='a', encoding=None, delay=False): 16 | """ 17 | Open the specified file and use it as the stream for logging. 18 | """ 19 | # Issue #27493: add support for Path objects to be passed in 20 | filename = os.fspath(filename) 21 | #keep the absolute path, otherwise derived classes which use this 22 | #may come a cropper when the current directory changes 23 | self.baseFilename = os.path.abspath(filename) 24 | self.mode = mode 25 | self.encoding = encoding 26 | self.delay = delay 27 | if delay: 28 | #We don't open the stream, but we still need to call the 29 | #Handler constructor to set level, formatter, lock etc. 30 | Handler.__init__(self) 31 | self.stream = None 32 | else: 33 | StreamHandler.__init__(self, self._open()) 34 | 35 | def close(self): 36 | """ 37 | Closes the stream. 38 | """ 39 | self.acquire() 40 | try: 41 | try: 42 | if self.stream: 43 | try: 44 | self.flush() 45 | finally: 46 | stream = self.stream 47 | self.stream = None 48 | if hasattr(stream, "close"): 49 | stream.close() 50 | finally: 51 | # Issue #19523: call unconditionally to 52 | # prevent a handler leak when delay is set 53 | StreamHandler.close(self) 54 | finally: 55 | self.release() 56 | 57 | def _open(self): 58 | """ 59 | Open the current base file with the (original) mode and encoding. 60 | Return the resulting stream. 61 | """ 62 | return open(self.baseFilename, self.mode, encoding=self.encoding) 63 | 64 | def emit(self, record): 65 | """ 66 | Emit a record. 67 | If the stream was not opened because 'delay' was specified in the 68 | constructor, open it before calling the superclass's emit. 69 | """ 70 | if self.stream is None: 71 | self.stream = self._open() 72 | StreamHandler.emit(self, record) 73 | self.close() 74 | 75 | def __repr__(self): 76 | level = getLevelName(self.level) 77 | return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level) 78 | 79 | 80 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 81 | logger = logging.getLogger(name) 82 | logger.setLevel(logging.DEBUG) 83 | # don't log results for the non-master process 84 | if distributed_rank > 0: 85 | return logger 86 | ch = logging.StreamHandler(stream=sys.stdout) 87 | ch.setLevel(logging.DEBUG) 88 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 89 | ch.setFormatter(formatter) 90 | logger.addHandler(ch) 91 | 92 | if save_dir: 93 | fh = FileHandler(os.path.join(save_dir, filename)) 94 | fh.setLevel(logging.DEBUG) 95 | fh.setFormatter(formatter) 96 | logger.addHandler(fh) 97 | 98 | return logger -------------------------------------------------------------------------------- /SimilarityNetworkTrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | from utils.utils import load_config_file 6 | from VR_SimilarityNetwork.model.SimilarityNetworkConcat import SimilarityNetworkConcat 7 | from VR_SimilarityNetwork.model.SimilarityNetworkVREncoder import SimilarityNetworkVREncoder 8 | from VR_SimilarityNetwork.dataloader.VrRVGDatasetTrain import VrRVGDatasetTrain 9 | from VR_Encoder.model.vtranse import VTransE 10 | from VR_Encoder.model.concat import Concat 11 | from tqdm import tqdm 12 | import time 13 | 14 | DATA_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/data_config_train.yaml" 15 | TRAINER_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs//train_config.yaml" 16 | MODEL_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/model_config.yaml" 17 | 18 | ####################################### 19 | # Defining the loss 20 | def episodic_loss(r, R): 21 | return torch.log(1+torch.exp(-R*r)) 22 | ####################################### 23 | 24 | def save_checkpoint(checkpoint, train_config): 25 | time.sleep(10) 26 | path = train_config.NETWORK + '_checkpoint.pth' 27 | torch.save(checkpoint, path) 28 | 29 | def per_sample_training(bag_size, ith_bag, net): 30 | cnt_pairs = 0 31 | bag_loss = 0.0 32 | 33 | for j in range(bag_size): 34 | image_ind_1=j 35 | image_ind_2=(j+1)%bag_size 36 | 37 | n_positive_1=len(ith_bag["relations"][image_ind_1]["positive_relations"]) # count positive relations in image 1 38 | n_negative_1=len(ith_bag["relations"][image_ind_1]["negative_relations"]) # count negative relations in image 1 39 | n_positive_2=len(ith_bag["relations"][image_ind_2]["positive_relations"]) # count positive relations in image 2 40 | n_negative_2=len(ith_bag["relations"][image_ind_2]["negative_relations"])# count negative relations in image 2 41 | 42 | cnt=0 43 | for a in range(n_positive_1): 44 | if(cnt > 10): 45 | break 46 | for b in range(n_positive_2): 47 | cnt+=1 48 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_1]["positive_relations"][a]).cuda() 49 | positive_example_2=torch.tensor(ith_bag["relations"][image_ind_2]["positive_relations"][b]).cuda() 50 | label=1 51 | r=net(positive_example_1,positive_example_2) 52 | 53 | loss=episodic_loss(r,label) 54 | cnt_pairs+=1 55 | bag_loss+=loss 56 | 57 | sample=cnt//2 58 | itr=0 59 | for a in range(n_positive_1): 60 | for b in range(n_negative_2): 61 | if itr>sample: 62 | break 63 | itr+=1 64 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_1]["positive_relations"][a]).cuda() 65 | negative_example_2=torch.tensor(ith_bag["relations"][image_ind_2]["negative_relations"][b]).cuda() 66 | label=-1 67 | r=net(positive_example_1,negative_example_2) 68 | 69 | loss=episodic_loss(r,label) 70 | bag_loss+=loss 71 | cnt_pairs+=1 72 | 73 | itr=0 74 | for a in range(n_positive_2): 75 | for b in range(n_negative_1): 76 | if itr>sample: 77 | break 78 | itr+=1 79 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_2]["positive_relations"][a]).cuda() 80 | negative_example_2=torch.tensor(ith_bag["relations"][image_ind_1]["negative_relations"][b]).cuda() 81 | label=-1 82 | r=net(positive_example_1,negative_example_2) 83 | 84 | loss=episodic_loss(r,label) 85 | bag_loss+=loss 86 | cnt_pairs+=1 87 | 88 | return bag_loss, cnt_pairs 89 | 90 | def per_epoch_train(net, train_dataloader, optimizer, scheduler): 91 | epoch_loss = 0.0 92 | 93 | for batch_data in tqdm(train_dataloader, desc="Training an epoch"): 94 | batch_size=len(batch_data) 95 | for i in range(batch_size): 96 | ith_bag=batch_data[i] 97 | bag_size=len(ith_bag["relations"]) 98 | optimizer.zero_grad() 99 | bag_loss, cnt_pairs = per_sample_training(bag_size , ith_bag, net) 100 | 101 | bag_loss/=cnt_pairs 102 | bag_loss.backward() 103 | optimizer.step() 104 | 105 | epoch_loss+=bag_loss.item() 106 | 107 | scheduler.step(epoch_loss) 108 | 109 | return epoch_loss 110 | 111 | def train(train_config, dataset, net): 112 | epochs = train_config.epochs 113 | batch_size=train_config.batch_size 114 | 115 | optimizer = optim.Adam(net.parameters()) 116 | 117 | train_dataloader = DataLoader(dataset, batch_size=batch_size, 118 | shuffle=True, num_workers=0, collate_fn=lambda x:x) 119 | 120 | scheduler = ReduceLROnPlateau(optimizer, mode= train_config.scheduler.mode , factor=train_config.scheduler.factor, patience= train_config.scheduler.patience, verbose= train_config.scheduler.verbose) 121 | net.train() 122 | 123 | epoch_loss_min=100000000 124 | for epoch in range(epochs): 125 | epoch_loss = per_epoch_train(net, train_dataloader, optimizer, scheduler) 126 | print('Epoch input: {} \tTraining Loss: {:.6f} '.format(epoch, epoch_loss)) 127 | 128 | if epoch_loss < epoch_loss_min: 129 | print('training loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(epoch_loss_min,epoch_loss)) 130 | checkpoint = { 131 | 'epoch': epoch, 132 | 'model': net.state_dict(), 133 | 'optimizer': optimizer.state_dict()} 134 | 135 | print("dataset sample size=",len(dataset)) 136 | save_checkpoint(checkpoint,train_config) 137 | epoch_loss_min=epoch_loss 138 | 139 | def main(): 140 | 141 | data_config = load_config_file(DATA_CONFIG_PATH) 142 | train_config = load_config_file(TRAINER_CONFIG_PATH) 143 | model_config = load_config_file(MODEL_CONFIG_PATH) 144 | 145 | 146 | if data_config.VREncoderEmbeddings == 'VTransE': 147 | vrNetwork_config = load_config_file(data_config.VREncoderConfig) 148 | vrNetwork = VTransE(index_sp=vrNetwork_config.index_sp, 149 | index_cls=vrNetwork_config.index_cls, 150 | num_pred=vrNetwork_config.num_pred, 151 | output_size=vrNetwork_config.output_size, 152 | input_size=vrNetwork_config.input_size) 153 | elif data_config.VREncoderEmbeddings == 'VRConcat': 154 | vrNetwork = Concat() 155 | 156 | dataset = VrRVGDatasetTrain(data_config, vrNetwork) 157 | dataset_len=(dataset.__len__()) 158 | print("dataset_length=",dataset_len) 159 | 160 | if( train_config.NETWORK == "SimilarityNetworkVREncoder"): 161 | net = SimilarityNetworkVREncoder(model_config) 162 | 163 | if( train_config.NETWORK == "SimilarityNetworkConcat"): 164 | net = SimilarityNetworkConcat(model_config) 165 | 166 | net = net.cuda() 167 | 168 | train(train_config, dataset, net) 169 | print("Training Done") 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /ConcatplusSimilarityNetworkTrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | from utils.utils import load_config_file 6 | from VR_SimilarityNetwork.model.SimilarityNetworkConcat import SimilarityNetworkConcat 7 | from VR_SimilarityNetwork.model.SimilarityNetworkVREncoder import SimilarityNetworkVREncoder 8 | from VR_SimilarityNetwork.dataloader.VrRVGDatasetTrain import VrRVGDatasetTrain 9 | from VR_Encoder.model.vtranse import VTransE 10 | from VR_Encoder.model.concat import Concat 11 | from tqdm import tqdm 12 | import time 13 | 14 | 15 | DATA_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/data_config_train.yaml" 16 | TRAINER_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs//train_config.yaml" 17 | MODEL_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/model_config.yaml" 18 | 19 | ####################################### 20 | # Defining the loss 21 | def episodic_loss(r, R): 22 | return torch.log(1+torch.exp(-R*r)) 23 | ####################################### 24 | 25 | def save_checkpoint(checkpoint, train_config): 26 | time.sleep(10) 27 | path = train_config.NETWORK + '_checkpoint.pth' 28 | torch.save(checkpoint, path) 29 | 30 | def per_img_pair_training(ith_bag, image_ind_1, image_ind_2, net): 31 | bag_loss=0.0 32 | 33 | n_positive_1=len(ith_bag["relations"][image_ind_1]["positive_relations"]) # count positive relations in image 1 34 | n_negative_1=len(ith_bag["relations"][image_ind_1]["negative_relations"]) # count negative relations in image 1 35 | n_positive_2=len(ith_bag["relations"][image_ind_2]["positive_relations"]) # count positive relations in image 2 36 | n_negative_2=len(ith_bag["relations"][image_ind_2]["negative_relations"])# count negative relations in image 2 37 | 38 | total_train_pairs = 2*n_positive_1*n_positive_2 # total number of training pairs in a bag 39 | 40 | pos_cnt=0 # counts positive pairs 41 | # taking positive relations from both images 42 | loss =0.0 43 | for a in range(n_positive_1): 44 | if(pos_cnt>10): 45 | break 46 | for b in range(n_positive_2): 47 | pos_cnt+=1 48 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_1]["positive_relations"][a]).cuda() 49 | positive_example_2=torch.tensor(ith_bag["relations"][image_ind_2]["positive_relations"][b]).cuda() 50 | 51 | label=1 52 | r=net(positive_example_1,positive_example_2) 53 | 54 | loss= loss + episodic_loss(r,label) 55 | 56 | sample=pos_cnt//2 # so that same number of negative samples are taken 57 | itr=0 58 | for a in range(n_positive_1): 59 | if itr>sample: 60 | break 61 | for b in range(n_negative_2): 62 | itr+=1 63 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_1]["positive_relations"][a]).cuda() 64 | negative_example_2=torch.tensor(ith_bag["relations"][image_ind_2]["negative_relations"][b]).cuda() 65 | 66 | label=-1 67 | r=net(positive_example_1,negative_example_2) 68 | 69 | loss= loss + episodic_loss(r,label) 70 | 71 | itr=0 72 | for a in range(n_positive_2): 73 | if itr>sample: 74 | break 75 | for b in range(n_negative_1): 76 | itr+=1 77 | positive_example_1=torch.tensor(ith_bag["relations"][image_ind_2]["positive_relations"][a]).cuda() 78 | negative_example_2=torch.tensor(ith_bag["relations"][image_ind_1]["negative_relations"][b]).cuda() 79 | 80 | label=-1 81 | r=net(positive_example_1,negative_example_2) 82 | 83 | loss= loss + episodic_loss(r,label) 84 | loss = loss / total_train_pairs 85 | loss.backward() 86 | bag_loss+=loss.item() 87 | return bag_loss 88 | 89 | def per_sample_training(bag_size, ith_bag, net, optimizer): 90 | bag_loss = 0.0 91 | for j in range(bag_size): 92 | image_ind_1=j 93 | image_ind_2=(j+1)%bag_size 94 | # torch.cuda.empty_cache() 95 | optimizer.zero_grad() 96 | pair_loss = per_img_pair_training(ith_bag, image_ind_1, image_ind_2, net) 97 | bag_loss =bag_loss + pair_loss 98 | optimizer.step() 99 | return bag_loss 100 | 101 | 102 | def per_epoch_train(net, train_dataloader, optimizer, scheduler): 103 | epoch_loss = 0.0 104 | 105 | for batch_data in tqdm(train_dataloader, desc="Training an epoch"): 106 | batch_size=len(batch_data) 107 | for i in range(batch_size): 108 | ith_bag=batch_data[i] 109 | bag_size=len(ith_bag["relations"]) 110 | 111 | bag_loss = per_sample_training(bag_size , ith_bag, net, optimizer) 112 | epoch_loss+=bag_loss 113 | 114 | scheduler.step(epoch_loss) 115 | 116 | return epoch_loss 117 | 118 | def train(train_config, dataset, net): 119 | epochs = train_config.epochs 120 | batch_size=train_config.batch_size 121 | 122 | optimizer = optim.Adam(net.parameters()) 123 | 124 | train_dataloader = DataLoader(dataset, batch_size=batch_size, 125 | shuffle=True, num_workers=0, collate_fn=lambda x:x) 126 | 127 | scheduler = ReduceLROnPlateau(optimizer, mode= train_config.scheduler.mode , factor=train_config.scheduler.factor, patience= train_config.scheduler.patience, verbose= train_config.scheduler.verbose) 128 | net.train() 129 | 130 | epoch_loss_min=100000000 131 | for epoch in range(epochs): 132 | epoch_loss = per_epoch_train(net, train_dataloader, optimizer, scheduler) 133 | print('Epoch input: {} \tTraining Loss: {:.6f} '.format(epoch, epoch_loss)) 134 | 135 | if epoch_loss < epoch_loss_min: 136 | print('training loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(epoch_loss_min,epoch_loss)) 137 | checkpoint = { 138 | 'epoch': epoch, 139 | 'model': net.state_dict(), 140 | 'optimizer': optimizer.state_dict()} 141 | 142 | print("dataset sample size=",len(dataset)) 143 | save_checkpoint(checkpoint, train_config) 144 | epoch_loss_min=epoch_loss 145 | 146 | def main(): 147 | 148 | data_config = load_config_file(DATA_CONFIG_PATH) 149 | train_config = load_config_file(TRAINER_CONFIG_PATH) 150 | model_config = load_config_file(MODEL_CONFIG_PATH) 151 | 152 | 153 | if data_config.VREncoderEmbeddings == 'VTransE': 154 | vrNetwork_config = load_config_file(data_config.VREncoderConfig) 155 | vrNetwork = VTransE(index_sp=vrNetwork_config.index_sp, 156 | index_cls=vrNetwork_config.index_cls, 157 | num_pred=vrNetwork_config.num_pred, 158 | output_size=vrNetwork_config.output_size, 159 | input_size=vrNetwork_config.input_size) 160 | elif data_config.VREncoderEmbeddings == 'VRConcat': 161 | vrNetwork = Concat() 162 | 163 | dataset = VrRVGDatasetTrain(data_config, vrNetwork) 164 | dataset_len=(dataset.__len__()) 165 | print("dataset_length=",dataset_len) 166 | 167 | # Creating 168 | if( train_config.NETWORK == "SimilarityNetworkVREncoder"): 169 | net = SimilarityNetworkVREncoder(model_config) 170 | 171 | if( train_config.NETWORK == "SimilarityNetworkConcat"): 172 | net = SimilarityNetworkConcat(model_config) 173 | 174 | net = net.cuda() 175 | 176 | train(train_config, dataset, net) 177 | print("Training Done") 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /VR_Encoder/dataloader/vrr_vg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import glob 5 | import xml.etree.ElementTree as ET 6 | from utils.utils import read_json, mkdir 7 | from utils.sampling_utils import get_roi_index 8 | from tqdm import tqdm 9 | 10 | 11 | class VrRVG_train_dataset(Dataset): 12 | def __init__(self, xml_file_path, npy_file_path, saved_vtranse_input, saved_dir, train_predicates_path): 13 | self.xml_path = xml_file_path 14 | self.npy_file_path = npy_file_path 15 | self.relations = [] 16 | self.predicates_name_to_id = read_json(train_predicates_path) 17 | self.class_imbalance = np.zeros((101)) 18 | 19 | self.saved_dir = saved_dir 20 | mkdir(self.saved_dir) 21 | 22 | if saved_vtranse_input: 23 | self.loadrelations() 24 | else: 25 | self.getrelations() 26 | self.print_class_imbalance() 27 | 28 | def print_class_imbalance(self): 29 | np.save("class_imbalance.npy", self.class_imbalance) 30 | for i, val in enumerate(self.class_imbalance): 31 | print("class #", i, " and freq=", val) 32 | return 33 | 34 | def loadrelations(self): 35 | lis = os.listdir(self.saved_dir) 36 | for item in tqdm(lis, desc="loading relations..."): 37 | try: 38 | item_path = os.path.join(self.saved_dir, item) 39 | arr = np.load(item_path, allow_pickle=True) 40 | arr = arr.item() 41 | predic_id = int(arr["predicate"]) 42 | 43 | if(self.class_imbalance[predic_id] < 500): 44 | self.class_imbalance[predic_id] += 1 45 | self.relations.append(arr) 46 | except: 47 | pass 48 | return 49 | 50 | def getrelations(self): 51 | cnt_relations = 0 52 | all_xml_files_path = glob.glob(self.xml_path+"/*.xml") 53 | for i, xml_file_path in enumerate(all_xml_files_path): 54 | if i % 1000 == 999: 55 | print("file #", i) 56 | print("dataset size=", len(self.relations)) 57 | data = ET.parse(xml_file_path) 58 | root = data.getroot() 59 | img_file = root.find('filename').text 60 | img_name = img_file.split(".")[0] 61 | npy_info_name = img_name+"_info.npy" 62 | npy_feat_name = img_name+".npy" 63 | try: 64 | info = np.load(os.path.join(self.npy_file_path, 65 | npy_info_name), allow_pickle=True) 66 | feat = np.load(os.path.join(self.npy_file_path, 67 | npy_feat_name), allow_pickle=True) 68 | except: 69 | continue 70 | 71 | for sub in root.findall('./object'): 72 | for obj in root.findall('./object'): 73 | sub_id = int(sub.find('object_id').text) 74 | obj_id = int(obj.find('object_id').text) 75 | relation = {} 76 | for rel in root.findall('./relation'): 77 | predicate = str(rel.find("predicate").text) 78 | try: 79 | rel_sub_id = int(rel.find('./subject_id').text) 80 | rel_obj_id = int(rel.find('./object_id').text) 81 | 82 | if(rel_sub_id == sub_id and rel_obj_id == obj_id): 83 | subject_bbox = {} 84 | object_bbox = {} 85 | 86 | subject_bbox["xmin"] = float( 87 | sub.find('bndbox').find('xmin').text) 88 | subject_bbox["xmax"] = float( 89 | sub.find('bndbox').find('xmax').text) 90 | subject_bbox["ymin"] = float( 91 | sub.find('bndbox').find('ymin').text) 92 | subject_bbox["ymax"] = float( 93 | sub.find('bndbox').find('ymax').text) 94 | 95 | object_bbox["xmin"] = float( 96 | obj.find('bndbox').find('xmin').text) 97 | object_bbox["xmax"] = float( 98 | obj.find('bndbox').find('xmax').text) 99 | object_bbox["ymin"] = float( 100 | obj.find('bndbox').find('ymin').text) 101 | object_bbox["ymax"] = float( 102 | obj.find('bndbox').find('ymax').text) 103 | 104 | predicate_id = self.predicates_name_to_id[predicate] 105 | relation["predicate"] = predicate_id 106 | 107 | subject_roi_index, subj_roi_iou = get_roi_index( 108 | subject_bbox, info) 109 | object_roi_index, obj_roi_iou = get_roi_index( 110 | object_bbox, info) 111 | 112 | image_width = int( 113 | info.item().get('image_width')) 114 | image_height = int( 115 | info.item().get('image_height')) 116 | for i, subjs in enumerate(subject_roi_index): 117 | for j, objs in enumerate(object_roi_index): 118 | results = {} 119 | results.update(relation) 120 | bnd_boxx = info.item().get( 121 | 'bbox')[subjs] # [xmin ymin xmax ymax] 122 | bnd_box = bnd_boxx.copy() 123 | 124 | bnd_box[0] = float( 125 | bnd_box[0]/image_width) 126 | bnd_box[2] /= image_width 127 | bnd_box[1] /= image_height 128 | bnd_box[3] /= image_height 129 | 130 | results["sub_bnd_box"] = bnd_box 131 | results["sub_roi_iou"] = subj_roi_iou[i] 132 | results["obj_roi_iou"] = obj_roi_iou[j] 133 | results["sub_class_scores"] = info.item( 134 | )["class_scores"][subjs] # 1601-d vector 135 | results["obj_class_scores"] = info.item( 136 | )["class_scores"][objs] # 1601-d vector 137 | 138 | bnd_boxxx = info.item().get( 139 | 'bbox')[objs] # [xmin ymin xmax ymax] 140 | bnd_box = bnd_boxxx.copy() 141 | bnd_box[0] /= image_width 142 | bnd_box[2] /= image_width 143 | bnd_box[1] /= image_height 144 | bnd_box[3] /= image_height 145 | results["obj_bnd_box"] = bnd_box 146 | 147 | results["sub_roi_features"] = feat[subjs] 148 | results["obj_roi_features"] = feat[objs] 149 | 150 | self.relations.append(results) 151 | filename = os.path.join( 152 | self.saved_dir, str(cnt_relations)+".npy") 153 | with open(filename, 'wb') as f: 154 | np.save( 155 | f, results, allow_pickle=True) 156 | cnt_relations += 1 157 | 158 | except: 159 | pass 160 | 161 | def __len__(self): 162 | return len(self.relations) 163 | 164 | def __getitem__(self, idx): 165 | return self.relations[idx] 166 | -------------------------------------------------------------------------------- /FullModelTest.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Text 3 | from utils.utils import load_config_file 4 | import torch 5 | import os 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | import os 9 | from collections import defaultdict 10 | from shutil import copyfile 11 | from VR_SimilarityNetwork.dataloader.VrRVGDatasetTest import VrRVGDatasetTest 12 | from VR_SimilarityNetwork.model.SimilarityNetworkVREncoder import SimilarityNetworkVREncoder 13 | from VR_SimilarityNetwork.model.SimilarityNetworkConcat import SimilarityNetworkConcat 14 | import cv2 15 | from VR_Encoder.model.vtranse import VTransE 16 | from VR_Encoder.model.concat import Concat 17 | DATA_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/data_config_test.yaml" 18 | TESTER_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs//test_config.yaml" 19 | MODEL_CONFIG_PATH = "/DATA/trevant/Vaibhav/tempVRC/VR_SimilarityNetwork/configs/model_config.yaml" 20 | 21 | data_config = load_config_file(DATA_CONFIG_PATH) 22 | test_config = load_config_file(TESTER_CONFIG_PATH) 23 | model_config = load_config_file(MODEL_CONFIG_PATH) 24 | 25 | RESULT_FOLDER = test_config.RESULT_FOLDER 26 | BATCH_SIZE = test_config.BATCH_SIZE 27 | RELATION_NET_CHECKPOINT = test_config.RelationNET_CHECKPOINT 28 | SIMILARITY_NET_CONCAT_CHECKPOINT = test_config.SIMILARITY_NET_CONCAT_CHECKPOINT 29 | BAG_SIZE = test_config.BAG_SIZE 30 | SIMILARITY= test_config.SIMILARITY 31 | CONCAT = test_config.CONCAT 32 | SAVE_OUTPUT= test_config.SAVE_OUTPUT 33 | ANCHOR_IMAGE = test_config.ANCHOR_IMAGE 34 | SUBJECT_ANCHORED = test_config.SUBJECT_ANCHORED 35 | top_k= test_config.top_k 36 | 37 | def printParams(): 38 | print("bag size =", BAG_SIZE) 39 | print("similarity =", SIMILARITY) 40 | print("concat =", CONCAT) 41 | print("IMAGE ANCHORED = ", ANCHOR_IMAGE) 42 | print("Subject anchored= ", SUBJECT_ANCHORED) 43 | 44 | def load_dataset(vrNetwork): 45 | dataset= VrRVGDatasetTest(data_config, test_config, vrNetwork) 46 | dataset_len=(dataset.__len__()) 47 | print("dataset_length=",dataset_len) 48 | train_sz=int(0*dataset_len) 49 | val_size=dataset_len-train_sz 50 | train_dataset,val_dataset= torch.utils.data.random_split(dataset, [train_sz, val_size]) 51 | val_dataloader= DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0,collate_fn=lambda x:x) 52 | return val_dataloader 53 | 54 | def save_output(relation_id, n_samples, image_ids, bboxes_sub, bboxes_obj): 55 | path=os.path.join(RESULT_FOLDER,str(relation_id)) 56 | try: 57 | os.mkdir(path) 58 | except: 59 | pass 60 | 61 | path=os.path.join(path,str(n_samples)) 62 | try: 63 | os.mkdir(path) 64 | except: 65 | pass 66 | 67 | for i,img in enumerate(image_ids): 68 | try: 69 | im=cv2.imread(data_config.VisualGenomeImageDir1+str(img)+".jpg") 70 | h,w,d=im.shape 71 | except: 72 | im=cv2.imread(data_config.VisualGenomeImageDir2+str(img)+".jpg") 73 | h,w,d=im.shape 74 | 75 | 76 | [x_min,y_min,x_max,y_max]=(bboxes_sub[i]).tolist() 77 | 78 | x_min=int(x_min*w) 79 | x_max=int(x_max*w) 80 | y_min=int(y_min*h) 81 | y_max=int(y_max*h) 82 | cv2.rectangle(im,(x_min,y_min),(x_max,y_max),(0,255,0),2) 83 | [x_min,y_min,x_max,y_max]=(bboxes_obj[i]).tolist() 84 | x_min=int(x_min*w) 85 | x_max=int(x_max*w) 86 | y_min=int(y_min*h) 87 | y_max=int(y_max*h) 88 | cv2.rectangle(im,(x_min,y_min),(x_max,y_max),(0,255,255),2) 89 | 90 | path_new=os.path.join(path,str(img)+".jpg") 91 | cv2.imwrite(path_new,im) 92 | 93 | 94 | 95 | def binary_tree(l,r,data, net): 96 | mid=(l+r)//2 97 | if(l==r): 98 | ranks=[] 99 | tupls=[] 100 | relation_sz=len(data["relations"][l]) 101 | for i in range(relation_sz): 102 | ranks.append(0) 103 | tupls.append([data["relations"][l]["relations"][i]]) 104 | return tupls,ranks 105 | relations_1,rank_1=binary_tree(l,mid,data, net) 106 | relations_2,rank_2=binary_tree(mid+1,r,data, net) 107 | sz_r1=len(relations_1) 108 | sz_r2=len(relations_2) 109 | relations_final=[] 110 | rank_final=[] 111 | for i in range(sz_r1): 112 | for j in range(sz_r2): 113 | val_1=rank_1[i] 114 | val_2=rank_2[j] 115 | rela=[] 116 | rank=float(val_1+val_2) 117 | 118 | for k in range(len(relations_1[i])): 119 | for l in range(len(relations_2[j])): 120 | tup1=(relations_1[i][k])[1] 121 | tup2=relations_2[j][l][1] 122 | 123 | tup1=torch.tensor(tup1).cuda() 124 | tup2=torch.tensor(tup2).cuda() 125 | 126 | if(SIMILARITY=="cosine"): 127 | cos=torch.nn.CosineSimilarity(dim=0) 128 | calc = cos(tup1,tup2) 129 | 130 | elif(SIMILARITY=="relation_net"): 131 | calc=net(tup1,tup2) 132 | 133 | rank+=float(calc) 134 | rela=relations_1[i]+relations_2[j] 135 | relations_final.append(rela) 136 | rank_final.append(rank) 137 | rank_final=np.array(rank_final) 138 | top_k_indices=rank_final.argsort()[-top_k:][::-1] 139 | relations_top_k=[] 140 | rank_top_k=[] 141 | 142 | for ind in top_k_indices: 143 | relations_top_k.append(relations_final[ind]) 144 | rank_top_k.append(rank_final[ind]) 145 | 146 | return relations_top_k,rank_top_k 147 | 148 | ''' WHOLE VISUALIZE CODE IS COMMENTED OUT ''' 149 | def test(val_dataloader, net): 150 | n_samples=0 151 | n_correct=0 152 | n_correct_frac=0.0 153 | m_iou = 0.0 154 | n_pred=0 155 | n_samples_class = defaultdict(float) 156 | image_corloc_class = defaultdict(float) 157 | bag_corloc_class = defaultdict(float) 158 | 159 | for i_batch, data in enumerate(val_dataloader): 160 | get_size=len(data) 161 | for j in range(get_size): 162 | try: 163 | image_n=len(data[j]["relations"]) 164 | relation_id=data[j]["relation_id"] 165 | relations,ranks=binary_tree(0,image_n-1,data[j], net) 166 | chk=0 167 | n_samples_class[relation_id] += 1 168 | for ii in range(1): 169 | len_tupl=len(relations[ii]) 170 | verify=True 171 | streak=[] 172 | image_ids=[] 173 | bboxes_sub=[] 174 | bboxes_obj=[] 175 | frac=float(1/len_tupl) 176 | sum_tupl=0.0 177 | for jj in range(len_tupl): 178 | op=relations[ii][jj][0] 179 | image_id=relations[ii][jj][3] 180 | sub_iou = relations[ii][jj][4] 181 | obj_iou = relations[ii][jj][5] 182 | 183 | n_pred+=1 184 | m_iou += ((sub_iou-m_iou)/n_pred) 185 | n_pred+=1 186 | m_iou += ((obj_iou-m_iou)/n_pred) 187 | 188 | image_ids.append(image_id) 189 | bb_sub=relations[ii][jj][2]["sub_bnd_box"] 190 | 191 | bb_obj=relations[ii][jj][2]["obj_bnd_box"] 192 | bboxes_sub.append(bb_sub) 193 | bboxes_obj.append(bb_obj) 194 | sum_tupl+=float(op*frac) 195 | 196 | streak.append(op) 197 | verify=verify & op 198 | 199 | if(SAVE_OUTPUT ==True): 200 | save_output(relation_id, n_samples, image_ids, bboxes_sub, bboxes_obj) 201 | if(verify==True): 202 | chk=1 203 | 204 | 205 | n_samples+=1 206 | n_correct_frac+=sum_tupl 207 | image_corloc_class[relation_id] += sum_tupl 208 | if(chk==1): 209 | n_correct+=1 210 | bag_corloc_class[relation_id] += 1 211 | else: 212 | bag_corloc_class[relation_id] += 0 213 | except: 214 | pass 215 | 216 | return n_correct, n_correct_frac, n_samples 217 | 218 | 219 | def printResult(n_correct, n_correct_frac, n_samples): 220 | print("n_correct=", n_correct) 221 | print("n correct fraction=", n_correct_frac) 222 | print("n_samples=", n_samples) 223 | 224 | def main(): 225 | printParams() 226 | if test_config.CONCAT ==False: 227 | vrNetwork_config = load_config_file(data_config.VREncoderConfig) 228 | vrNetwork = VTransE(index_sp=vrNetwork_config.index_sp, 229 | index_cls=vrNetwork_config.index_cls, 230 | num_pred=vrNetwork_config.num_pred, 231 | output_size=vrNetwork_config.output_size, 232 | input_size=vrNetwork_config.input_size) 233 | else: 234 | vrNetwork = Concat() 235 | 236 | val_dataloader = load_dataset(vrNetwork) # loading dataset 237 | 238 | #loading network 239 | ################################################# 240 | if( CONCAT == False): 241 | net = SimilarityNetworkVREncoder(model_config) 242 | net = net.cuda() 243 | chkpt=torch.load(RELATION_NET_CHECKPOINT) 244 | 245 | else: 246 | net = SimilarityNetworkConcat(model_config) 247 | net = net.cuda() 248 | chkpt=torch.load(SIMILARITY_NET_CONCAT_CHECKPOINT) 249 | 250 | net.load_state_dict(chkpt["model"]) 251 | net.eval() 252 | torch.no_grad() 253 | ################################################## 254 | 255 | n_correct, n_correct_frac, n_samples =test(val_dataloader, net) 256 | printResult(n_correct, n_correct_frac, n_samples) 257 | 258 | if __name__ == "__main__": 259 | main() 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /train_vr_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.optim import Adam 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import random_split 8 | 9 | from VR_Encoder.dataloader.vrr_vg_dataset import VrRVG_train_dataset 10 | 11 | from VR_Encoder.model.vtranse import VTransE 12 | from VR_Encoder.model.concat import Concat 13 | 14 | from utils.utils import set_seed, mkdir, load_config_file 15 | from utils.logger import setup_logger 16 | 17 | from omegaconf import OmegaConf 18 | 19 | DATA_CONFIG_PATH = "VR_Encoder/configs/data_config.yaml" 20 | TRAINER_CONFIG_PATH = "VR_Encoder/configs/train_config.yaml" 21 | MODEL_CONFIG_PATH = "VR_Encoder/configs/model_config.yaml" 22 | 23 | 24 | def save_checkpoint(config, epoch, model, optimizer): 25 | ''' 26 | Checkpointing. Saves model and optimizer state_dict() and current epoch and global training steps. 27 | ''' 28 | checkpoint_path = os.path.join( 29 | config.saved_checkpoints, f'checkpoint_{epoch}.pt') 30 | save_num = 0 31 | while (save_num < 10): 32 | try: 33 | 34 | if config.n_gpu > 1: 35 | torch.save({ 36 | 'epoch': epoch, 37 | 'model_state_dict': model.module.state_dict(), 38 | 'optimizer_state_dict': optimizer.state_dict() 39 | }, checkpoint_path) 40 | else: 41 | torch.save({ 42 | 'epoch': epoch, 43 | 'model_state_dict': model.state_dict(), 44 | 'optimizer_state_dict': optimizer.state_dict() 45 | }, checkpoint_path) 46 | 47 | logger.info("Save checkpoint to {}".format(checkpoint_path)) 48 | break 49 | except: 50 | save_num += 1 51 | if save_num == 10: 52 | logger.info("Failed to save checkpoint after 10 trails.") 53 | return 54 | 55 | 56 | def train(config, train_dataset, model): 57 | ''' 58 | Trains the model. 59 | ''' 60 | 61 | config.train_batch_size = config.per_gpu_train_batch_size * \ 62 | max(1, config.n_gpu) 63 | 64 | # creating val set from train dataset and dataloaders 65 | train_size = int(config.training_split_ratio*len(train_dataset)) 66 | val_size = len(train_dataset)-train_size 67 | 68 | train_dataset, val_dataset = random_split( 69 | train_dataset, [train_size, val_size]) 70 | 71 | train_dataloader = DataLoader(train_dataset, batch_size=4, 72 | shuffle=True, num_workers=0) 73 | val_dataloader = DataLoader( 74 | val_dataset, batch_size=4, shuffle=True, num_workers=0) 75 | 76 | # total training iterations 77 | t_total = len(train_dataloader) * config.num_train_epochs 78 | 79 | criterion = torch.nn.CrossEntropyLoss() 80 | optimizer = Adam(model.parameters(), lr=config.optimizer.params.lr, 81 | eps=config.optimizer.params.eps, weight_decay=config.optimizer.params.weight_decay) 82 | scheduler = ReduceLROnPlateau( 83 | optimizer, mode='max', factor=0.5, patience=3, verbose=True) 84 | 85 | if config.n_gpu > 1: 86 | model = torch.nn.DataParallel(model) 87 | 88 | model = model.to(torch.device(config.device)) 89 | 90 | logger.info("***** Running training *****") 91 | logger.info(" Num examples = %d", len(train_dataset)) 92 | logger.info(" Num Epochs = %d", config.num_train_epochs) 93 | logger.info(" Number of GPUs = %d", config.n_gpu) 94 | 95 | logger.info(" Batch size per GPU = %d", config.per_gpu_train_batch_size) 96 | logger.info(" Total train batch size (w. parallel) = %d", 97 | config.train_batch_size) 98 | logger.info(" Total optimization steps = %d", t_total) 99 | 100 | max_val_acc = 0 101 | epoch_val_loss_min = 1000 102 | val_acc_for_min_loss = 0 103 | 104 | for epoch in range(int(config.num_train_epochs)): 105 | epoch_train_loss, epoch_val_loss = 0.0, 0.0 106 | 107 | # train for the epoch 108 | model.train() 109 | for step, sample_batched in enumerate(train_dataloader): 110 | model.zero_grad() 111 | 112 | subj_sp = sample_batched["sub_bnd_box"].to( 113 | torch.device(config.device)) 114 | obj_sp = sample_batched["obj_bnd_box"].to( 115 | torch.device(config.device)) 116 | subj_cls = sample_batched["sub_class_scores"].to( 117 | torch.device(config.device)) 118 | obj_cls = sample_batched["obj_class_scores"].to( 119 | torch.device(config.device)) 120 | sub_feat = sample_batched["sub_roi_features"].to( 121 | torch.device(config.device)) 122 | obj_feat = sample_batched["obj_roi_features"].to( 123 | torch.device(config.device)) 124 | labels = sample_batched["predicate"].to( 125 | torch.device(config.device)) 126 | 127 | rela_score, _ = model( 128 | subj_sp, subj_cls, sub_feat, obj_sp, obj_cls, obj_feat) 129 | loss = criterion(rela_score, labels) 130 | if config.n_gpu > 1: 131 | loss = loss.mean() # mean() to average on multi-gpu parallel training 132 | 133 | loss.backward() 134 | optimizer.step() 135 | epoch_train_loss += loss.item() 136 | 137 | # eval after the epoch 138 | with torch.no_grad(): 139 | model.eval() 140 | num_correct, num_samples = 0, 0 141 | for step, sample_batched in enumerate(val_dataloader): 142 | subj_sp = sample_batched["sub_bnd_box"].to( 143 | torch.device(config.device)) 144 | obj_sp = sample_batched["obj_bnd_box"].to( 145 | torch.device(config.device)) 146 | subj_cls = sample_batched["sub_class_scores"].to( 147 | torch.device(config.device)) 148 | obj_cls = sample_batched["obj_class_scores"].to( 149 | torch.device(config.device)) 150 | sub_feat = sample_batched["sub_roi_features"].to( 151 | torch.device(config.device)) 152 | obj_feat = sample_batched["obj_roi_features"].to( 153 | torch.device(config.device)) 154 | labels = sample_batched["predicate"].to( 155 | torch.device(config.device)) 156 | 157 | rela_score, _ = model( 158 | subj_sp, subj_cls, sub_feat, obj_sp, obj_cls, obj_feat) 159 | max_index = rela_score.argmax(dim=1) 160 | num_correct += (max_index == labels).sum() 161 | num_samples += labels.size(0) 162 | 163 | # seeing val loss 164 | loss = criterion(rela_score, labels) 165 | if config.n_gpu > 1: 166 | loss = loss.mean() # mean() to average on multi-gpu parallel training 167 | 168 | epoch_val_loss += loss.item() 169 | 170 | val_acc = (num_correct/num_samples)*100 171 | max_val_acc = max(val_acc, max_val_acc) 172 | # logger.info(f"Epoch {epoch}:Got {num_correct} / {num_samples} correct with val accuracy: {val_acc}") 173 | 174 | scheduler.step(val_acc) 175 | 176 | epoch_train_loss = epoch_train_loss / len(train_dataloader) 177 | epoch_val_loss = epoch_val_loss / len(val_dataloader) 178 | 179 | logger.info( 180 | f"Epoch {epoch} | Train Loss={epoch_train_loss} | Val Loss={epoch_val_loss} | Val acc. {val_acc}") 181 | 182 | if epoch_val_loss < epoch_val_loss_min: 183 | logger.info("Epoch Val loss decreased({:.6f} --> {:.6f}).Saving model ...".format( 184 | epoch_val_loss_min, epoch_val_loss)) 185 | save_checkpoint(config, epoch, model, optimizer) 186 | epoch_val_loss_min = epoch_val_loss 187 | val_acc_for_min_loss = val_acc 188 | 189 | return epoch_val_loss_min, val_acc_for_min_loss 190 | 191 | 192 | def main(): 193 | 194 | data_config = load_config_file(DATA_CONFIG_PATH) 195 | train_config = load_config_file(TRAINER_CONFIG_PATH) 196 | model_config = load_config_file(MODEL_CONFIG_PATH) 197 | 198 | # merging data and train configs to be given to train() 199 | config = OmegaConf.merge(train_config, data_config) 200 | 201 | global logger 202 | # creating directories for saving checkpoints and logs 203 | mkdir(path=config.saved_checkpoints) 204 | mkdir(path=config.logs) 205 | 206 | logger = setup_logger(config.logs, config.logs, 0, 207 | filename="training_logs.txt") 208 | 209 | config.device = "cuda" if torch.cuda.is_available() else "cpu" 210 | config.n_gpu = torch.cuda.device_count() # config.n_gpu 211 | set_seed(seed=42, n_gpu=config.n_gpu) 212 | 213 | # creating model 214 | if model_config.model_name == "VTransE": 215 | model = VTransE(index_sp=model_config.index_sp, 216 | index_cls=model_config.index_cls, 217 | num_pred=model_config.num_pred, 218 | output_size=model_config.output_size, 219 | input_size=model_config.input_size) 220 | elif model_config.model_name == "Concat": 221 | model = Concat() 222 | else: 223 | logger.info(f"{model_config.model_name} model not supported") 224 | 225 | # getting dataset for training 226 | logger.info(f"Initializing dataset ...") 227 | train_dataset = VrRVG_train_dataset(xml_file_path=data_config.xml_file_path, 228 | npy_file_path=data_config.npy_file_path, 229 | saved_vtranse_input=data_config.saved_vtranse_input, 230 | saved_dir=data_config.saved_dir, 231 | train_predicates_path=data_config.train_predicates_path) 232 | 233 | # Now training 234 | val_loss, val_acc = train(config, train_dataset, model) 235 | 236 | logger.info(f"Training done: val_loss = {val_loss}, val_acc = {val_acc}") 237 | 238 | 239 | if __name__ == "__main__": 240 | main() 241 | -------------------------------------------------------------------------------- /data_preparation/vrc_extract_frcnn_feats.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Steps before running the scripts: 3 | 4 | # 1. first install maskrcnn-benchmark : FRCNN Model 5 | 6 | # $ git clone https://gitlab.com/meetshah1995/vqa-maskrcnn-benchmark.git 7 | # $ cd vqa-maskrcnn-benchmark 8 | # $ python setup.py build 9 | # $ python setup.py develop 10 | 11 | # 2. download pre-trained detectron weights 12 | 13 | # $ mkdir detectron_weights 14 | # $ wget -O detectron_weights/detectron_model.pth https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.pth 15 | # $ wget -O detectron_weights/detectron_model.yaml https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.yaml 16 | 17 | # NOTE: just modify the code in /content/vqa-maskrcnn-benchmark/maskrcnn_benchmark/utils/imports.py, change PY3 to PY37 18 | 19 | # to run the script 20 | # $ python faster_rcnn_script.py --image_dir= 21 | 22 | ################################################### 23 | 24 | import argparse 25 | import glob 26 | import os 27 | 28 | import cv2 29 | import numpy as np 30 | import torch 31 | from PIL import Image 32 | 33 | from maskrcnn_benchmark.config import cfg 34 | from maskrcnn_benchmark.layers import nms 35 | from maskrcnn_benchmark.modeling.detector import build_detection_model 36 | from maskrcnn_benchmark.structures.image_list import to_image_list 37 | from maskrcnn_benchmark.utils.model_serialization import load_state_dict 38 | 39 | class FeatureExtractor: 40 | MODEL_URL = ( 41 | "https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.pth" 42 | ) 43 | CONFIG_URL = ( 44 | "https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.yaml" 45 | ) 46 | MAX_SIZE = 1333 47 | MIN_SIZE = 800 48 | 49 | def __init__(self): 50 | self.args = self.get_parser().parse_args() 51 | self.detection_model = self._build_detection_model() 52 | os.makedirs(self.args.output_folder, exist_ok=True) 53 | 54 | def get_parser(self): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | "--model_file", default="detectron_weights/detectron_model.pth", type=str, help="Detectron model file" 58 | ) 59 | parser.add_argument( 60 | "--config_file", default="detectron_weights/detectron_model.yaml", type=str, help="Detectron config file" 61 | ) 62 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") 63 | parser.add_argument( 64 | "--num_features", type=int, default=100, help="Number of features to extract." 65 | ) 66 | parser.add_argument( 67 | "--output_folder", type=str, default="./all_vg_frcnn", help="Output folder" 68 | ) 69 | parser.add_argument("--image_dir", default="./input_images", type=str, help="Image directory or file") 70 | parser.add_argument( 71 | "--feature_name", type=str, help="The name of the feature to extract", 72 | default="fc6", 73 | ) 74 | parser.add_argument( 75 | "--confidence_threshold", type=float, default=0, 76 | help="Threshold of detection confidence above which boxes will be selected" 77 | ) 78 | parser.add_argument( 79 | "--background", action="store_true", 80 | help="The model will output predictions for the background class when set" 81 | ) 82 | return parser 83 | 84 | def _build_detection_model(self): 85 | cfg.merge_from_file(self.args.config_file) 86 | cfg.freeze() 87 | 88 | model = build_detection_model(cfg) 89 | checkpoint = torch.load(self.args.model_file, map_location=torch.device("cpu")) 90 | 91 | load_state_dict(model, checkpoint.pop("model")) 92 | 93 | model.to("cuda") 94 | model.eval() 95 | return model 96 | 97 | def _image_transform(self, path): 98 | img = Image.open(path) 99 | im = np.array(img).astype(np.float32) 100 | 101 | # temp fix : for images with 4 channels 102 | if im.shape[-1] > 3: 103 | im = np.array(img.convert("RGB")).astype(np.float32) 104 | 105 | # IndexError: too many indices for array, grayscale images 106 | if len(im.shape) < 3: 107 | im = np.repeat(im[:, :, np.newaxis], 3, axis=2) 108 | im = im[:, :, ::-1] 109 | im -= np.array([102.9801, 115.9465, 122.7717]) 110 | im_shape = im.shape 111 | im_height = im_shape[0] 112 | im_width = im_shape[1] 113 | im_size_min = np.min(im_shape[0:2]) 114 | im_size_max = np.max(im_shape[0:2]) 115 | 116 | # Scale based on minimum size 117 | im_scale = self.MIN_SIZE / im_size_min 118 | 119 | # Prevent the biggest axis from being more than max_size 120 | # If bigger, scale it down 121 | if np.round(im_scale * im_size_max) > self.MAX_SIZE: 122 | im_scale = self.MAX_SIZE / im_size_max 123 | 124 | im = cv2.resize( 125 | im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR 126 | ) 127 | img = torch.from_numpy(im).permute(2, 0, 1) 128 | 129 | im_info = { 130 | "width": im_width, 131 | "height": im_height 132 | } 133 | 134 | return img, im_scale, im_info 135 | 136 | def _process_feature_extraction( 137 | self, output, im_scales, im_infos, feature_name="fc6", conf_thresh=0 138 | ): 139 | batch_size = len(output[0]["proposals"]) 140 | n_boxes_per_image = [len(boxes) for boxes in output[0]["proposals"]] 141 | score_list = output[0]["scores"].split(n_boxes_per_image) 142 | score_list = [torch.nn.functional.softmax(x, -1) for x in score_list] 143 | feats = output[0][feature_name].split(n_boxes_per_image) 144 | cur_device = score_list[0].device 145 | 146 | feat_list = [] 147 | info_list = [] 148 | 149 | for i in range(batch_size): 150 | dets = output[0]["proposals"][i].bbox / im_scales[i] 151 | scores = score_list[i] 152 | max_conf = torch.zeros((scores.shape[0])).to(cur_device) 153 | conf_thresh_tensor = torch.full_like(max_conf, conf_thresh) 154 | start_index = 1 155 | # Column 0 of the scores matrix is for the background class 156 | if self.args.background: 157 | start_index = 0 158 | for cls_ind in range(start_index, scores.shape[1]): 159 | cls_scores = scores[:, cls_ind] 160 | keep = nms(dets, cls_scores, 0.5) 161 | max_conf[keep] = torch.where( 162 | # Better than max one till now and minimally greater than conf_thresh 163 | (cls_scores[keep] > max_conf[keep]) & 164 | (cls_scores[keep] > conf_thresh_tensor[keep]), 165 | cls_scores[keep], max_conf[keep] 166 | ) 167 | 168 | sorted_scores, sorted_indices = torch.sort(max_conf, descending=True) 169 | num_boxes = (sorted_scores[:self.args.num_features] != 0).sum() 170 | keep_boxes = sorted_indices[:self.args.num_features] 171 | feat_list.append(feats[i][keep_boxes]) 172 | bbox = output[0]["proposals"][i][keep_boxes].bbox / im_scales[i] 173 | # Predict the class label using the scores 174 | objects = torch.argmax(scores[keep_boxes], dim=1) 175 | 176 | info_list.append( 177 | { 178 | "bbox": bbox.cpu().numpy(), 179 | "num_boxes": num_boxes.item(), 180 | "objects": objects.cpu().numpy(), 181 | "image_width": im_infos[i]["width"], 182 | "image_height": im_infos[i]["height"], 183 | "class_scores" : scores[keep_boxes].cpu().numpy() 184 | } 185 | ) 186 | 187 | return feat_list, info_list 188 | 189 | def get_detectron_features(self, image_paths): 190 | img_tensor, im_scales, im_infos = [], [], [] 191 | 192 | for image_path in image_paths: 193 | im, im_scale, im_info = self._image_transform(image_path) 194 | img_tensor.append(im) 195 | im_scales.append(im_scale) 196 | im_infos.append(im_info) 197 | 198 | # Image dimensions should be divisible by 32, to allow convolutions 199 | # in detector to work 200 | current_img_list = to_image_list(img_tensor, size_divisible=32) 201 | current_img_list = current_img_list.to("cuda") 202 | 203 | with torch.no_grad(): 204 | output = self.detection_model(current_img_list) 205 | 206 | feat_list = self._process_feature_extraction( 207 | output, im_scales, im_infos, self.args.feature_name, 208 | self.args.confidence_threshold 209 | ) 210 | 211 | return feat_list 212 | 213 | def _chunks(self, array, chunk_size): 214 | for i in range(0, len(array), chunk_size): 215 | yield array[i : i + chunk_size] 216 | 217 | def _save_feature(self, file_name, feature, info): 218 | file_base_name = os.path.basename(file_name) 219 | file_base_name = file_base_name.split(".")[0] 220 | info_file_base_name = file_base_name + "_info.npy" 221 | file_base_name = file_base_name + ".npy" 222 | 223 | np.save( 224 | os.path.join(self.args.output_folder, file_base_name), feature.cpu().numpy() 225 | ) 226 | np.save(os.path.join(self.args.output_folder, info_file_base_name), info) 227 | 228 | def extract_features(self): 229 | image_dir = self.args.image_dir 230 | 231 | if os.path.isfile(image_dir): 232 | features, infos = self.get_detectron_features([image_dir]) 233 | self._save_feature(image_dir, features[0], infos[0]) 234 | else: 235 | files = glob.glob(os.path.join(image_dir, "*.*")) 236 | 237 | for chunk in self._chunks(files, self.args.batch_size): 238 | features, infos = self.get_detectron_features(chunk) 239 | for idx, file_name in enumerate(chunk): 240 | self._save_feature(file_name, features[idx], infos[idx]) 241 | 242 | if __name__ == "__main__": 243 | feature_extractor = FeatureExtractor() 244 | feature_extractor.extract_features() -------------------------------------------------------------------------------- /VR_SimilarityNetwork/dataloader/VrRVGDatasetTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import glob 5 | import xml.etree.ElementTree as ET 6 | from collections import defaultdict 7 | import glob 8 | 9 | import random 10 | from tqdm import tqdm 11 | import torch 12 | 13 | from utils.sampling_utils import get_iou 14 | from utils.utils import read_json 15 | 16 | 17 | class VrRVGDatasetTrain(Dataset): 18 | def __init__(self, data_config, vrNetwork): 19 | self.data_config = data_config 20 | 21 | self.vrNetwork = vrNetwork 22 | self.vrNetwork = self.vrNetwork.cuda() 23 | if(self.data_config.VREncoderEmbeddings == "VTransE"): 24 | chkpt=torch.load(self.data_config.VREncoder_Net_Checkpoint) 25 | self.vrNetwork.load_state_dict(chkpt["model_state_dict"]) 26 | self.vrNetwork.eval() 27 | 28 | self.xml_path = self.data_config.XML_FILE_PATH_VrRVG 29 | self.npy_file_path = self.data_config.NPY_FILE_PATH 30 | 31 | self.all_xml_files_path=glob.glob(self.xml_path+"/*.xml") 32 | self.predicates_name_to_id=read_json(data_config.train_predicates_path) 33 | self.relation_id_to_images=defaultdict(list) 34 | self.get_relation_id_to_images() 35 | 36 | self.sampled_bags=self.load_samples() 37 | self.bag_features=self.sample_generator() # all vtranse features of relations in images in a bag 38 | 39 | def extract_vtranse_embedding_from_relation(self,info,feat,subjs,objs,image_width,image_height): 40 | relation={} 41 | bnd_boxx=info.item().get('bbox')[subjs] # [xmin ymin xmax ymax] 42 | bnd_box=bnd_boxx.copy() 43 | bnd_box[0]=float(bnd_box[0]/image_width) 44 | bnd_box[2]/=image_width 45 | bnd_box[1]/=image_height 46 | bnd_box[3]/=image_height 47 | relation["sub_class_scores"] = info.item()["class_scores"][subjs] 48 | relation["obj_class_scores"]= info.item()["class_scores"][objs] 49 | relation["sub_bnd_box"]=bnd_box 50 | 51 | bnd_boxxx=info.item().get('bbox')[objs] # [xmin ymin xmax ymax] 52 | bnd_box=bnd_boxxx.copy() 53 | bnd_box[0]/=image_width 54 | bnd_box[2]/=image_width 55 | bnd_box[1]/=image_height 56 | bnd_box[3]/=image_height 57 | relation["obj_bnd_box"]=bnd_box 58 | 59 | relation["sub_roi_features"]=feat[subjs] 60 | relation["obj_roi_features"]=feat[objs] 61 | if(self.data_config.VREncoderEmbeddings == "VTransE"): 62 | embedding = self.get_vtranse_embedding(relation) 63 | else: 64 | embedding = self.get_concat_embeddings(relation) 65 | 66 | return embedding 67 | 68 | def get_vtranse_embedding(self,sample_batched): 69 | 70 | with torch.no_grad(): 71 | 72 | subj_sp=torch.tensor(sample_batched["sub_bnd_box"]).float().cuda() 73 | obj_sp=torch.tensor(sample_batched["obj_bnd_box"]).float().cuda() 74 | 75 | subj_cls=torch.tensor(sample_batched["sub_class_scores"]).float().cuda() 76 | obj_cls=torch.tensor(sample_batched["obj_class_scores"]).float().cuda() 77 | 78 | sub_feat=torch.tensor(sample_batched["sub_roi_features"]).float().cuda() 79 | obj_feat=torch.tensor(sample_batched["obj_roi_features"]).float().cuda() 80 | 81 | _, vr_emb = self.vrNetwork.forward_inference(subj_sp,subj_cls,sub_feat,obj_sp,obj_cls,obj_feat) 82 | vr_emb_numpy=vr_emb.cpu().detach().numpy() 83 | return vr_emb_numpy 84 | 85 | def get_concat_embeddings(self, sample_batched): 86 | with torch.no_grad(): 87 | subj_sp=torch.tensor(sample_batched["sub_bnd_box"]).float() 88 | obj_sp=torch.tensor(sample_batched["obj_bnd_box"]).float() 89 | 90 | subj_cls=torch.tensor(sample_batched["sub_class_scores"]).float() 91 | obj_cls=torch.tensor(sample_batched["obj_class_scores"]).float() 92 | 93 | sub_feat=torch.tensor(sample_batched["sub_roi_features"]).float() 94 | obj_feat=torch.tensor(sample_batched["obj_roi_features"]).float() 95 | 96 | sub_emb = torch.cat([subj_sp, subj_cls, sub_feat]) 97 | ob_emb = torch.cat([obj_sp, obj_cls, obj_feat]) 98 | 99 | vr_emb = torch.cat([sub_emb, ob_emb]) 100 | 101 | vr_emb = vr_emb.cpu().detach().numpy() 102 | return vr_emb 103 | 104 | def sample_generator(self): 105 | bags=[] 106 | i=0 107 | for bag in tqdm(self.sampled_bags, desc="generating dataset"): # what is self.sampled_bags?? a list which has dicts {"relation_id" :relation_id, "images_ids" : img_ids of bag size} 108 | i+=1 109 | bag_relation_id=bag["relation_id"] # relation id is of common relation of the bag 110 | relations_per_bag={} 111 | relations_per_bag["relation_id"]=bag_relation_id 112 | relations_per_bag["relations"]=[] # n1,n2,n3 relations in images 113 | bag_size = len(bag["images_ids"]) 114 | try: 115 | for img in bag["images_ids"]: 116 | relations_per_image={} 117 | img_name=img.split(".")[0] 118 | 119 | relations_per_image["image_name"]=img_name 120 | relations_per_image["positive_relations"]=[] 121 | relations_per_image["negative_relations"]=[] 122 | 123 | npy_info_name=img_name+"_info.npy" 124 | npy_feat_name=img_name+".npy" 125 | info=np.load(os.path.join(self.npy_file_path,npy_info_name),allow_pickle=True) 126 | feat=np.load(os.path.join(self.npy_file_path,npy_feat_name),allow_pickle=True) 127 | xml_file_path=self.xml_path+"/"+img_name+".xml" 128 | 129 | xml_data = ET.parse(xml_file_path) 130 | root = xml_data.getroot() 131 | positive_relations_id_from_xml=[] 132 | all_relations_in_xml = root.findall('./relation') 133 | 134 | # function of this for loop : finding positive relations in xml 135 | for rel in all_relations_in_xml: 136 | predicate=str(rel.find("predicate").text) 137 | try: 138 | verify_test_train = self.predicates_name_to_id[predicate] 139 | if(verify_test_train == bag_relation_id): 140 | rel_sub_id=str(rel.find('./subject_id').text) 141 | rel_obj_id=str(rel.find('./object_id').text) 142 | relation_key=rel_sub_id+rel_obj_id 143 | positive_relations_id_from_xml.append(relation_key) 144 | except: 145 | # if rel is of different test/train set : basically ignore this code 146 | pass 147 | 148 | ## refactored till here 149 | for rel in all_relations_in_xml: 150 | predicate=str(rel.find("predicate").text) 151 | try: 152 | check=(self.predicates_name_to_id[predicate]) 153 | if(check==bag_relation_id): 154 | rel_sub_id=int(rel.find('./subject_id').text) 155 | rel_obj_id=int(rel.find('./object_id').text) 156 | current_rel_key=str(rel_sub_id)+str(rel_obj_id) 157 | for sub in root.findall('./object'): 158 | for obj in root.findall('./object'): 159 | flag=0 160 | sub_id=int(sub.find('object_id').text) 161 | obj_id=int(obj.find('object_id').text) 162 | if(rel_sub_id==sub_id and rel_obj_id==obj_id): 163 | subject_bbox={} 164 | object_bbox={} 165 | 166 | 167 | subject_bbox["xmin"]=float(sub.find('bndbox').find('xmin').text) 168 | subject_bbox["xmax"]=float(sub.find('bndbox').find('xmax').text) 169 | subject_bbox["ymin"]=float(sub.find('bndbox').find('ymin').text) 170 | subject_bbox["ymax"]=float(sub.find('bndbox').find('ymax').text) 171 | 172 | object_bbox["xmin"]=float(obj.find('bndbox').find('xmin').text) 173 | object_bbox["xmax"]=float(obj.find('bndbox').find('xmax').text) 174 | object_bbox["ymin"]=float(obj.find('bndbox').find('ymin').text) 175 | object_bbox["ymax"]=float(obj.find('bndbox').find('ymax').text) 176 | 177 | #get subjects and objects roi index from npy file where the iou is greater than the threshold for subjects and objects 178 | # ERROR in this Function : now fixed 179 | subject_roi_index, neutral_sub_roi_index = self.get_subject_roi_index(root, subject_bbox, info, bag_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="subject") 180 | object_roi_index, neutral_obj_roi_index = self.get_subject_roi_index(root, object_bbox, info, bag_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="object") 181 | 182 | image_width = int(info.item().get('image_width')) 183 | image_height = int(info.item().get('image_height')) 184 | sampling=0 185 | # extract positive relations embeddings 186 | for i,subjs in enumerate(subject_roi_index): 187 | for j,objs in enumerate(object_roi_index): 188 | embedding = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height) 189 | # print("appending positive-positive pairs relations in img") 190 | relations_per_image["positive_relations"].append(embedding) 191 | sampling+=1 192 | 193 | rois_info=info.item().get('bbox') 194 | rois=rois_info.shape[0] 195 | cntt=0 196 | objs_len=len(object_roi_index) 197 | 198 | # taking positive-negative/ negative-positive pairs as negative samples 199 | for i,subjs in enumerate(subject_roi_index): 200 | for j in range(0,objs_len): 201 | objs=random.randint(0,rois-1) 202 | if(objs not in object_roi_index): 203 | embedding=self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height) 204 | # print("appending positive-negative pairs relations in img") 205 | relations_per_image["negative_relations"].append(embedding) 206 | sampling+=1 207 | else: 208 | j-=1 209 | subjs_len=len(subject_roi_index) 210 | for i,objs in enumerate(object_roi_index): 211 | for j in range(0,subjs_len): 212 | subjs=random.randint(0,rois-1) 213 | if(subjs not in subject_roi_index): 214 | embedding=self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height) 215 | # print("appending positive-negative pairs relations in img") 216 | 217 | relations_per_image["negative_relations"].append(embedding) 218 | sampling+=1 219 | else: 220 | j-=1 221 | 222 | # taking negative-negative pairs 223 | for i in range(subjs_len): 224 | subjs=random.randint(0,rois-1) 225 | if(subjs in subject_roi_index or subjs in neutral_sub_roi_index): 226 | i-=1 227 | else: 228 | for j in range(objs_len): 229 | objs=random.randint(0,rois-1) 230 | if(objs in object_roi_index or objs in neutral_obj_roi_index): 231 | j-=1 232 | else: 233 | embedding=self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height) 234 | # print("appending negative-negative pairs relations in img") 235 | 236 | relations_per_image["negative_relations"].append(embedding) 237 | sampling+=1 238 | 239 | except Exception as e: 240 | # print("exception") 241 | # print(e) 242 | # print("-") 243 | pass 244 | 245 | if(len(relations_per_image["positive_relations"])!=0 and len(relations_per_image["negative_relations"])!=0): 246 | # print("appending image relations in bag") 247 | relations_per_bag["relations"].append(relations_per_image) 248 | except: 249 | pass 250 | 251 | if(len(relations_per_bag["relations"]) == bag_size): 252 | # print("appending bag in list") 253 | bags.append(relations_per_bag) 254 | else : 255 | # print('len(relations_per_bag["relations"]', len(relations_per_bag["relations"])) 256 | pass 257 | 258 | return bags 259 | 260 | def get_subject_roi_index(self, root, bbox, info, bag_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="subject"): 261 | 262 | indexes=[] 263 | neutral_indexes=set() 264 | 265 | rois_info=info.item().get('bbox') 266 | rois=rois_info.shape[0] 267 | subj_roi_iou=[] 268 | # for positive rois 269 | for i in range(0, rois): 270 | bbox_roi=rois_info[i] 271 | bbox_dict={} 272 | bbox_dict["xmin"]=float(bbox_roi[0]) 273 | bbox_dict["ymin"]=float(bbox_roi[1]) 274 | bbox_dict["xmax"]=float(bbox_roi[2]) 275 | bbox_dict["ymax"]=float(bbox_roi[3]) 276 | 277 | iou=get_iou(bbox,bbox_dict) 278 | if(iou>0.50): 279 | indexes.append(i) 280 | #for neutral rois 281 | # since neutrals are ignored : all rois which are in positive are also in neutral 282 | for i in range(0,rois): 283 | bbox_roi=rois_info[i] 284 | bbox_dict={} 285 | bbox_dict["xmin"]=float(bbox_roi[0]) 286 | bbox_dict["ymin"]=float(bbox_roi[1]) 287 | bbox_dict["xmax"]=float(bbox_roi[2]) 288 | bbox_dict["ymax"]=float(bbox_roi[3]) 289 | 290 | for rel in all_relations_in_xml: 291 | predicate=str(rel.find("predicate").text) 292 | check=(self.predicates_name_to_id[predicate]) 293 | if(check==bag_relation_id): 294 | rel_sub_id=int(rel.find('./subject_id').text) 295 | rel_obj_id=int(rel.find('./object_id').text) 296 | current_rel_key=str(rel_sub_id)+str(rel_obj_id) 297 | if current_rel_key in positive_relations_id_from_xml: 298 | for sub in root.findall('./object'): 299 | flag=0 300 | xml_obj_id=int(sub.find('object_id').text) 301 | if(type=="subject"): 302 | if(rel_sub_id!=xml_obj_id): 303 | continue 304 | elif(type=="object"): 305 | if(rel_obj_id!=xml_obj_id): 306 | continue 307 | subject_class=str(sub.find('name').text) 308 | subject_bbox={} 309 | subject_bbox["xmin"]=float(sub.find('bndbox').find('xmin').text) 310 | subject_bbox["xmax"]=float(sub.find('bndbox').find('xmax').text) 311 | subject_bbox["ymin"]=float(sub.find('bndbox').find('ymin').text) 312 | subject_bbox["ymax"]=float(sub.find('bndbox').find('ymax').text) 313 | iou=get_iou(subject_bbox,bbox_dict) 314 | if(iou>0.50): 315 | neutral_indexes.add(i) 316 | 317 | neutral_indexes=list(neutral_indexes) 318 | return indexes,neutral_indexes 319 | 320 | def load_samples(self): 321 | ''' returns a list of dicts. Each dict : a bag for training : {"relation_id" : id_of_predicate, "images_ids": list_of_img_ids_of_bag_size} ''' 322 | samples=[] 323 | # bag_size = [4,6,8,10] 324 | # count_bag = np.zeros(4) 325 | bag_size = [2,4] 326 | count_bag = np.zeros(2) 327 | 328 | for relation_id,images_ids in tqdm(self.relation_id_to_images.items(), desc="returing a list of dictionary og relations"): 329 | images_len=len(images_ids) 330 | sample_per_relation = self.data_config.SAMPLE_PER_RELATION 331 | for _ in range(0,sample_per_relation): 332 | bag_n = np.argmin(count_bag) 333 | count_bag[bag_n] += 1 334 | 335 | sample={} 336 | sample["relation_id"]=relation_id 337 | sample["images_ids"]=[] 338 | image_index=random.sample(range(0,images_len), bag_size[bag_n]) 339 | for ind in image_index: 340 | sample["images_ids"].append(images_ids[ind]) 341 | samples.append(sample) 342 | 343 | return samples 344 | 345 | def get_relation_id_to_images(self): 346 | for xml_file_path in tqdm(self.all_xml_files_path, desc="reading xml files"): 347 | data=ET.parse(xml_file_path) 348 | root = data.getroot() 349 | img_file=root.find('filename').text 350 | for rel in root.findall('./relation'): 351 | predicate=str(rel.find("predicate").text) 352 | if predicate in self.predicates_name_to_id: 353 | predicate_id=self.predicates_name_to_id[predicate] 354 | self.relation_id_to_images[predicate_id].append(img_file) 355 | 356 | 357 | def __getitem__(self,idx): 358 | return dict(self.bag_features[idx]) 359 | 360 | def __len__(self): 361 | return len(self.bag_features) 362 | -------------------------------------------------------------------------------- /VR_SimilarityNetwork/dataloader/VrRVGDatasetTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import glob 5 | import xml.etree.ElementTree as ET 6 | from collections import defaultdict 7 | import glob 8 | import random 9 | from tqdm import tqdm 10 | from scipy import special 11 | from utils.utils import read_json 12 | from utils.sampling_utils import get_iou 13 | 14 | import torch 15 | 16 | class VrRVGDatasetTest(Dataset): 17 | def __init__(self, data_config, test_config, vrNetwork): 18 | self.data_config = data_config 19 | self.test_config = test_config 20 | 21 | self.vrNetwork = vrNetwork 22 | self.vrNetwork = self.vrNetwork.cuda() 23 | if(self.test_config.CONCAT== False): 24 | chkpt=torch.load(self.data_config.VREncoder_Net_Checkpoint) 25 | self.vrNetwork.load_state_dict(chkpt["model_state_dict"]) 26 | self.vrNetwork.eval() 27 | 28 | self.xml_path = self.data_config.XML_FILE_PATH_VrRVG 29 | self.all_xml_files_path=glob.glob(self.xml_path+"/*.xml") 30 | self.npy_file_path = self.data_config.NPY_FILE_PATH 31 | 32 | self.bag_size = self.test_config.BAG_SIZE 33 | self.predicates_name_to_id=read_json(data_config.test_predicates_path) 34 | self.relation_id_to_images=defaultdict(list) 35 | self.get_relation_id_to_images() 36 | self.sampl=self.load_samples() 37 | self.samples=self.test_sample_generator() 38 | 39 | def extract_vtranse_embedding_from_relation(self,info,feat,subjs,objs,image_width,image_height, subj_bbox, obj_bbox): 40 | relation={} 41 | bnd_boxx=info.item().get('bbox')[subjs] # [xmin ymin xmax ymax] 42 | bnd_box=bnd_boxx.copy() 43 | sub_pred = {} 44 | sub_pred["xmin"] = bnd_box[0] 45 | sub_pred["ymin"] = bnd_box[1] 46 | sub_pred["xmax"] = bnd_box[2] 47 | sub_pred["ymax"] = bnd_box[3] 48 | sub_iou = get_iou(sub_pred,subj_bbox) 49 | 50 | bnd_box[0]=float(bnd_box[0]/image_width) 51 | bnd_box[2]/=image_width 52 | bnd_box[1]/=image_height 53 | bnd_box[3]/=image_height 54 | relation["sub_class_scores"] = info.item()["class_scores"][subjs] 55 | relation["obj_class_scores"]= info.item()["class_scores"][objs] 56 | 57 | relation["sub_bnd_box"]=bnd_box 58 | 59 | bnd_boxxx=info.item().get('bbox')[objs] # [xmin ymin xmax ymax] 60 | bnd_box=bnd_boxxx.copy() 61 | obj_pred = {} 62 | obj_pred["xmin"] = bnd_box[0] 63 | obj_pred["ymin"] = bnd_box[1] 64 | obj_pred["xmax"] = bnd_box[2] 65 | obj_pred["ymax"] = bnd_box[3] 66 | obj_iou = get_iou(obj_pred,obj_bbox) 67 | 68 | bnd_box[0]/=image_width 69 | bnd_box[2]/=image_width 70 | bnd_box[1]/=image_height 71 | bnd_box[3]/=image_height 72 | relation["obj_bnd_box"]=bnd_box 73 | 74 | relation["sub_roi_features"]=feat[subjs] 75 | relation["obj_roi_features"]=feat[objs] 76 | if(self.test_config.CONCAT == 0): 77 | embedding=self.get_vtranse_embedding(relation) 78 | else: 79 | embedding=self.get_concat_embedding(relation) 80 | 81 | return embedding,relation, sub_iou, obj_iou 82 | 83 | 84 | def get_vtranse_embedding(self,sample_batched): 85 | 86 | with torch.no_grad(): 87 | 88 | subj_sp=torch.tensor(sample_batched["sub_bnd_box"]).float().cuda() 89 | obj_sp=torch.tensor(sample_batched["obj_bnd_box"]).float().cuda() 90 | 91 | subj_cls=torch.tensor(sample_batched["sub_class_scores"]).float().cuda() 92 | obj_cls=torch.tensor(sample_batched["obj_class_scores"]).float().cuda() 93 | 94 | sub_feat=torch.tensor(sample_batched["sub_roi_features"]).float().cuda() 95 | obj_feat=torch.tensor(sample_batched["obj_roi_features"]).float().cuda() 96 | 97 | _, vr_emb = self.vrNetwork.forward_inference(subj_sp,subj_cls,sub_feat,obj_sp,obj_cls,obj_feat) 98 | vr_emb_numpy=vr_emb.cpu().detach().numpy() 99 | return vr_emb_numpy 100 | 101 | 102 | def get_concat_embedding(self, sample_batched): 103 | with torch.no_grad(): 104 | subj_sp=torch.tensor(sample_batched["sub_bnd_box"]).float() 105 | obj_sp=torch.tensor(sample_batched["obj_bnd_box"]).float() 106 | 107 | subj_cls=torch.tensor(sample_batched["sub_class_scores"]).float() 108 | obj_cls=torch.tensor(sample_batched["obj_class_scores"]).float() 109 | 110 | sub_feat=torch.tensor(sample_batched["sub_roi_features"]).float() 111 | obj_feat=torch.tensor(sample_batched["obj_roi_features"]).float() 112 | 113 | sub_emb = torch.cat([subj_sp, subj_cls, sub_feat]) 114 | ob_emb = torch.cat([obj_sp, obj_cls, obj_feat]) 115 | 116 | vr_emb = torch.cat([sub_emb, ob_emb]) 117 | 118 | vr_emb = vr_emb.cpu().detach().numpy() 119 | return vr_emb 120 | 121 | def test_sample_generator(self): 122 | samples=[] 123 | positive_examples=0 124 | negative_examples=0 125 | cnt=0 126 | for sample in tqdm(self.sampl, desc =" iterating samples"): 127 | sample_relation_id=sample["relation_id"] 128 | relations_per_sample={} 129 | relations_per_sample["relation_id"]=sample_relation_id 130 | relations_per_sample["relations"]=[] # n1,n2,n3 relations in images 131 | for img_ind,img in enumerate(sample["images_ids"]): 132 | relations_per_image={} 133 | img_name=img.split(".")[0] 134 | 135 | relations_per_image["image_name"]=img_name 136 | relations_per_image["relations"]=[] 137 | 138 | npy_info_name=img_name+"_info.npy" 139 | npy_feat_name=img_name+".npy" 140 | info=np.load(os.path.join(self.npy_file_path,npy_info_name),allow_pickle=True) 141 | feat=np.load(os.path.join(self.npy_file_path,npy_feat_name),allow_pickle=True) 142 | xml_file_path=img_name+".xml" 143 | xml_file_path = os.path.join(self.xml_path, xml_file_path) 144 | 145 | data=ET.parse(xml_file_path) 146 | root = data.getroot() 147 | 148 | positive_relations_id_from_xml=[] 149 | all_relations_in_xml = root.findall('./relation') 150 | 151 | # function of this for loop : finding positive relations in xml 152 | for rel in all_relations_in_xml: 153 | predicate=str(rel.find("predicate").text) 154 | verify_test_train = self.predicates_name_to_id.get(predicate,None) 155 | if verify_test_train is None : 156 | continue 157 | 158 | if(verify_test_train == sample_relation_id): 159 | rel_sub_id=str(rel.find('./subject_id').text) 160 | rel_obj_id=str(rel.find('./object_id').text) 161 | relation_key=rel_sub_id+rel_obj_id 162 | positive_relations_id_from_xml.append(relation_key) 163 | 164 | 165 | for rel in all_relations_in_xml: 166 | predicate=str(rel.find("predicate").text) 167 | check=self.predicates_name_to_id.get(predicate, None) 168 | if check is None : 169 | continue 170 | if(check==sample_relation_id): 171 | rel_sub_id=int(rel.find('./subject_id').text) 172 | rel_obj_id=int(rel.find('./object_id').text) 173 | current_rel_key=str(rel_sub_id)+str(rel_obj_id) 174 | for sub in root.findall('./object'): 175 | for obj in root.findall('./object'): 176 | flag=0 177 | sub_id=int(sub.find('object_id').text) 178 | obj_id=int(obj.find('object_id').text) 179 | if(rel_sub_id==sub_id and rel_obj_id==obj_id): 180 | subject_bbox={} 181 | object_bbox={} 182 | 183 | # get subject and object bounding box for xml files 184 | subject_bbox["xmin"]=float(sub.find('bndbox').find('xmin').text) 185 | subject_bbox["xmax"]=float(sub.find('bndbox').find('xmax').text) 186 | subject_bbox["ymin"]=float(sub.find('bndbox').find('ymin').text) 187 | subject_bbox["ymax"]=float(sub.find('bndbox').find('ymax').text) 188 | 189 | object_bbox["xmin"]=float(obj.find('bndbox').find('xmin').text) 190 | object_bbox["xmax"]=float(obj.find('bndbox').find('xmax').text) 191 | object_bbox["ymin"]=float(obj.find('bndbox').find('ymin').text) 192 | object_bbox["ymax"]=float(obj.find('bndbox').find('ymax').text) 193 | 194 | #get subjects and objects roi index from npy file where the iou is greater than the threshold for subjects and objects 195 | 196 | subject_roi_index, neutral_sub_roi_index = self.get_subject_roi_index(root, subject_bbox, info, sample_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="subject") 197 | object_roi_index, neutral_obj_roi_index = self.get_subject_roi_index(root, object_bbox, info, sample_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="object") 198 | 199 | image_width = int(info.item().get('image_width')) 200 | image_height = int(info.item().get('image_height')) 201 | sampling=0 202 | 203 | 204 | if(self.test_config.ANCHOR_IMAGE == 1 and img_ind==0): 205 | for i,subjs in enumerate(subject_roi_index): 206 | for j,objs in enumerate(object_roi_index): 207 | embedding,relation, iou_sub, iou_obj = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height, subject_bbox, object_bbox) 208 | relations_per_image["relations"].append((1,embedding,relation,img_name, iou_sub, iou_obj)) 209 | positive_examples+=1 210 | sampling+=1 211 | 212 | else: 213 | for i,subjs in enumerate(subject_roi_index): 214 | for j,objs in enumerate(object_roi_index): 215 | embedding,relation, iou_sub, iou_obj = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height, subject_bbox, object_bbox) 216 | relations_per_image["relations"].append((1,embedding,relation,img_name, iou_sub, iou_obj)) 217 | positive_examples+=1 218 | sampling+=1 219 | rois_info=info.item().get('bbox') 220 | rois=rois_info.shape[0] 221 | objs_len=len(object_roi_index) 222 | subjs_len=len(subject_roi_index) 223 | # for negative pair : positive sub - negative-obj 224 | for i,subjs in enumerate(subject_roi_index): 225 | for objs in range(0,rois): 226 | if(objs not in object_roi_index): 227 | embedding,relation, iou_sub, iou_obj = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height, subject_bbox, object_bbox) 228 | relations_per_image["relations"].append((0,embedding,relation,img_name, iou_sub, iou_obj)) 229 | negative_examples+=1 230 | sampling+=1 231 | 232 | if(self.test_config.SUBJECT_ANCHORED==False): 233 | # for negative pair : negative sub - positive-obj 234 | for subjs in range(0,rois): 235 | for j,objs in enumerate(object_roi_index): 236 | if(subjs not in subject_roi_index ): 237 | embedding,relation, iou_sub, iou_obj = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height, subject_bbox, object_bbox) 238 | relations_per_image["relations"].append((0,embedding,relation,img_name, iou_sub, iou_obj)) 239 | negative_examples+=1 240 | sampling+=1 241 | 242 | # taking negative pairs : negative sub - negative obj 243 | for i in range(subjs_len): 244 | subjs=random.randint(0,rois-1) 245 | if(subjs in subject_roi_index or subjs in neutral_sub_roi_index): 246 | i-=1 247 | else: 248 | for j in range(objs_len): 249 | objs=random.randint(0,rois-1) 250 | if(objs in object_roi_index or objs in neutral_obj_roi_index): 251 | j-=1 252 | else: 253 | embedding,relation, iou_sub, iou_obj = self.extract_vtranse_embedding_from_relation(info,feat,subjs,objs,image_width,image_height, subject_bbox, object_bbox) 254 | # print("appending negative-negative pairs relations in img") 255 | negative_examples+=1 256 | relations_per_image["relations"].append((0,embedding,relation,img_name, iou_sub, iou_obj)) 257 | sampling+=1 258 | 259 | 260 | if(len(relations_per_image["relations"])!=0 ): 261 | # print("adding relations of 1 image in bag") 262 | #print('len(relations_per_image["relations"])', len(relations_per_image["relations"])) 263 | 264 | relations_per_sample["relations"].append(relations_per_image) 265 | # else : 266 | # #print('len(relations_per_image["relations"])', len(relations_per_image["relations"])) 267 | 268 | if(len(relations_per_sample["relations"])==self.bag_size): 269 | # print() 270 | cnt += 1 271 | 272 | samples.append(relations_per_sample) 273 | 274 | print("positive_examples=",positive_examples) 275 | print("negative_examples=",negative_examples) 276 | return samples 277 | 278 | def get_subject_roi_index(self, root, bbox, info, bag_relation_id, positive_relations_id_from_xml, all_relations_in_xml, type="subject"): 279 | 280 | indexes=[] 281 | neutral_indexes=set() 282 | 283 | rois_info=info.item().get('bbox') 284 | rois=rois_info.shape[0] 285 | # for positive rois 286 | for i in range(0, rois): 287 | bbox_roi=rois_info[i] 288 | bbox_dict={} 289 | bbox_dict["xmin"]=float(bbox_roi[0]) 290 | bbox_dict["ymin"]=float(bbox_roi[1]) 291 | bbox_dict["xmax"]=float(bbox_roi[2]) 292 | bbox_dict["ymax"]=float(bbox_roi[3]) 293 | 294 | iou=get_iou(bbox,bbox_dict) 295 | if(iou>0.50): 296 | indexes.append(i) 297 | #for neutral rois 298 | # since neutrals are ignored : all rois which are in positive are also in neutral 299 | for i in range(0,rois): 300 | bbox_roi=rois_info[i] 301 | bbox_dict={} 302 | bbox_dict["xmin"]=float(bbox_roi[0]) 303 | bbox_dict["ymin"]=float(bbox_roi[1]) 304 | bbox_dict["xmax"]=float(bbox_roi[2]) 305 | bbox_dict["ymax"]=float(bbox_roi[3]) 306 | 307 | for rel in all_relations_in_xml: 308 | predicate=str(rel.find("predicate").text) 309 | check= self.predicates_name_to_id.get(predicate, None) 310 | 311 | if check is None: 312 | continue 313 | if(check==bag_relation_id): 314 | rel_sub_id=int(rel.find('./subject_id').text) 315 | rel_obj_id=int(rel.find('./object_id').text) 316 | current_rel_key=str(rel_sub_id)+str(rel_obj_id) 317 | if current_rel_key in positive_relations_id_from_xml: 318 | for sub in root.findall('./object'): 319 | flag=0 320 | xml_obj_id=int(sub.find('object_id').text) 321 | if(type=="subject"): 322 | if(rel_sub_id!=xml_obj_id): 323 | continue 324 | elif(type=="object"): 325 | if(rel_obj_id!=xml_obj_id): 326 | continue 327 | subject_bbox={} 328 | subject_bbox["xmin"]=float(sub.find('bndbox').find('xmin').text) 329 | subject_bbox["xmax"]=float(sub.find('bndbox').find('xmax').text) 330 | subject_bbox["ymin"]=float(sub.find('bndbox').find('ymin').text) 331 | subject_bbox["ymax"]=float(sub.find('bndbox').find('ymax').text) 332 | iou=get_iou(subject_bbox,bbox_dict) 333 | if(iou>0.50): 334 | neutral_indexes.add(i) 335 | 336 | neutral_indexes=list(neutral_indexes) 337 | return indexes,neutral_indexes 338 | 339 | def load_samples(self): 340 | ''' returns a list of dicts. Each dict : a bag for training : {"relation_id" : id_of_predicate, "images_ids": list_of_img_ids_of_bag_size} ''' 341 | samples=[] 342 | for relation_id,images_ids in tqdm(self.relation_id_to_images.items(), desc="returing a list of dictionary og relations"): 343 | images_len=len(images_ids) 344 | assert images_len>=self.bag_size 345 | sample_per_relation=int((special.comb(images_len,2))) 346 | 347 | if(sample_per_relation > self.data_config.SAMPLE_PER_RELATION): 348 | sample_per_relation= self.data_config.SAMPLE_PER_RELATION 349 | for j in range(0,sample_per_relation): 350 | sample={} 351 | sample["relation_id"]=relation_id 352 | sample["images_ids"]=[] 353 | image_index=random.sample(range(0,images_len),self.bag_size) 354 | for ind in image_index: 355 | sample["images_ids"].append(images_ids[ind]) 356 | samples.append(sample) 357 | return samples 358 | 359 | def get_relation_id_to_images(self): 360 | for xml_file_path in tqdm(self.all_xml_files_path, desc ="reading xml"): 361 | data=ET.parse(xml_file_path) 362 | root = data.getroot() 363 | img_file=root.find('filename').text 364 | for rel in root.findall('./relation'): 365 | predicate=str(rel.find("predicate").text) 366 | 367 | if predicate in self.predicates_name_to_id: 368 | predicate_id=self.predicates_name_to_id[predicate] 369 | self.relation_id_to_images[predicate_id].append(img_file) 370 | 371 | def __getitem__(self,idx): 372 | return dict(self.samples[idx]) 373 | 374 | def __len__(self): 375 | return len(self.samples) 376 | --------------------------------------------------------------------------------