├── figures ├── figure_1.pdf ├── figure_2.jpg ├── figure_2.pdf ├── figure_3.pdf ├── figure_4.jpg ├── figure_4.pdf └── figure_5.pdf ├── .gitignore ├── utils ├── scene_graph_eval_matrix.py ├── utils.py ├── vis_tool.py ├── segmentation_eval_matrix.py └── io.py ├── environment.yml ├── eval_instructions.txt ├── README.md ├── models ├── surgicalDataset.py ├── mtl_model.py ├── scene_graph.py └── segmentation_model.py ├── evaluation.py └── model_train.py /figures/figure_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_1.pdf -------------------------------------------------------------------------------- /figures/figure_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_2.jpg -------------------------------------------------------------------------------- /figures/figure_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_2.pdf -------------------------------------------------------------------------------- /figures/figure_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_3.pdf -------------------------------------------------------------------------------- /figures/figure_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_4.jpg -------------------------------------------------------------------------------- /figures/figure_4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_4.pdf -------------------------------------------------------------------------------- /figures/figure_5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_5.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints/ 2 | .vscode/ 3 | 4 | checkpoints/ 5 | models/r18/ 6 | datasets/ 7 | old_deprecated/ 8 | 9 | feature_extractor/checkpoint 10 | log/ 11 | results/ 12 | venv/ 13 | 14 | sai_transfer/ 15 | process_checkpoint/ 16 | 17 | tmp* 18 | *__pycache__ 19 | *.pyc -------------------------------------------------------------------------------- /utils/scene_graph_eval_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics 3 | 4 | def compute_mean_avg_prec(y_true, y_score): 5 | try: 6 | avg_prec = sklearn.metrics.average_precision_score(y_true, y_score, average=None) 7 | mean_avg_prec = np.nansum(avg_prec) / len(avg_prec) 8 | except ValueError: 9 | mean_avg_prec = 0 10 | 11 | return mean_avg_prec 12 | 13 | def calibration_metrics(logits_all, labels_all): 14 | 15 | logits = logits_all.detach().cpu().numpy() 16 | labels = labels_all.detach().cpu().numpy() 17 | map_value = compute_mean_avg_prec(labels, logits) 18 | labels = np.argmax(labels, axis=-1) 19 | recall = sklearn.metrics.recall_score(labels, np.argmax(logits,1), average='macro') 20 | return(map_value, recall) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /utils/vis_tool.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import random 4 | import numpy as np 5 | import matplotlib 6 | import torch as t 7 | 8 | matplotlib.use('Agg') 9 | from matplotlib import pyplot as plot 10 | from PIL import Image, ImageDraw, ImageFont 11 | 12 | 13 | def vis_img(img, node_classes, bboxs, det_action, data_const, score_thresh = 0.7): 14 | 15 | Drawer = ImageDraw.Draw(img) 16 | line_width = 3 17 | outline = '#FF0000' 18 | font = ImageFont.truetype(font='/usr/share/fonts/truetype/freefont/FreeMono.ttf', size=25) 19 | 20 | im_w,im_h = img.size 21 | node_num = len(node_classes) 22 | edge_num = len(det_action) 23 | tissue_num = len(np.where(node_classes == 1)[0]) 24 | 25 | for node in range(node_num): 26 | 27 | r_color = random.choice(np.arange(256)) 28 | g_color = random.choice(np.arange(256)) 29 | b_color = random.choice(np.arange(256)) 30 | 31 | text = data_const.instrument_classes[node_classes[node]] 32 | h, w = font.getsize(text) 33 | Drawer.rectangle(list(bboxs[node]), outline=outline, width=line_width) 34 | Drawer.text(xy=(bboxs[node][0], bboxs[node][1]-w-1), text=text, font=font, fill=(r_color,g_color,b_color)) 35 | 36 | edge_idx = 0 37 | 38 | for tissue in range(tissue_num): 39 | for instrument in range(tissue+1, node_num): 40 | 41 | #action_idx = np.where(det_action[edge_idx] > score_thresh)[0] 42 | action_idx = np.argmax(det_action[edge_idx]) 43 | # print('det_action', det_action[edge_idx]) 44 | # print('action_idx',action_idx) 45 | 46 | text = data_const.action_classes[action_idx] 47 | r_color = random.choice(np.arange(256)) 48 | g_color = random.choice(np.arange(256)) 49 | b_color = random.choice(np.arange(256)) 50 | 51 | x1,y1,x2,y2 = bboxs[tissue] 52 | x1_,y1_,x2_,y2_ = bboxs[instrument] 53 | 54 | c0 = int(0.5*x1)+int(0.5*x2) 55 | c0 = max(0,min(c0,im_w-1)) 56 | r0 = int(0.5*y1)+int(0.5*y2) 57 | r0 = max(0,min(r0,im_h-1)) 58 | c1 = int(0.5*x1_)+int(0.5*x2_) 59 | c1 = max(0,min(c1,im_w-1)) 60 | r1 = int(0.5*y1_)+int(0.5*y2_) 61 | r1 = max(0,min(r1,im_h-1)) 62 | Drawer.line(((c0,r0),(c1,r1)), fill=(r_color,g_color,b_color), width=3) 63 | Drawer.text(xy=(c1, r1), text=text, font=font, fill=(r_color,g_color,b_color)) 64 | 65 | edge_idx +=1 66 | 67 | return img -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gr-mtl-environment 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2021.7.5=h06a4308_1 10 | - certifi=2021.5.30=py36h06a4308_0 11 | - cudatoolkit=10.2.89=hfd86e86_1 12 | - dataclasses=0.8=pyh4f3eec9_6 13 | - freetype=2.10.4=h5ab3b9f_0 14 | - intel-openmp=2021.3.0=h06a4308_3350 15 | - jpeg=9b=h024ee3a_2 16 | - lcms2=2.12=h3be6417_0 17 | - libedit=3.1.20210216=h27cfd23_1 18 | - libffi=3.2.1=hf484d3e_1007 19 | - libgcc-ng=9.3.0=h5101ec6_17 20 | - libgomp=9.3.0=h5101ec6_17 21 | - libpng=1.6.37=hbc83047_0 22 | - libstdcxx-ng=9.3.0=hd4cf53a_17 23 | - libtiff=4.2.0=h85742a9_0 24 | - libuv=1.40.0=h7b6447c_0 25 | - libwebp-base=1.2.0=h27cfd23_0 26 | - lz4-c=1.9.3=h295c915_1 27 | - mkl=2020.2=256 28 | - mkl-service=2.3.0=py36he8ac12f_0 29 | - mkl_fft=1.3.0=py36h54f3939_0 30 | - mkl_random=1.1.1=py36h0573a6f_0 31 | - ncurses=6.2=he6710b0_1 32 | - ninja=1.10.2=hff7bd54_1 33 | - numpy=1.19.2=py36h54aff64_0 34 | - numpy-base=1.19.2=py36hfa32c7d_0 35 | - olefile=0.46=py36_0 36 | - openjpeg=2.3.0=h05c96fa_1 37 | - openssl=1.1.1k=h27cfd23_0 38 | - pillow=8.3.1=py36h2c7a002_0 39 | - pip=21.2.2=py36h06a4308_0 40 | - python=3.6.9=h265db76_0 41 | - pytorch=1.7.1=py3.6_cuda10.2.89_cudnn7.6.5_0 42 | - readline=7.0=h7b6447c_5 43 | - setuptools=52.0.0=py36h06a4308_0 44 | - six=1.16.0=pyhd3eb1b0_0 45 | - sqlite=3.33.0=h62c20be_0 46 | - tk=8.6.10=hbc83047_0 47 | - torchaudio=0.7.2=py36 48 | - torchvision=0.8.2=py36_cu102 49 | - typing_extensions=3.10.0.0=pyh06a4308_0 50 | - wheel=0.37.0=pyhd3eb1b0_0 51 | - xz=5.2.5=h7b6447c_0 52 | - zlib=1.2.11=h7b6447c_3 53 | - zstd=1.4.9=haebb681_0 54 | - pip: 55 | - albumentations==1.0.3 56 | - cached-property==1.5.2 57 | - charset-normalizer==2.0.4 58 | - cycler==0.10.0 59 | - decorator==4.4.2 60 | - dgl-cu102==0.4.2 61 | - h5py==2.10.0 62 | - idna==3.2 63 | - imageio==2.9.0 64 | - importlib-metadata==4.8.1 65 | - joblib==1.0.1 66 | - kiwisolver==1.3.1 67 | - matplotlib==3.3.4 68 | - networkx==2.5.1 69 | - opencv-python-headless==4.5.3.56 70 | - prettytable==2.2.0 71 | - pynvml==11.0.0 72 | - pyparsing==2.4.7 73 | - python-dateutil==2.8.2 74 | - pywavelets==1.1.1 75 | - pyyaml==5.4.1 76 | - requests==2.26.0 77 | - scikit-image==0.17.2 78 | - scikit-learn==0.24.2 79 | - scipy==1.5.4 80 | - tabulate==0.8.9 81 | - threadpoolctl==2.2.0 82 | - tifffile==2020.9.3 83 | - torchsummary==1.5.1 84 | - tqdm==4.62.2 85 | - urllib3==1.26.6 86 | - wcwidth==0.2.5 87 | - zipp==3.5.0 88 | -------------------------------------------------------------------------------- /utils/segmentation_eval_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | 9 | def batch_pix_accuracy(output, target): 10 | """Batch Pixel Accuracy 11 | Args: 12 | predict: input 4D tensor 13 | target: label 3D tensor 14 | """ 15 | _, predict = torch.max(output, 1) 16 | 17 | predict = predict.cpu().numpy().astype('int64') + 1 18 | target = target.cpu().numpy().astype('int64') + 1 19 | 20 | pixel_labeled = np.sum(target > 0) 21 | pixel_correct = np.sum((predict == target)*(target > 0)) 22 | assert pixel_correct <= pixel_labeled, \ 23 | "Correct area should be smaller than Labeled" 24 | return pixel_correct, pixel_labeled 25 | 26 | 27 | def batch_intersection_union(output, target, nclass): 28 | """Batch Intersection of Union 29 | Args: 30 | predict: input 4D tensor 31 | target: label 3D tensor 32 | nclass: number of categories (int) 33 | """ 34 | _, predict = torch.max(output, 1) 35 | mini = 1 36 | maxi = nclass 37 | nbins = nclass 38 | predict = predict.cpu().numpy().astype('int64') + 1 39 | target = target.cpu().numpy().astype('int64') + 1 40 | 41 | predict = predict * (target > 0).astype(predict.dtype) 42 | intersection = predict * (predict == target) 43 | # areas of intersection and union 44 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 45 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 46 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 47 | area_union = area_pred + area_lab - area_inter 48 | assert (area_inter <= area_union).all(), \ 49 | "Intersection area should be smaller than Union area" 50 | return area_inter, area_union 51 | 52 | 53 | class SegmentationLosses(nn.CrossEntropyLoss): 54 | def __init__(self, se_loss=False, se_weight=0.2, nclass=-1, 55 | aux=False, aux_weight=0.4, weight=None, 56 | ignore_index=-1): 57 | super(SegmentationLosses, self).__init__(weight, None, ignore_index) 58 | self.se_loss = se_loss 59 | self.aux = aux 60 | self.nclass = nclass 61 | self.se_weight = se_weight 62 | self.aux_weight = aux_weight 63 | self.bceloss = nn.BCELoss(weight) 64 | 65 | def forward(self, *inputs): 66 | if not self.se_loss and not self.aux: 67 | return super(SegmentationLosses, self).forward(*inputs) 68 | elif not self.se_loss: 69 | pred1, pred2, target = tuple(inputs) 70 | loss1 = super(SegmentationLosses, self).forward(pred1, target) 71 | loss2 = super(SegmentationLosses, self).forward(pred2, target) 72 | return loss1 + self.aux_weight * loss2 73 | elif not self.aux: 74 | pred, se_pred, target = tuple(inputs) 75 | se_target = self._get_batch_label_vector( 76 | target, nclass=self.nclass).type_as(pred) 77 | loss1 = super(SegmentationLosses, self).forward(pred, target) 78 | loss2 = self.bceloss(torch.sigmoid(se_pred), se_target) 79 | return loss1 + self.se_weight * loss2 80 | else: 81 | pred1, se_pred, pred2, target = tuple(inputs) 82 | se_target = self._get_batch_label_vector( 83 | target, nclass=self.nclass).type_as(pred1) 84 | loss1 = super(SegmentationLosses, self).forward(pred1, target) 85 | loss2 = super(SegmentationLosses, self).forward(pred2, target) 86 | loss3 = self.bceloss(torch.sigmoid(se_pred), se_target) 87 | return loss1 + self.aux_weight * loss2 + self.se_weight * loss3 88 | 89 | @staticmethod 90 | def _get_batch_label_vector(target, nclass): 91 | # target is a 3D Variable BxHxW, output is 2D BxnClass 92 | batch = target.size(0) 93 | tvect = Variable(torch.zeros(batch, nclass)) 94 | for i in range(batch): 95 | hist = torch.histc(target[i].cpu().data.float(), 96 | bins=nclass, min=0, 97 | max=nclass-1) 98 | vect = hist > 0 99 | tvect[i] = vect 100 | return tvect 101 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | import yaml 5 | import numpy as np 6 | import gzip 7 | import scipy.io 8 | 9 | def load_pickle_object(file_name, compress=True): 10 | data = read(file_name) 11 | if compress: 12 | load_object = pickle.loads(gzip.decompress(data)) 13 | else: 14 | load_object = pickle.loads(data) 15 | return load_object 16 | 17 | 18 | def dump_pickle_object(dump_object, file_name, compress=True, compress_level=9): 19 | data = pickle.dumps(dump_object) 20 | if compress: 21 | write(file_name, gzip.compress(data, compresslevel=compress_level)) 22 | else: 23 | write(file_name, data) 24 | 25 | 26 | def load_json_object(file_name, compress=False): 27 | if compress: 28 | return json.loads(gzip.decompress(read(file_name)).decode('utf8')) 29 | else: 30 | return json.loads(read(file_name, 'r')) 31 | 32 | 33 | def dump_json_object(dump_object, file_name, compress=False, indent=4): 34 | data = json.dumps( 35 | dump_object, cls=NumpyAwareJSONEncoder, sort_keys=True, indent=indent) 36 | if compress: 37 | write(file_name, gzip.compress(data.encode('utf8'))) 38 | else: 39 | write(file_name, data, 'w') 40 | 41 | 42 | def dumps_json_object(dump_object, indent=4): 43 | data = json.dumps( 44 | dump_object, cls=NumpyAwareJSONEncoder, sort_keys=True, indent=indent) 45 | return data 46 | 47 | 48 | def load_mat_object(file_name): 49 | return scipy.io.loadmat(file_name=file_name) 50 | 51 | 52 | def load_yaml_object(file_name): 53 | return yaml.load(read(file_name, 'r')) 54 | 55 | 56 | def read(file_name, mode='rb'): 57 | with open(file_name, mode) as f: 58 | return f.read() 59 | 60 | 61 | def write(file_name, data, mode='wb'): 62 | with open(file_name, mode) as f: 63 | f.write(data) 64 | 65 | 66 | def serialize_object(in_obj, method='json'): 67 | if method == 'json': 68 | return json.dumps(in_obj) 69 | else: 70 | return pickle.dumps(in_obj) 71 | 72 | 73 | def deserialize_object(obj_str, method='json'): 74 | if method == 'json': 75 | return json.loads(obj_str) 76 | else: 77 | return pickle.loads(obj_str) 78 | 79 | 80 | def mkdir_if_not_exists(dir_name, recursive=False): 81 | if os.path.exists(dir_name): 82 | return 83 | if recursive: 84 | os.makedirs(dir_name) 85 | else: 86 | os.mkdir(dir_name) 87 | 88 | 89 | class NumpyAwareJSONEncoder(json.JSONEncoder): 90 | def default(self, obj): 91 | if isinstance(obj, np.ndarray): 92 | if obj.ndim == 1: 93 | return obj.tolist() 94 | else: 95 | return [self.default(obj[i]) for i in range(obj.shape[0])] 96 | elif isinstance(obj, np.int64): 97 | return int(obj) 98 | elif isinstance(obj, np.int32): 99 | return int(obj) 100 | elif isinstance(obj, np.int16): 101 | return int(obj) 102 | elif isinstance(obj, np.float64): 103 | return float(obj) 104 | elif isinstance(obj, np.float32): 105 | return float(obj) 106 | elif isinstance(obj, np.float16): 107 | return float(obj) 108 | elif isinstance(obj, np.uint64): 109 | return int(obj) 110 | elif isinstance(obj, np.uint32): 111 | return int(obj) 112 | elif isinstance(obj, np.uint16): 113 | return int(obj) 114 | return json.JSONEncoder.default(self, obj) 115 | 116 | 117 | class JsonSerializableClass(): 118 | def to_json(self,json_filename=None): 119 | serialized_dict = json.dumps( 120 | self, 121 | default=lambda o: o.__dict__, 122 | sort_keys=True, 123 | indent=4) 124 | serialized_dict = json.loads(serialized_dict) 125 | if json_filename is not None: 126 | dump_json_object(serialized_dict,json_filename) 127 | 128 | return serialized_dict 129 | 130 | def from_json(self,json_filename): 131 | assert(type(json_filename is dict)), 'Use from dict instead' 132 | dict_to_restore = load_json_object(json_filename) 133 | for attr_name, attr_value in dict_to_restore.items(): 134 | setattr(self,attr_name,attr_value) 135 | 136 | def from_dict(self,dict_to_restore): 137 | for attr_name, attr_value in dict_to_restore.items(): 138 | setattr(self,attr_name,attr_value) 139 | 140 | 141 | class WritableToFile(): 142 | def to_file(self,filename): 143 | with open(filename,'w') as file: 144 | file.write(self.__str__()) -------------------------------------------------------------------------------- /eval_instructions.txt: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------# 2 | Steps to be followed 3 | # ------------------------------------------------------------------------------------------------# 4 | 5 | 6 | 1. git clone https://github.com/lalithjets/Global-reasoned-multi-task-model.git 7 | 2. cd Global-reasoned-multi-task-model/ 8 | 9 | 10 | # ------------------------- Download Commands ------------------------- # 11 | 12 | # ------------------------- Checkpoints ------------------------- # 13 | Link : https://drive.google.com/file/d/1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l/view?usp=sharing 14 | 15 | Command : (GDrive wget download - Optional) - Can be downloaded manually and placed in root 16 | > 3. wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l" -O gr_mtl_ssu_checkpoints.zip && rm -rf /tmp/cookies.txt 17 | 18 | 4. unzip gr_mtl_ssu_checkpoints.zip 19 | 5. rm gr_mtl_ssu_checkpoints.zip 20 | 21 | # ------------------------- Dataset ------------------------- # 22 | Link : https://drive.google.com/file/d/1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0/view?usp=sharing 23 | 24 | Command : (GDrive wget download - Optional) - Can be downloaded manually and placed in root 25 | > 6. wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0" -O gr_mtl_ssu_dataset.zip && rm -rf /tmp/cookies.txt 26 | 27 | 7. unzip gr_mtl_ssu_dataset.zip 28 | 8. rm gr_mtl_ssu_dataset.zip 29 | 30 | 9. Set the model_type, ver, seg_mode and checkpoint_dir in evaluation.py as given in instructions 31 | 32 | # ------------------------- Run the command for Evaluation ------------------------- # 33 | 10. CUDA_VISIBLE_DEVICES=1 python3 evaluation.py 34 | 35 | 36 | # --------------------------------------------- Sample Output --------------------------------------------- # 37 | 38 | Settings : 39 | 40 | model_type = 'amtl-t0' 41 | ver = 'amtl_t0_sv1' 42 | seg_mode = 'v1' 43 | checkpoint_dir = 'amtl_t0_sv1' 44 | 45 | # ------------------------------------------------------------------------------------------------# 46 | Output 47 | # ------------------------------------------------------------------------------------------------# 48 | 49 | ================= Evaluation ==================== 50 | Graph : acc: 0.7003 map: 0.2885 recall: 0.3096 loss: 0.3764} 51 | Segmentation : Pacc: 0.9638 mIoU: 0.4354 loss: 0.1500} 52 | 53 | ================= Class-wise IoU ==================== 54 | Mean Value: 0.435358693711956 55 | 56 | | Class | IoU | 57 | |---------------------------+------------| 58 | | Background | 0.971428 | 59 | | Bipolar_Forceps | 0.696591 | 60 | | Prograsp_Forceps | 0.435617 | 61 | | Large_Needle_Driver | 0.00154275 | 62 | | Monopolar_Curved_Scissors | 0.871583 | 63 | | Ultrasound_Probe | 0.120284 | 64 | | Suction_Instrument | 0.347132 | 65 | | Clip_Applier | 0.0386921 | 66 | 67 | 68 | 69 | # ------------------------------------------------------------------------------------------------# 70 | Eval Repository Structure 71 | # ------------------------------------------------------------------------------------------------# 72 | 73 | ├── checkpoints 74 | │   ├── amtl_t0_s 75 | │   │   └── best_epoch.pth 76 | │   ├── amtl_t0_sv1 77 | │   │   └── best_epoch.pth 78 | │   ├── amtl_t0_sv2gc 79 | │   │   └── best_epoch.pth 80 | │   ├── amtl_t3g_sv1 81 | │   │   └── best_epoch.pth 82 | │   ├── amtl_t3pn_sv1 83 | │   │   └── best_epoch.pth 84 | │   ├── mtl_kd_t0_s 85 | │   │   └── best_epoch.pth 86 | │   ├── mtl_kd_t0_sv1 87 | │   │   └── best_epoch.pth 88 | │   ├── mtl_kd_t1_sv1 89 | │   │   └── best_epoch.pth 90 | │   ├── mtl_kd_t3g_sv1 91 | │   │   └── best_epoch.pth 92 | │   ├── stl_s 93 | │   │   └── best_epoch.pth 94 | │   ├── stl_sg 95 | │   │   └── best_epoch.pth 96 | │   ├── stl_s_ng 97 | │   │   └── best_epoch.pth 98 | │   ├── stl_s_v1 99 | │   │   └── best_epoch.pth 100 | │   └── stl_s_v2gc 101 | │   └── best_epoch.pth 102 | ├── dataset 103 | │   ├── labels_isi_dataset.json 104 | │   ├── seq_1 105 | │   │   ├── annotations 106 | │   │   │   ├── frame000.png 107 | │   │   │   ├── ... 108 | │   │   ├── left_frames 109 | │   │   │   ├── frame000.png 110 | │   │   │   ├── ... 111 | │   │   ├── vsgat 112 | │   │   │   └── features 113 | │   │   │   ├── frame000_features.hdf5 114 | │   │   │   ├── ... 115 | │   │   └── xml 116 | │   │   ├── frame000.xml 117 | │   │   ├── ... 118 | │   ├── seq_16 119 | │   │   ├── annotations 120 | │   │   │   ├── frame000.png 121 | │   │   │   ├── ... 122 | │   │   ├── left_frames 123 | │   │   │   ├── frame000.png 124 | │   │   │   ├── ... 125 | │   │   ├── vsgat 126 | │   │   │   └── features 127 | │   │   │   ├── frame000_features.hdf5 128 | │   │   │   ├── ... 129 | │   │   └── xml 130 | │   │   ├── frame000.xml 131 | │   │   ├── ... 132 | │   ├── seq_5 133 | │   │   ├── annotations 134 | │   │   │   ├── frame000.png 135 | │   │   │   ├── ... 136 | │   │   ├── left_frames 137 | │   │   │   ├── frame000.png 138 | │   │   │   ├── ... 139 | │   │   ├── vsgat 140 | │   │   │   └── features 141 | │   │   │   ├── frame000_features.hdf5 142 | │   │   │   ├── ... 143 | │   │   └── xml 144 | │   │   ├── frame000.xml 145 | │   │   ├── ... 146 | │   └── surgicalscene_word2vec.hdf5 147 | ├── environment.yml 148 | ├── evaluation.py 149 | ├── eval_instructions.txt 150 | ├── figures 151 | │   ├── figure_1.pdf 152 | │   ├── figure_2.pdf 153 | │   ├── figure_3.pdf 154 | │   ├── figure_4.pdf 155 | │   └── figure_5.pdf 156 | ├── models 157 | │   ├── mtl_model.py 158 | │   ├── __pycache__ 159 | │   │   ├── mtl_model.cpython-36.pyc 160 | │   │   ├── scene_graph.cpython-36.pyc 161 | │   │   ├── segmentation_model.cpython-36.pyc 162 | │   │   └── surgicalDataset.cpython-36.pyc 163 | │   ├── scene_graph.py 164 | │   ├── segmentation_model.py 165 | │   └── surgicalDataset.py 166 | ├── model_train.py 167 | ├── README.md 168 | ├── result_logs 169 | │   ├── results_combined 170 | │   └── results_kd.txt 171 | └── utils 172 |    ├── io.py 173 |    ├── __pycache__ 174 |    │   ├── scene_graph_eval_matrix.cpython-36.pyc 175 |    │   └── segmentation_eval_matrix.cpython-36.pyc 176 |    ├── scene_graph_eval_matrix.py 177 |    ├── segmentation_eval_matrix.py 178 |    ├── utils.py 179 |    └── vis_tool.py 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |

Global-Reasoned Multi-Task Model for Surgical Scene Understanding

6 | 7 |

Seenivasan lalithkumar, Sai Mitheran, Mobarakol Islam, Hongliang Ren

8 | 9 |
10 | 11 | --- 12 | 13 | | **[ [```arXiv```]() ]** |**[ [```Paper```]() ]** |**[ [```YouTube```]() ]** | 14 | |:-------------------:|:-------------------:|:-------------------:| 15 | 16 | ICRA 2022, IEEE Robotics and Automation Letters (RA-L) 17 | 18 |
19 | 20 | If you find our code or paper useful, please cite as 21 | 22 | ```bibtex 23 | @article{seenivasan2022global, 24 | title={Global-Reasoned Multi-Task Learning Model for Surgical Scene Understanding}, 25 | author={Seenivasan, Lalithkumar and Mitheran, Sai and Islam, Mobarakol and Ren, Hongliang}, 26 | journal={IEEE Robotics and Automation Letters}, 27 | year={2022}, 28 | publisher={IEEE} 29 | } 30 | ``` 31 | 32 | --- 33 | 34 | ## Introduction 35 | Global and local relational reasoning enable scene understanding models to perform human-like scene analysis and understanding. Scene understanding enables better semantic segmentation and object-to-object interaction detection. In the medical domain, a robust surgical scene understanding model allows the automation of surgical skill evaluation, real-time monitoring of surgeon’s performance and post-surgical analysis. This paper introduces a globally-reasoned multi-task surgical scene understanding model capable of performing instrument segmentation and tool-tissue interaction detection. Here, we incorporate global relational reasoning in the latent interaction space and introduce multi-scale local (neighborhood) reasoning in the coordinate space to improve segmentation. Utilizing the multi-task model setup, the performance of the visual-semantic graph attention network in interaction detection is further enhanced through global reasoning. The global interaction space features from the segmentation module are introduced into the graph network, allowing it to detect interactions based on both node-to-node and global interaction reasoning. Our model reduces the computation cost compared to running two independent single-task models by sharing common modules, which is indispensable for practical applications. Using a sequential optimization technique, the proposed multi-task model outperforms other state-of-the-art single-task models on the MICCAI endoscopic vision challenge 2018 dataset. Additionally, we also observe the performance of the multi-task model when trained using the knowledge distillation technique. 36 | 37 | ## Method 38 | 39 | ![framework](figures/figure_2.jpg) 40 | 41 | The proposed network architecture. The proposed globally-reasoned multi-task scene understanding model consists of a shared feature extractor. The segmentation module performs latent global reasoning (GloRe unit [2]) and local reasoning (multi-scale local reasoning) to segment instruments. To detect tool interaction, the scene graph (tool interaction detection) model incorporates the global interaction space features to further improve the performance of the visual-semantic graph attention network [1]. 42 | 43 | ### Feature Sharing 44 | 45 |

46 | features 47 |

48 | 49 | Variants of feature sharing between the segmentation and scene graph modules in multi-task setting to improve single-task performance 50 | 51 | --- 52 | 53 | ## Directory setup 54 | 55 | In this project, we implement our method using the Pytorch and DGL library, the structure is as follows: 56 | 57 | - `dataset/`: Contains the data needed to train the network. 58 | - `checkpoints/`: Contains trained weights. 59 | - `models/`: Contains network models. 60 | - `utils/`: Contains utility tools used for training and evaluation. 61 | 62 | --- 63 | 64 | ## Library Prerequisities 65 | 66 | ### DGL 67 | DGL is a Python package dedicated to deep learning on graphs, built atop existing tensor DL frameworks (e.g. Pytorch, MXNet) and simplifying the implementation of graph-based neural networks 68 | 69 | ### Dependencies (Used for Experiments) 70 | - Python 3.6 71 | - Pytorch 1.7.1 72 | - DGL 0.4.2 73 | - CUDA 10.2 74 | - Ubuntu 16.04 75 | 76 | --- 77 | 78 | ## Setup (From an Env File) 79 | 80 | We have provided environment files for installation using conda 81 | 82 | ### Using Conda 83 | 84 | ```bash 85 | conda env create -f environment.yml 86 | ``` 87 | 88 | --- 89 | 90 | ## Dataset: 91 | 1. Frames - Left camera images from [2018 robotic scene segmentation challenge](https://arxiv.org/pdf/2001.11190.pdf) are used in this work. 92 | 2. Instrument label - To be released! 93 | 3. BBox and Tool-Tissue interaction annotation - [Our annotations](https://drive.google.com/file/d/16G_Pf4E9KjVq7j_7BfBKHg0NyQQ0oTxP/view?usp=sharing) (Cite this paper / [our previous work](https://link.springer.com/chapter/10.1007/978-3-030-59716-0_60) when using these annotations.) 94 | 4. Download the pretrain word2vec model on [GoogleNews](https://code.google.com/archive/p/word2vec/) and put it into `dataset/word2vec` 95 | 96 | --- 97 | 98 | ## Training 99 | ### Process dataset (For Spatial Features) 100 | - To be released! 101 | 102 | ### Run training 103 | - Set the model_type, version for the mode to be trained according to the instructions given in the train file 104 | 105 | ```bash 106 | python3 model_train.py 107 | ``` 108 | 109 | --- 110 | ## Evaluation 111 | For the direct sequence of commands to be followed, refer to [this link](https://github.com/lalithjets/Global-reasoned-multi-task-model/blob/master/eval_instructions.txt) 112 | 113 | ### Pre-trained Models 114 | Download from **[[`Checkpoints Link`](https://drive.google.com/file/d/1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l/view?usp=sharing)]**, place it inside the repository root and unzip 115 | 116 | ### Evaluation Data 117 | Download from **[[`Dataset Link`](https://drive.google.com/file/d/1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0/view?usp=sharing)]** and place it inside the repository root and unzip 118 | 119 | ### Inference 120 | To reproduce the results, set the model_type, ver, seg_mode and checkpoint_dir based on the table given [here](https://github.com/lalithjets/Global-reasoned-multi-task-model/blob/c6668fcca712d3bd5ca25c66b11d34305103af94/evaluation.py#L195) 121 | - model_type 122 | - ver 123 | - seg_mode 124 | - checkpoint_dir 125 | 126 | ```bash 127 | python3 evaluation.py 128 | ``` 129 | 130 | --- 131 | 132 | ## Acknowledgement 133 | Code adopted and modified from : 134 | 1. Visual-Semantic Graph Attention Network for Human-Object Interaction Detecion 135 | - Paper [Visual-Semantic Graph Attention Network for Human-Object Interaction Detecion](https://arxiv.org/abs/2001.02302). 136 | - Official Pytorch implementation [code](https://github.com/birlrobotics/vs-gats). 137 | 1. Graph-Based Global Reasoning Networks 138 | - Paper [Graph-Based Global Reasoning Networks](https://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Graph-Based_Global_Reasoning_Networks_CVPR_2019_paper.pdf). 139 | - Official code implementation [code](https://github.com/facebookresearch/GloRe.git). 140 | 141 | --- 142 | 143 | ## Other Works: 144 | 1. Learning and Reasoning with the Graph Structure Representation in Robotic Surgery| **[ [```arXiv```]() ]** |**[ [```Paper```]() ]** | 145 | 146 | --- 147 | 148 | ## Contact 149 | 150 | For any queries, please contact [Lalithkumar](mailto:lalithjets@gmail.com) or [Sai Mitheran](mailto:saimitheran06@gmail.com) 151 | -------------------------------------------------------------------------------- /models/surgicalDataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding 3 | Lab : MMLAB, National University of Singapore 4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren 5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation 6 | 7 | ''' 8 | 9 | 10 | import os 11 | import sys 12 | import random 13 | 14 | import h5py 15 | import numpy as np 16 | from glob import glob 17 | from PIL import Image 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torchvision.transforms as transforms 22 | from torch.utils.data import Dataset 23 | 24 | 25 | class SurgicalSceneConstants(): 26 | ''' 27 | Set the instrument classes and action classes, with path to XML and Word2Vec Features (if applicable) 28 | ''' 29 | def __init__(self): 30 | self.instrument_classes = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver', 31 | 'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier', 32 | 'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery') 33 | 34 | self.action_classes = ('Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 35 | 'Tool_Manipulation', 'Cutting', 'Cauterization', 36 | 'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 37 | 'Ultrasound_Sensing') 38 | 39 | self.xml_data_dir = 'dataset/instruments18/seq_' 40 | self.word2vec_loc = 'dataset/surgicalscene_word2vec.hdf5' 41 | 42 | 43 | class SurgicalSceneDataset(Dataset): 44 | ''' 45 | Dataset class for the MTL Model 46 | Inputs: sequence set, data directory (root), image directory, mask directory, augmentation flag (istrain), dataset (dset), feature extractor chosen 47 | ''' 48 | def __init__(self, seq_set, data_dir, img_dir, mask_dir, istrain, dset, dataconst, feature_extractor, reduce_size=False): 49 | 50 | self.data_size = 143 51 | self.dataconst = dataconst 52 | self.img_dir = img_dir 53 | self.mask_dir = mask_dir 54 | self.is_train = istrain 55 | self.feature_extractor = feature_extractor 56 | self.reduce_size = reduce_size 57 | 58 | # Images and masks are resized to (320, 400) 59 | self.resizer = transforms.Compose([transforms.Resize((320, 400))]) 60 | 61 | self.xml_dir_list = [] 62 | self.dset = [] 63 | 64 | for domain in range(len(seq_set)): 65 | domain_dir_list = [] 66 | for i in seq_set[domain]: 67 | xml_dir_temp = data_dir[domain] + str(i) + '/xml/' 68 | domain_dir_list = domain_dir_list + glob(xml_dir_temp + '/*.xml') 69 | if self.reduce_size: 70 | indices = np.random.permutation(len(domain_dir_list)) 71 | domain_dir_list = [domain_dir_list[j] for j in indices[0:self.data_size]] 72 | for file in domain_dir_list: 73 | self.xml_dir_list.append(file) 74 | self.dset.append(dset[domain]) 75 | self.word2vec = h5py.File('dataset/surgicalscene_word2vec.hdf5', 'r') 76 | 77 | # Word2Vec function 78 | def _get_word2vec(self, node_ids, sgh=0): 79 | word2vec = np.empty((0, 300)) 80 | for node_id in node_ids: 81 | if sgh == 1 and node_id == 0: 82 | vec = self.word2vec['tissue'] 83 | else: 84 | vec = self.word2vec[self.dataconst.instrument_classes[node_id]] 85 | word2vec = np.vstack((word2vec, vec)) 86 | return word2vec 87 | 88 | # Dataset length 89 | def __len__(self): 90 | return len(self.xml_dir_list) 91 | 92 | # Function to get images and masks 93 | def __getitem__(self, idx): 94 | 95 | file_name = os.path.splitext(os.path.basename(self.xml_dir_list[idx]))[0] 96 | file_root = os.path.dirname(os.path.dirname(self.xml_dir_list[idx])) 97 | if len(self.img_dir) == 1: 98 | _img_loc = os.path.join(file_root+self.img_dir[0] + file_name + '.png') 99 | _mask_loc = os.path.join(file_root+self.mask_dir[0] + file_name + '.png') 100 | 101 | else: 102 | _img_loc = os.path.join( file_root+self.img_dir[self.dset[idx]] + file_name + '.png') 103 | _mask_loc = os.path.join( file_root+self.mask_dir[self.dset[idx]] + file_name + '.png') 104 | 105 | 106 | _img = Image.open(_img_loc).convert('RGB') 107 | _target = Image.open(_mask_loc) 108 | 109 | if self.is_train: 110 | isAugment = random.random() < 0.5 111 | if isAugment: 112 | isHflip = random.random() < 0.5 113 | if isHflip: 114 | _img = _img.transpose(Image.FLIP_LEFT_RIGHT) 115 | _target = _target.transpose(Image.FLIP_LEFT_RIGHT) 116 | else: 117 | _img = _img.transpose(Image.FLIP_TOP_BOTTOM) 118 | _target = _target.transpose(Image.FLIP_TOP_BOTTOM) 119 | 120 | _img = np.asarray(_img, np.float32) * 1.0 / 255 121 | _img = torch.from_numpy(np.array(_img).transpose(2, 0, 1)).float() 122 | _target = torch.from_numpy(np.array(_target)).long() 123 | 124 | frame_data = h5py.File(os.path.join( file_root+'/vsgat/'+self.feature_extractor+'/'+ file_name + '_features.hdf5'), 'r') 125 | 126 | data = {} 127 | 128 | data['img_name'] = frame_data['img_name'][()][:] + '.jpg' 129 | data['img_loc'] = _img_loc 130 | 131 | # segmentation 132 | data['img'] = self.resizer(_img.unsqueeze(0)) 133 | data['mask'] = self.resizer(_target.unsqueeze(0)) 134 | 135 | 136 | data['node_num'] = frame_data['node_num'][()] 137 | data['roi_labels'] = frame_data['classes'][:] 138 | data['det_boxes'] = frame_data['boxes'][:] 139 | 140 | data['edge_labels'] = frame_data['edge_labels'][:] 141 | data['edge_num'] = data['edge_labels'].shape[0] 142 | 143 | data['features'] = frame_data['node_features'][:] 144 | data['spatial_feat'] = frame_data['spatial_features'][:] 145 | 146 | data['word2vec'] = self._get_word2vec(data['roi_labels'], self.dset[idx]) 147 | return data 148 | 149 | 150 | # For Dataset Loader 151 | def collate_fn(batch): 152 | ''' 153 | Default collate_fn(): https://github.com/pytorch/pytorch/blob/1d53d0756668ce641e4f109200d9c65b003d05fa/torch/utils/data/_utils/collate.py#L43 154 | Inputs: Data Batch 155 | ''' 156 | batch_data = {} 157 | batch_data['img_name'] = [] 158 | batch_data['img_loc'] = [] 159 | batch_data['img'] = [] 160 | batch_data['mask'] = [] 161 | batch_data['node_num'] = [] 162 | batch_data['roi_labels'] = [] 163 | batch_data['det_boxes'] = [] 164 | batch_data['edge_labels'] = [] 165 | batch_data['edge_num'] = [] 166 | batch_data['features'] = [] 167 | batch_data['spatial_feat'] = [] 168 | batch_data['word2vec'] = [] 169 | 170 | for data in batch: 171 | batch_data['img_name'].append(data['img_name']) 172 | batch_data['img_loc'].append(data['img_loc']) 173 | batch_data['img'].append(data['img']) 174 | batch_data['mask'].append(data['mask']) 175 | batch_data['node_num'].append(data['node_num']) 176 | batch_data['roi_labels'].append(data['roi_labels']) 177 | batch_data['det_boxes'].append(data['det_boxes']) 178 | batch_data['edge_labels'].append(data['edge_labels']) 179 | batch_data['edge_num'].append(data['edge_num']) 180 | batch_data['features'].append(data['features']) 181 | batch_data['spatial_feat'].append(data['spatial_feat']) 182 | batch_data['word2vec'].append(data['word2vec']) 183 | 184 | batch_data['img'] = torch.FloatTensor(np.concatenate(batch_data['img'], axis=0)) 185 | batch_data['mask'] = torch.LongTensor(np.concatenate(batch_data['mask'], axis=0)) 186 | batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(batch_data['edge_labels'], axis=0)) 187 | batch_data['features'] = torch.FloatTensor(np.concatenate(batch_data['features'], axis=0)) 188 | batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(batch_data['spatial_feat'], axis=0)) 189 | batch_data['word2vec'] = torch.FloatTensor(np.concatenate(batch_data['word2vec'], axis=0)) 190 | 191 | return batch_data 192 | -------------------------------------------------------------------------------- /models/mtl_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding 3 | Lab : MMLAB, National University of Singapore 4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren 5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation 6 | ''' 7 | 8 | import cv2 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import torch 13 | import torchvision 14 | import torch.nn as nn 15 | 16 | class mtl_model(nn.Module): 17 | ''' 18 | Multi-task model : Graph Scene Understanding and segmentation 19 | Forward uses features from feature_extractor 20 | ''' 21 | 22 | def __init__(self, feature_encoder, scene_graph, seg_gcn_block, seg_decoder, seg_mode = None): 23 | super(mtl_model, self).__init__() 24 | self.feature_encoder = feature_encoder 25 | self.gcn_unit = seg_gcn_block 26 | self.seg_mode = seg_mode 27 | self.seg_decoder = seg_decoder 28 | self.scene_graph = scene_graph 29 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 30 | self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 31 | 32 | def model_type1_insert(self): 33 | self.sg_avgpool = nn.AdaptiveAvgPool1d(1) 34 | self.sg_linear = nn.Linear(1040, 128) 35 | self.sg_feat_s1d1 = nn.Conv1d(1, 1, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 36 | 37 | def model_type2_insert(self): 38 | self.sg2_linear = nn.Linear(1040, 128) 39 | 40 | def model_type3_insert(self): 41 | # self.sf_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 42 | self.sf_avgpool = nn.AdaptiveAvgPool1d(1) 43 | #self.sf_linear = nn.Linear(256, 128) 44 | 45 | def set_train_test(self, model_type): 46 | ''' train Feature extractor for scene graph ''' 47 | # if model_type == 'stl-s' or model_type == 'amtl-t0' or model_type == 'amtl-t3' or model_type == 'stl-sg': 48 | if model_type == 'stl-s' or model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 49 | self.train_FE_SG = False 50 | else: 51 | self.train_FE_SG = True 52 | 53 | ''' train feature extractor for segmentation ''' 54 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 55 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3':# or model_type == 'amtl-t1': 56 | self.Train_FE_SEG = False 57 | else: 58 | self.Train_FE_SEG = True 59 | 60 | ''' train scene graph''' 61 | # set train flag for scene graph 62 | if model_type == 'stl-s': 63 | self.Train_SG = False 64 | else: 65 | self.Train_SG = True 66 | 67 | ''' train segmentation GR-unit (Global-Reasoniing unit) ''' 68 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 69 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 70 | self.Train_SEG_GR = False 71 | else: 72 | self.Train_SEG_GR = True 73 | 74 | ''' train segmentation decoder ''' 75 | # set train flag for segmentation decoder 76 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 77 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3': 78 | self.Train_SG_DECODER = False 79 | else: 80 | self.Train_SG_DECODER = True 81 | 82 | self.model_type = model_type 83 | 84 | 85 | def forward(self, img, img_dir, det_boxes_all, node_num, spatial_feat, word2vec, roi_labels, validation=False): 86 | 87 | gsu_node_feat = None 88 | seg_inputs = None 89 | interaction = None 90 | imsize = img.size()[2:] 91 | 92 | # ====================================================== Extract node features for Scene graph ============================================================== 93 | if not self.train_FE_SG: 94 | ''' skip training the feature extractor for scene graph ''' 95 | with torch.no_grad(): 96 | for index, img_loc in enumerate(img_dir): 97 | _img = Image.open(img_loc).convert('RGB') 98 | _img = np.array(_img) 99 | img_stack = None 100 | for bndbox in det_boxes_all[index]: 101 | roi = np.array(bndbox).astype(int) 102 | roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :] 103 | roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR)) 104 | roi_image = torch.autograd.Variable(roi_image.unsqueeze(0)) 105 | # stack nodes images per image 106 | img_stack = roi_image if img_stack == None else torch.cat((img_stack, roi_image)) 107 | 108 | img_stack = img_stack.cuda(non_blocking=True) 109 | _, _, _, img_stack = self.feature_encoder(img_stack) 110 | 111 | img_stack = self.avgpool(img_stack) 112 | img_stack = img_stack.view(img_stack.size(0), -1) 113 | 114 | # # prepare graph node features 115 | gsu_node_feat = img_stack if gsu_node_feat == None else torch.cat((gsu_node_feat, img_stack)) 116 | 117 | else: 118 | # print('node_info grad enabled') 119 | for index, img_loc in enumerate(img_dir): 120 | _img = Image.open(img_loc).convert('RGB') 121 | _img = np.array(_img) 122 | img_stack = None 123 | for bndbox in det_boxes_all[index]: 124 | roi = np.array(bndbox).astype(int) 125 | roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :] 126 | roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR)) 127 | roi_image = torch.autograd.Variable(roi_image.unsqueeze(0)) 128 | # stack nodes images per image 129 | img_stack = roi_image if img_stack == None else torch.cat((img_stack, roi_image)) 130 | 131 | img_stack = img_stack.cuda(non_blocking=True) 132 | _, _, _, img_stack = self.feature_encoder(img_stack) 133 | img_stack = self.avgpool(img_stack) 134 | img_stack = img_stack.view(img_stack.size(0), -1) 135 | # prepare graph node features 136 | gsu_node_feat = img_stack if gsu_node_feat == None else torch.cat((gsu_node_feat, img_stack)) 137 | # ================================================================================================================================================================ 138 | # ===================================================== Segmentation feature extractor =========================================================================== 139 | if not self.Train_FE_SEG: 140 | ''' Skip training feature encoder for segmentation task ''' 141 | with torch.no_grad(): 142 | s1, s2, s3, seg_inputs = self.feature_encoder(img) 143 | fe_feat = seg_inputs 144 | else: 145 | # print('segment encoder enabled') 146 | s1, s2, s3, seg_inputs = self.feature_encoder(img) 147 | fe_feat = seg_inputs 148 | # ================================================================================================================================================================ 149 | # ================================================= Scene graph and segmentation GR (Global Reasoning) unit ====================================================== 150 | if self.model_type == 'amtl-t1' or self.model_type == 'mtl-t1': 151 | ''' 152 | In type 1, interaction features are passed to segmentation GR (Global Reasoning) module. 153 | inside GR unit, (x = x + h + avg((x)T) * sg_feat[1x128]) 154 | Here interation is called before GR unit. 155 | ''' 156 | ''' ==== scene graph ==== ''' 157 | # print('inside mtl-1') 158 | interaction, sg_feat = self.scene_graph(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels, validation= validation) 159 | 160 | ''' ==== GR (Global Reasoning) ==== ''' 161 | edge_sum = 0 162 | batch_sg_feat = None 163 | for n in node_num: 164 | active_edges = n-1 if n >1 else n 165 | if batch_sg_feat == None: 166 | batch_sg_feat = self.sg_linear(self.sg_avgpool(sg_feat[edge_sum:edge_sum+active_edges, :].unsqueeze(0).permute(0,2,1)).permute(0,2,1)) 167 | else: 168 | batch_sg_feat = torch.cat((batch_sg_feat, self.sg_linear(self.sg_avgpool(sg_feat[edge_sum:edge_sum+active_edges, :].unsqueeze(0).permute(0,2,1)).permute(0,2,1)))) 169 | edge_sum += active_edges 170 | batch_sg_feat = self.sg_feat_s1d1(batch_sg_feat) 171 | s1, s2, s3, seg_inputs, _ = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, scene_feat = batch_sg_feat, seg_mode = self.seg_mode, model_type = self.model_type) 172 | 173 | elif self.model_type == 'amtl-t2' or self.model_type == 'mtl-t2': 174 | ''' 175 | In type 2, interaction features are passed to segmentation GR module. Replace 176 | inside GR, GCN is replaced with x = x * sg_feat [128 x 128] 177 | Here interation is called before GR unit. 178 | ''' 179 | ''' ==== scene graph ==== ''' 180 | interaction, sg_feat = self.scene_graph(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels, validation= validation) 181 | 182 | ''' ==== GR (Global Reasoning) ==== ''' 183 | edge_sum = 0 184 | batch_sg_feat = None 185 | for n in node_num: 186 | active_edges = n-1 if n >1 else n 187 | if batch_sg_feat == None: 188 | batch_sg_feat = torch.matmul(self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :]).permute(1, 0), \ 189 | self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :])).unsqueeze(0) 190 | else: 191 | batch_sg_feat = torch.cat((batch_sg_feat, torch.matmul(self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :]).permute(1, 0), \ 192 | self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :])).unsqueeze(0))) 193 | edge_sum += active_edges 194 | s1, s2, s3, seg_inputs, _ = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, scene_feat = batch_sg_feat, seg_mode = self.seg_mode, model_type = self.model_type) 195 | 196 | else: 197 | ''' 198 | If it's not type 1 & 2, then GR is processed before interaction. 199 | ''' 200 | ''' ==== GR (Global Reasoning) ==== ''' 201 | if not self.Train_SEG_GR: 202 | ''' skip GR unit training ''' 203 | with torch.no_grad(): 204 | s1, s2, s3, seg_inputs, gi_feat = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, seg_mode = self.seg_mode, model_type = self.model_type) 205 | else: 206 | # print('segment gcn enabled') 207 | s1, s2, s3, seg_inputs, gi_feat = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, seg_mode = self.seg_mode, model_type = self.model_type) 208 | 209 | ''' ==== scene graph ==== ''' 210 | if self.model_type == 'amtl-t3' or self.model_type == 'mtl-t3': 211 | gr_int_feat = self.sf_avgpool(gi_feat).view(gi_feat.size(0), 128) 212 | 213 | edge_sum = 0 214 | global_spatial_feat = None 215 | 216 | for b_i, n in enumerate(node_num): 217 | active_edges = (n*(n-1)) if n >1 else n 218 | if global_spatial_feat == None: 219 | global_spatial_feat = torch.cat((spatial_feat[edge_sum:edge_sum+active_edges, :], gr_int_feat[b_i,:].repeat(active_edges,1)),1) 220 | else: 221 | global_spatial_feat = torch.cat((global_spatial_feat, torch.cat((spatial_feat[edge_sum:edge_sum+active_edges, :], gr_int_feat[b_i,:].repeat(active_edges,1)),1))) 222 | edge_sum += active_edges 223 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= validation) 224 | elif not self.Train_SG: 225 | ''' skip scene graph training ''' 226 | with torch.no_grad(): 227 | global_spatial_feat = spatial_feat 228 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= True) 229 | else: 230 | # print('interaction encoder enabled') 231 | global_spatial_feat = spatial_feat 232 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= validation) 233 | 234 | # ================================================================================================================================================================ 235 | # ================================================= Scene graph and segmentation GR Unit ========================================================================= 236 | 237 | ''' ============== Segmentation decoder ==============''' 238 | if not self.Train_SG_DECODER: 239 | ''' skip segmentation decoder ''' 240 | with torch.no_grad(): 241 | seg_inputs = self.seg_decoder(seg_inputs, s1 = s1, s2 = s2, s3 =s3, imsize = imsize, seg_mode = self.seg_mode) 242 | 243 | else: 244 | # print('segment_decoder_enabled') 245 | seg_inputs = self.seg_decoder(seg_inputs, s1 = s1, s2 = s2, s3 =s3, imsize = imsize, seg_mode = self.seg_mode) 246 | # ================================================================================================================================================================ 247 | 248 | return interaction, seg_inputs, fe_feat -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #from functools import lru_cache 2 | import os 3 | import time 4 | import json 5 | 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch import optim 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | 16 | from models.mtl_model import * 17 | from models.scene_graph import * 18 | from models.surgicalDataset import * 19 | from models.segmentation_model import get_gcnet # for the get_gcnet function 20 | 21 | from utils.scene_graph_eval_matrix import * 22 | from utils.segmentation_eval_matrix import * # SegmentationLoss and Eval code 23 | 24 | from tabulate import tabulate 25 | 26 | import torch.multiprocessing as mp 27 | import torch.distributed as dist 28 | from torch.nn.parallel import DistributedDataParallel as DDP 29 | 30 | import warnings 31 | warnings.filterwarnings('ignore') 32 | 33 | def label_to_index(lbl): 34 | ''' 35 | Label to index mapping 36 | Input: class label 37 | Output: class index 38 | ''' 39 | return torch.tensor(map_dict.index(lbl)) 40 | 41 | 42 | def index_to_label(index): 43 | ''' 44 | Index to label mapping 45 | Input: class index 46 | Output: class label 47 | ''' 48 | return map_dict[index] 49 | 50 | 51 | 52 | def seed_everything(seed=27): 53 | ''' 54 | Set random seed for reproducible experiments 55 | Inputs: seed number 56 | ''' 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | os.environ['PYTHONHASHSEED'] = str(seed) 60 | torch.backends.cudnn.deterministic = True 61 | torch.backends.cudnn.benchmark = False 62 | 63 | 64 | def seg_eval_batch(seg_output, target): 65 | ''' 66 | Calculate segmentation loss, pixel acc and IoU 67 | ''' 68 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2) 69 | loss = seg_criterion(seg_output, target) 70 | correct, labeled = batch_pix_accuracy(seg_output.data, target) 71 | inter, union = batch_intersection_union(seg_output.data, target, 8) # 8 is num classes 72 | return correct, labeled, inter, union, loss 73 | 74 | 75 | def build_model(args): 76 | ''' 77 | Build MTL model 78 | 1) Scene Graph Understanding Model 79 | 2) Segmentation Model : Encoder, Reasoning unit, Decoder 80 | 81 | Inputs: args 82 | ''' 83 | 84 | '''==== Graph model ====''' 85 | # graph model 86 | scene_graph = AGRNN(bias=True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, global_feat=args.global_feat) 87 | 88 | # segmentation model 89 | seg_model = get_gcnet(backbone='resnet18_model', pretrained=False) 90 | model = mtl_model(seg_model.pretrained, scene_graph, seg_model.gr_interaction, seg_model.gr_decoder, seg_mode = args.seg_mode) 91 | model.to(torch.device('cpu')) 92 | return model 93 | 94 | 95 | 96 | def model_eval(model, validation_dataloader, nclass=8): 97 | ''' 98 | Evaluate MTL 99 | ''' 100 | 101 | model.eval() 102 | 103 | class_values = np.zeros(nclass) 104 | 105 | # graph 106 | scene_graph_criterion = nn.MultiLabelSoftMarginLoss() 107 | scene_graph_edge_count = 0 108 | scene_graph_total_acc = 0.0 109 | scene_graph_total_loss = 0.0 110 | scene_graph_logits_list = [] 111 | scene_graph_labels_list = [] 112 | 113 | test_seg_loss = 0.0 114 | total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 115 | 116 | 117 | for data in tqdm(validation_dataloader): 118 | seg_img = data['img'] 119 | seg_masks = data['mask'] 120 | img_loc = data['img_loc'] 121 | node_num = data['node_num'] 122 | roi_labels = data['roi_labels'] 123 | det_boxes = data['det_boxes'] 124 | edge_labels = data['edge_labels'] 125 | spatial_feat = data['spatial_feat'] 126 | word2vec = data['word2vec'] 127 | 128 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True) 129 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True) 130 | 131 | with torch.no_grad(): 132 | interaction, seg_outputs, _ = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True) 133 | 134 | scene_graph_logits_list.append(interaction) 135 | scene_graph_labels_list.append(edge_labels) 136 | 137 | # loss and accuracy 138 | scene_graph_loss = scene_graph_criterion(interaction, edge_labels.float()) 139 | scene_graph_acc = np.sum(np.equal(np.argmax(interaction.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1))) 140 | correct, labeled, inter, union, t_loss = seg_eval_batch(seg_outputs, seg_masks) 141 | 142 | # accumulate scene graph loss and acc 143 | scene_graph_total_loss += scene_graph_loss.item() * edge_labels.shape[0] 144 | scene_graph_total_acc += scene_graph_acc 145 | scene_graph_edge_count += edge_labels.shape[0] 146 | 147 | total_correct += correct 148 | total_label += labeled 149 | total_inter += inter 150 | total_union += union 151 | test_seg_loss += t_loss.item() 152 | 153 | # graph evaluation 154 | scene_graph_total_acc = scene_graph_total_acc / scene_graph_edge_count 155 | scene_graph_total_loss = scene_graph_total_loss / len(validation_dataloader) 156 | scene_graph_logits_all = torch.cat(scene_graph_logits_list).cuda() 157 | scene_graph_labels_all = torch.cat(scene_graph_labels_list).cuda() 158 | scene_graph_logits_all = F.softmax(scene_graph_logits_all, dim=1) 159 | scene_graph_map_value, scene_graph_recall = calibration_metrics(scene_graph_logits_all, scene_graph_labels_all) 160 | 161 | # segmentation evaluation 162 | pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) 163 | IoU = 1.0 * total_inter / (np.spacing(1) + total_union) 164 | class_values += IoU 165 | mIoU = IoU.mean() 166 | 167 | print('\n================= Evaluation ====================') 168 | print('Graph : acc: %0.4f map: %0.4f recall: %0.4f loss: %0.4f}' % (scene_graph_total_acc, scene_graph_map_value, scene_graph_recall, scene_graph_total_loss)) 169 | print('Segmentation : Pacc: %0.4f mIoU: %0.4f loss: %0.4f}' % (pixAcc, mIoU, test_seg_loss/len(validation_dataloader))) 170 | 171 | print('\n================= Class-wise IoU ====================') 172 | class_wise_IoU = [] 173 | m_vals = [] 174 | for idx, value in enumerate(class_values): 175 | class_name = index_to_label(idx) 176 | pair = [class_name, value] 177 | m_vals.append(value) 178 | class_wise_IoU.append(pair) 179 | 180 | print("Mean Value: ", np.mean(np.array(m_vals)), "\n") 181 | 182 | print(tabulate(class_wise_IoU, 183 | headers=['Class', 'IoU'], tablefmt='orgtbl')) 184 | 185 | return(scene_graph_total_acc, scene_graph_map_value, mIoU) 186 | 187 | 188 | if __name__ == "__main__": 189 | 190 | ''' 191 | Main function to set arguments 192 | ''' 193 | 194 | ''' 195 | To reproduce the results, set the model_type, ver, seg_mode and checkpoint_dir based on the table below 196 | TBR = To be released 197 | ============================================================================================================ 198 | Paper_name | model_type | ver | seg_mode | checkpoint_dir 199 | ============================================================================================================ 200 | STL 201 | ------------------------|----------------------------------------------------------------------------------- 202 | VS-GAT | 'stl-sg' | 'stl_sg' | None | 'stl_sg' 203 | SEG | 'stl-s' | 'stl_s_ng' | TBR | 'stl_s_ng' 204 | SEG-GR | 'stl-s' | 'stl_s' | None | 'stl_s' 205 | SEG-MSGR | 'stl-s' | 'stl_s_v2gc' | 'v2gc' | 'stl_s_v2gc' 206 | SEG-MSLRGR | 'stl-s' | 'stl_s_v1' | 'v1' | 'stl_s_v1' 207 | ------------------------------------------------------------------------------------------------------------ 208 | SMTL 209 | ------------------------------------------------------------------------------------------------------------ 210 | GR | 'amtl-t0' | 'amtl_t0_s' | None | 'amtl_t0_s' 211 | MSGR | 'amtl-t0' | 'amtl_t0_sv2gc' | 'v2gc' | 'amtl_t0_sv2gc' 212 | MSLRGR | 'amtl-t0' | 'amtl_t0_sv1' | 'v1' | 'amtl_t0_sv1' 213 | MSLRGR-GISFSG | 'amtl-t3' | 'amtl_t3pn_sv1' | 'v1' | 'amtl_t3pn_sv1' 214 | ------------------------------------------------------------------------------------------------------------ 215 | v-MTL 216 | ------------------------------------------------------------------------------------------------------------ 217 | V-MTL-GR | 'mtl-t0' | 'mtl_t0_s | None | 'mtl_t0_s' 218 | ------------------------------------------------------------------------------------------------------------ 219 | KD-MTL (set args.KD = True) 220 | ------------------------------------------------------------------------------------------------------------ 221 | KD-MTL-GR | 'mtl-t0' | 'mtl_kd_t0_s' | None | TBR 222 | KD-MTL-MSLRGR | 'mtl-t0' | 'mtl_kd_t0_sv1' | 'v1' | 'mtl_kd_t0_sv1' 223 | KD-MTL-MSLRGR-SGFSEG | 'mtl-t1' | 'mtl_kd_t1_sv1' | 'v1' | 'mtl_kd_t1_sv1' 224 | KD-MTL-MSLRGR-GISFSG | 'mtl-t3' | 'mtl_kd_t3_sv1' | 'v1' | 'mtl_kd_t3_sv1' 225 | ------------------------------------------------------------------------------------------------------------ 226 | ''' 227 | 228 | model_type = 'amtl-t3' 229 | ver = 'amtl_t3_sv1' 230 | seg_mode = 'v1' 231 | checkpoint_dir = 'amtl_t3_sv1' 232 | 233 | port = '8892' 234 | 235 | # Set random seed 236 | seed_everything() 237 | print(ver, seg_mode) 238 | 239 | # arguments 240 | parser = argparse.ArgumentParser(description='GR_MTL_SSU') 241 | 242 | # hyper parameters 243 | parser.add_argument('--lr', type=float, default = 0.00001) #0.00001 244 | parser.add_argument('--epoch', type=int, default = 130) 245 | parser.add_argument('--start_epoch', type=int, default = 0) 246 | parser.add_argument('--batch_size', type=int, default = 1) 247 | parser.add_argument('--gpu', type=bool, default = True) 248 | parser.add_argument('--train_model', type=str, default = 'epoch') 249 | parser.add_argument('--exp_ver', type=str, default = ver) 250 | 251 | # file locations 252 | parser.add_argument('--log_dir', type=str, default = './log/' + ver) 253 | parser.add_argument('--save_dir', type=str, default = './checkpoints/' + ver) 254 | parser.add_argument('--output_img_dir', type=str, default = './results/' + ver) 255 | parser.add_argument('--save_every', type=int, default = 10) 256 | parser.add_argument('--pretrained', type=str, default = None) 257 | 258 | # network 259 | parser.add_argument('--layers', type=int, default = 1) 260 | parser.add_argument('--bn', type=bool, default = False) 261 | parser.add_argument('--drop_prob', type=float, default = 0.3) 262 | parser.add_argument('--bias', type=bool, default = True) 263 | parser.add_argument('--multi_attn', type=bool, default = False) 264 | parser.add_argument('--diff_edge', type=bool, default = False) 265 | if model_type == 'mtl-t3' or model_type == 'amtl-t3': 266 | parser.add_argument('--global_feat', type=int, default = 128) 267 | else: 268 | parser.add_argument('--global_feat', type=int, default = 0) 269 | # data_processing 270 | parser.add_argument('--sampler', type=int, default = 0) 271 | parser.add_argument('--data_aug', type=bool, default = False) 272 | parser.add_argument('--feature_extractor', type=str, default = 'features') 273 | parser.add_argument('--seg_mode', type=str, default = seg_mode) 274 | 275 | # CBS 276 | parser.add_argument('--use_cbs', type=bool, default = False) 277 | 278 | # Knowledge distillation 279 | parser.add_argument('--KD', type=bool, default = False) 280 | 281 | parser.add_argument('--model', type=str, default = model_type) 282 | args = parser.parse_args() 283 | 284 | # seed_everything() 285 | data_const = SurgicalSceneConstants() 286 | 287 | label_path = 'dataset/labels_isi_dataset.json' 288 | with open(label_path) as f: 289 | labels = json.load(f) 290 | 291 | CLASSES = [] 292 | CLASS_ID = [] 293 | 294 | for item in labels: 295 | CLASSES.append(item['name']) 296 | CLASS_ID.append(item['classid']) 297 | 298 | map_dict = {k: v for k, v in zip(CLASS_ID, CLASSES)} 299 | 300 | # this is placed above the dist.init process, possibility because of the feature_extraction model. 301 | model = build_model(args) 302 | model.set_train_test(args.model) 303 | 304 | # insert nn layers based on type. 305 | if args.model == 'amtl-t1' or args.model == 'mtl-t1': 306 | model.model_type1_insert() 307 | elif args.model == 'amtl-t2' or args.model == 'mtl-t2': 308 | model.model_type2_insert() 309 | elif args.model == 'amtl-t3' or args.model == 'mtl-t3': 310 | model.model_type3_insert() 311 | 312 | # load pre-trained stl_mtl_model 313 | print('Loading pre-trained weights') 314 | pretrained_model = torch.load(('checkpoints/'+checkpoint_dir+'/best_epoch.pth')) 315 | model.load_state_dict(pretrained_model) 316 | 317 | # Wrap the model with ddp 318 | model.cuda() 319 | 320 | # train and test dataloader 321 | val_seq = [[1, 5, 16]] 322 | data_dir = ['dataset/seq_'] 323 | img_dir = ['/left_frames/'] 324 | mask_dir = ['/annotations/'] 325 | dset = [0] 326 | data_const = SurgicalSceneConstants() 327 | 328 | seq = {'val_seq': val_seq, 'data_dir': data_dir, 'img_dir': img_dir, 'dset': dset, 'mask_dir': mask_dir} 329 | 330 | # val_dataset only set in 1 GPU 331 | val_dataset = SurgicalSceneDataset(seq_set=seq['val_seq'], dset=seq['dset'], data_dir=seq['data_dir'], \ 332 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], istrain=False, dataconst=data_const, \ 333 | feature_extractor=args.feature_extractor, reduce_size=False) 334 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) 335 | 336 | model_eval(model, val_dataloader) 337 | -------------------------------------------------------------------------------- /models/scene_graph.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding 3 | Lab : MMLAB, National University of Singapore 4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren 5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation 6 | Visual-Semantic Graph Network: 7 | @article{liang2020visual, 8 | title={Visual-Semantic Graph Attention Networks for Human-Object Interaction Detection}, 9 | author={Liang, Zhijun and Rojas, Juan and Liu, Junfa and Guan, Yisheng}, 10 | journal={arXiv preprint arXiv:2001.02302}, 11 | year={2020} 12 | } 13 | ''' 14 | 15 | 16 | import dgl 17 | import math 18 | import numpy as np 19 | 20 | import torch 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | from collections import OrderedDict 27 | 28 | ''' 29 | Configurations of the network 30 | 31 | readout: G_ER_L_S = [1024+300+16+300+1024, 1024, 117] 32 | 33 | node_func: G_N_L_S = [1024+1024, 1024] 34 | node_lang_func: G_N_L_S2 = [300+300+300] 35 | 36 | edge_func : G_E_L_S = [1024*2+16, 1024] 37 | edge_lang_func: [300*2, 1024] 38 | 39 | attn: [1024, 1] 40 | attn_lang: [1024, 1] 41 | ''' 42 | class CONFIGURATION(object): 43 | ''' 44 | Configuration arguments: feature type, layer, bias, batch normalization, dropout, multi-attn 45 | 46 | readout : fc_size, activation, bias, bn, droupout 47 | gnn_node : fc_size, activation, bias, bn, droupout 48 | gnn_node_for_lang : fc_size, activation, bias, bn, droupout 49 | gnn_edge : fc_size, activation, bias, bn, droupout 50 | gnn_edge_for_lang : fc_size, activation, bias, bn, droupout 51 | gnn_attn : fc_size, activation, bias, bn, droupout 52 | gnn_attn_for_lang : fc_size, activation, bias, bn, droupout 53 | ''' 54 | def __init__(self, layer=1, bias=True, bn=False, dropout=0.2, multi_attn=False, global_feat = 0): 55 | 56 | # if multi_attn: 57 | if True: 58 | if layer==1: 59 | feature_size = 512 60 | additional_sf = global_feat 61 | # readout 62 | self.G_ER_L_S = [feature_size+300+16+additional_sf+300+feature_size, feature_size, 13] 63 | self.G_ER_A = ['ReLU', 'Identity'] 64 | self.G_ER_B = bias #true 65 | self.G_ER_BN = bn #false 66 | self.G_ER_D = dropout #0.3 67 | # self.G_ER_GRU = feature_size 68 | 69 | # # gnn node function 70 | self.G_N_L_S = [feature_size+feature_size, feature_size] 71 | self.G_N_A = ['ReLU'] 72 | self.G_N_B = bias #true 73 | self.G_N_BN = bn #false 74 | self.G_N_D = dropout #0.3 75 | # self.G_N_GRU = feature_size 76 | 77 | # # gnn node function for language 78 | self.G_N_L_S2 = [300+300, 300] 79 | self.G_N_A2 = ['ReLU'] 80 | self.G_N_B2 = bias #true 81 | self.G_N_BN2 = bn #false 82 | self.G_N_D2 = dropout #0.3 83 | # self.G_N_GRU2 = feature_size 84 | 85 | # gnn edge function1 86 | self.G_E_L_S = [feature_size*2+16+additional_sf, feature_size] 87 | self.G_E_A = ['ReLU'] 88 | self.G_E_B = bias # true 89 | self.G_E_BN = bn # false 90 | self.G_E_D = dropout # 0.3 91 | # self.G_E_c_kernel_size = 3 92 | 93 | 94 | # gnn edge function2 for language 95 | self.G_E_L_S2 = [300*2, feature_size] 96 | self.G_E_A2 = ['ReLU'] 97 | self.G_E_B2 = bias #true 98 | self.G_E_BN2 = bn #false 99 | self.G_E_D2 = dropout #0.3 100 | 101 | # gnn attention mechanism 102 | self.G_A_L_S = [feature_size, 1] 103 | self.G_A_A = ['LeakyReLU'] 104 | self.G_A_B = bias #true 105 | self.G_A_BN = bn #false 106 | self.G_A_D = dropout #0.3 107 | 108 | # gnn attention mechanism2 for language 109 | self.G_A_L_S2 = [feature_size, 1] 110 | self.G_A_A2 = ['LeakyReLU'] 111 | self.G_A_B2 = bias #true 112 | self.G_A_BN2 = bn #false 113 | self.G_A_D2 = dropout #0.3 114 | 115 | def save_config(self): 116 | model_config = {'graph_head':{}, 'graph_node':{}, 'graph_edge':{}, 'graph_attn':{}} 117 | CONFIG=self.__dict__ 118 | for k, v in CONFIG.items(): 119 | if 'G_H' in k: 120 | model_config['graph_head'][k]=v 121 | elif 'G_N' in k: 122 | model_config['graph_node'][k]=v 123 | elif 'G_E' in k: 124 | model_config['graph_edge'][k]=v 125 | elif 'G_A' in k: 126 | model_config['graph_attn'][k]=v 127 | else: 128 | model_config[k]=v 129 | 130 | return model_config 131 | 132 | 133 | class Identity(nn.Module): 134 | ''' 135 | Identity class activation layer 136 | f(x) = x 137 | ''' 138 | def __init__(self): 139 | super(Identity,self).__init__() 140 | 141 | def forward(self, x): 142 | return x 143 | 144 | def get_activation(name): 145 | ''' 146 | get_activation function 147 | argument: Activation name (eg. ReLU, Identity, Tanh, Sigmoid, LeakyReLU) 148 | ''' 149 | if name=='ReLU': return nn.ReLU(inplace=True) 150 | elif name=='Identity': return Identity() 151 | elif name=='Tanh': return nn.Tanh() 152 | elif name=='Sigmoid': return nn.Sigmoid() 153 | elif name=='LeakyReLU': return nn.LeakyReLU(0.2,inplace=True) 154 | else: assert(False), 'Not Implemented' 155 | 156 | 157 | class MLP(nn.Module): 158 | ''' 159 | Args: 160 | layer_sizes: a list, [1024,1024,...] 161 | activation: a list, ['ReLU', 'Tanh',...] 162 | bias : bool 163 | use_bn: bool 164 | drop_prob: default is None, use drop out layer or not 165 | ''' 166 | def __init__(self, layer_sizes, activation, bias=True, use_bn=False, drop_prob=None): 167 | super(MLP, self).__init__() 168 | self.bn = use_bn 169 | self.layers = nn.ModuleList() 170 | for i in range(len(layer_sizes)-1): 171 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=bias) 172 | activate = get_activation(activation[i]) 173 | block = nn.Sequential(OrderedDict([(f'L{i}', layer), ])) 174 | 175 | # !NOTE:# Actually, it is inappropriate to use batch-normalization here 176 | if use_bn: 177 | bn = nn.BatchNorm1d(layer_sizes[i+1]) 178 | block.add_module(f'B{i}', bn) 179 | 180 | # batch normalization is put before activation function 181 | block.add_module(f'A{i}', activate) 182 | 183 | # dropout probablility 184 | if drop_prob: 185 | block.add_module(f'D{i}', nn.Dropout(drop_prob)) 186 | 187 | self.layers.append(block) 188 | 189 | def forward(self, x): 190 | for layer in self.layers: 191 | # !NOTE: Sometime the shape of x will be [1,N], and we cannot use batch-normalization in that situation 192 | if self.bn and x.shape[0]==1: 193 | x = layer[0](x) 194 | x = layer[:-1](x) 195 | else: 196 | x = layer(x) 197 | return x 198 | 199 | 200 | class H_H_EdgeApplyModule(nn.Module): #Human to Human edge 201 | ''' 202 | init : config, multi_attn 203 | forward : edge 204 | ''' 205 | def __init__(self, CONFIG, multi_attn=False): 206 | super(H_H_EdgeApplyModule, self).__init__() 207 | self.edge_fc = MLP(CONFIG.G_E_L_S, CONFIG.G_E_A, CONFIG.G_E_B, CONFIG.G_E_BN, CONFIG.G_E_D) 208 | self.edge_fc_lang = MLP(CONFIG.G_E_L_S2, CONFIG.G_E_A2, CONFIG.G_E_B2, CONFIG.G_E_BN2, CONFIG.G_E_D2) 209 | 210 | def forward(self, edge): 211 | feat = torch.cat([edge.src['n_f'], edge.data['s_f'], edge.dst['n_f']], dim=1) 212 | feat_lang = torch.cat([edge.src['word2vec'], edge.dst['word2vec']], dim=1) 213 | e_feat = self.edge_fc(feat) 214 | e_feat_lang = self.edge_fc_lang(feat_lang) 215 | 216 | return {'e_f': e_feat, 'e_f_lang': e_feat_lang} 217 | 218 | 219 | 220 | class H_NodeApplyModule(nn.Module): #human node 221 | ''' 222 | init : config 223 | forward : node 224 | ''' 225 | def __init__(self, CONFIG): 226 | super(H_NodeApplyModule, self).__init__() 227 | self.node_fc = MLP(CONFIG.G_N_L_S, CONFIG.G_N_A, CONFIG.G_N_B, CONFIG.G_N_BN, CONFIG.G_N_D) 228 | self.node_fc_lang = MLP(CONFIG.G_N_L_S2, CONFIG.G_N_A2, CONFIG.G_N_B2, CONFIG.G_N_BN2, CONFIG.G_N_D2) 229 | 230 | def forward(self, node): 231 | feat = torch.cat([node.data['n_f'], node.data['z_f']], dim=1) 232 | feat_lang = torch.cat([node.data['word2vec'], node.data['z_f_lang']], dim=1) 233 | n_feat = self.node_fc(feat) 234 | n_feat_lang = self.node_fc_lang(feat_lang) 235 | 236 | return {'new_n_f': n_feat, 'new_n_f_lang': n_feat_lang} 237 | 238 | 239 | class E_AttentionModule1(nn.Module): #edge attention 240 | ''' 241 | init : config 242 | forward : edge 243 | ''' 244 | def __init__(self, CONFIG): 245 | super(E_AttentionModule1, self).__init__() 246 | self.attn_fc = MLP(CONFIG.G_A_L_S, CONFIG.G_A_A, CONFIG.G_A_B, CONFIG.G_A_BN, CONFIG.G_A_D) 247 | self.attn_fc_lang = MLP(CONFIG.G_A_L_S2, CONFIG.G_A_A2, CONFIG.G_A_B2, CONFIG.G_A_BN2, CONFIG.G_A_D2) 248 | 249 | def forward(self, edge): 250 | a_feat = self.attn_fc(edge.data['e_f']) 251 | a_feat_lang = self.attn_fc_lang(edge.data['e_f_lang']) 252 | return {'a_feat': a_feat, 'a_feat_lang': a_feat_lang} 253 | 254 | 255 | class GNN(nn.Module): 256 | ''' 257 | init : config, multi_attn, diff_edge 258 | forward : g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_features 259 | ''' 260 | def __init__(self, CONFIG, multi_attn=False, diff_edge=True): 261 | super(GNN, self).__init__() 262 | self.diff_edge = diff_edge # false 263 | self.apply_h_h_edge = H_H_EdgeApplyModule(CONFIG, multi_attn) 264 | self.apply_edge_attn1 = E_AttentionModule1(CONFIG) 265 | self.apply_h_node = H_NodeApplyModule(CONFIG) 266 | 267 | def _message_func(self, edges): 268 | return {'nei_n_f': edges.src['n_f'], 'nei_n_w': edges.src['word2vec'], 'e_f': edges.data['e_f'], 'e_f_lang': edges.data['e_f_lang'], 'a_feat': edges.data['a_feat'], 'a_feat_lang': edges.data['a_feat_lang']} 269 | 270 | def _reduce_func(self, nodes): 271 | alpha = F.softmax(nodes.mailbox['a_feat'], dim=1) 272 | alpha_lang = F.softmax(nodes.mailbox['a_feat_lang'], dim=1) 273 | 274 | z_raw_f = nodes.mailbox['nei_n_f']+nodes.mailbox['e_f'] 275 | z_f = torch.sum( alpha * z_raw_f, dim=1) 276 | 277 | z_raw_f_lang = nodes.mailbox['nei_n_w'] 278 | z_f_lang = torch.sum(alpha_lang * z_raw_f_lang, dim=1) 279 | 280 | # we cannot return 'alpha' for the different dimension 281 | if self.training or validation: return {'z_f': z_f, 'z_f_lang': z_f_lang} 282 | else: return {'z_f': z_f, 'z_f_lang': z_f_lang, 'alpha': alpha, 'alpha_lang': alpha_lang} 283 | 284 | def forward(self, g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_feat=False): 285 | 286 | g.apply_edges(self.apply_h_h_edge, g.edges()) 287 | g.apply_edges(self.apply_edge_attn1) 288 | g.update_all(self._message_func, self._reduce_func) 289 | g.apply_nodes(self.apply_h_node, h_node+o_node) 290 | 291 | # !NOTE:PAY ATTENTION WHEN ADDING MORE FEATURE 292 | g.ndata.pop('n_f') 293 | g.ndata.pop('word2vec') 294 | 295 | g.ndata.pop('z_f') 296 | g.edata.pop('e_f') 297 | g.edata.pop('a_feat') 298 | 299 | g.ndata.pop('z_f_lang') 300 | g.edata.pop('e_f_lang') 301 | g.edata.pop('a_feat_lang') 302 | 303 | 304 | class GRNN(nn.Module): 305 | ''' 306 | init: 307 | config, multi_attn, diff_edge 308 | forward: 309 | batch_graph, batch_h_node_list, batch_obj_node_list, 310 | batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, 311 | features, spatial_features, word2vec, 312 | valid, pop_features, initial_features 313 | ''' 314 | def __init__(self, CONFIG, multi_attn=False, diff_edge=True): 315 | super(GRNN, self).__init__() 316 | self.multi_attn = multi_attn #false 317 | self.gnn = GNN(CONFIG, multi_attn, diff_edge) 318 | 319 | def forward(self, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, valid=False, pop_feat=False, initial_feat=False): 320 | 321 | # !NOTE: if node_num==1, there will be something wrong to forward the attention mechanism 322 | global validation 323 | validation = valid 324 | 325 | # initialize the graph with some datas 326 | batch_graph.ndata['n_f'] = feat # node: features 327 | batch_graph.ndata['word2vec'] = word2vec # node: words 328 | batch_graph.edata['s_f'] = spatial_feat # edge: spatial features 329 | 330 | try: 331 | self.gnn(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list) 332 | except Exception as e: 333 | print(e) 334 | 335 | 336 | class Predictor(nn.Module): 337 | ''' 338 | init : config 339 | forward : edge 340 | ''' 341 | def __init__(self, CONFIG): 342 | super(Predictor, self).__init__() 343 | self.classifier = MLP(CONFIG.G_ER_L_S, CONFIG.G_ER_A, CONFIG.G_ER_B, CONFIG.G_ER_BN, CONFIG.G_ER_D) 344 | self.sigmoid = nn.Sigmoid() 345 | 346 | def forward(self, edge): 347 | feat = torch.cat([edge.dst['new_n_f'], edge.dst['new_n_f_lang'], edge.data['s_f'], edge.src['new_n_f_lang'], edge.src['new_n_f']], dim=1) 348 | scene_feat = torch.cat([edge.dst['new_n_f'], edge.src['new_n_f'],edge.data['s_f']], dim=1) 349 | pred = self.classifier(feat) 350 | # If the criterion is BCELoss, uncomment the following code -> 351 | # output = self.sigmoid(output) 352 | return {'pred': pred, 'scene_feat': scene_feat} 353 | 354 | 355 | class AGRNN(nn.Module): 356 | ''' 357 | init : 358 | feature_type, bias, bn, dropout, multi_attn, layer, diff_edge 359 | 360 | forward : 361 | node_num, features, spatial_features, word2vec, roi_label, 362 | validation, choose_nodes, remove_nodes 363 | ''' 364 | def __init__(self, bias=True, bn=True, dropout=None, multi_attn=False, layer=1, diff_edge=True, global_feat = 0): 365 | super(AGRNN, self).__init__() 366 | 367 | self.multi_attn = multi_attn # false 368 | self.layer = layer # 1 layer 369 | self.diff_edge = diff_edge # false 370 | 371 | self.CONFIG1 = CONFIGURATION(layer=1, bias=bias, bn=bn, dropout=dropout, multi_attn=multi_attn, global_feat=global_feat) 372 | 373 | self.grnn1 = GRNN(self.CONFIG1, multi_attn=multi_attn, diff_edge=diff_edge) 374 | self.edge_readout = Predictor(self.CONFIG1) 375 | 376 | def _collect_edge(self, node_num, roi_label, node_space, diff_edge): 377 | ''' 378 | arguments: node_num, roi_label, node_space, diff_edge 379 | ''' 380 | 381 | # get human nodes && object nodes 382 | h_node_list = np.where(roi_label == 0)[0] 383 | obj_node_list = np.where(roi_label != 0)[0] 384 | edge_list = [] 385 | 386 | h_h_e_list = [] 387 | o_o_e_list = [] 388 | h_o_e_list = [] 389 | 390 | readout_edge_list = [] 391 | readout_h_h_e_list = [] 392 | readout_h_o_e_list = [] 393 | 394 | # get all edge in the fully-connected graph, edge_list, For node_num = 2, edge_list = [(0, 1), (1, 0)] 395 | for src in range(node_num): 396 | for dst in range(node_num): 397 | if src == dst: 398 | continue 399 | else: 400 | edge_list.append((src, dst)) 401 | 402 | # readout_edge_list, get corresponding readout edge in the graph 403 | src_box_list = np.arange(roi_label.shape[0]) 404 | for dst in h_node_list: 405 | for src in src_box_list: 406 | if src not in h_node_list: 407 | readout_edge_list.append((src, dst)) 408 | 409 | # readout h_h_e_list, get corresponding readout h_h edges && h_o edges 410 | temp_h_node_list = h_node_list[:] 411 | for dst in h_node_list: 412 | if dst == h_node_list.shape[0]-1: 413 | continue 414 | temp_h_node_list = temp_h_node_list[1:] 415 | for src in temp_h_node_list: 416 | if src == dst: continue 417 | readout_h_h_e_list.append((src, dst)) 418 | 419 | # readout h_o_e_list 420 | readout_h_o_e_list = [x for x in readout_edge_list if x not in readout_h_h_e_list] 421 | 422 | # add node space to match the batch graph 423 | h_node_list = (np.array(h_node_list)+node_space).tolist() 424 | obj_node_list = (np.array(obj_node_list)+node_space).tolist() 425 | 426 | h_h_e_list = (np.array(h_h_e_list)+node_space).tolist() #empty no diff_edge 427 | o_o_e_list = (np.array(o_o_e_list)+node_space).tolist() #empty no diff_edge 428 | h_o_e_list = (np.array(h_o_e_list)+node_space).tolist() #empty no diff_edge 429 | 430 | readout_h_h_e_list = (np.array(readout_h_h_e_list)+node_space).tolist() 431 | readout_h_o_e_list = (np.array(readout_h_o_e_list)+node_space).tolist() 432 | readout_edge_list = (np.array(readout_edge_list)+node_space).tolist() 433 | 434 | return edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list 435 | 436 | def _build_graph(self, node_num, roi_label, node_space, diff_edge): 437 | ''' 438 | Declare graph, add_nodes, collect edges, add_edges 439 | ''' 440 | graph = dgl.DGLGraph() 441 | graph.add_nodes(node_num) 442 | 443 | edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._collect_edge(node_num, roi_label, node_space, diff_edge) 444 | src, dst = tuple(zip(*edge_list)) 445 | graph.add_edges(src, dst) # make the graph bi-directional 446 | 447 | return graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list 448 | 449 | def forward(self, node_num=None, feat=None, spatial_feat=None, word2vec=None, roi_label=None, validation=False, choose_nodes=None, remove_nodes=None): 450 | 451 | batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, batch_readout_edge_list, batch_readout_h_h_e_list, batch_readout_h_o_e_list = [], [], [], [], [], [], [], [], [] 452 | node_num_cum = np.cumsum(node_num) # !IMPORTANT 453 | 454 | for i in range(len(node_num)): 455 | # set node space 456 | node_space = 0 457 | if i != 0: 458 | node_space = node_num_cum[i-1] 459 | graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._build_graph(node_num[i], roi_label[i], node_space, diff_edge=self.diff_edge) 460 | 461 | # update batch 462 | batch_graph.append(graph) 463 | batch_h_node_list += h_node_list 464 | batch_obj_node_list += obj_node_list 465 | 466 | batch_h_h_e_list += h_h_e_list 467 | batch_o_o_e_list += o_o_e_list 468 | batch_h_o_e_list += h_o_e_list 469 | 470 | batch_readout_edge_list += readout_edge_list 471 | batch_readout_h_h_e_list += readout_h_h_e_list 472 | batch_readout_h_o_e_list += readout_h_o_e_list 473 | 474 | batch_graph = dgl.batch(batch_graph) 475 | 476 | # GRNN 477 | self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, initial_feat=True) 478 | batch_graph.apply_edges(self.edge_readout, tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list)))) 479 | 480 | if self.training or validation: 481 | # !NOTE: cannot use "batch_readout_h_o_e_list+batch_readout_h_h_e_list" because of the wrong order 482 | return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \ 483 | batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['scene_feat'] 484 | else: 485 | return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \ 486 | batch_graph.nodes[batch_h_node_list].data['alpha'], \ 487 | batch_graph.nodes[batch_h_node_list].data['alpha_lang'] 488 | -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding 3 | Lab : MMLAB, National University of Singapore 4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren 5 | ''' 6 | 7 | import os 8 | import time 9 | 10 | import argparse 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch import optim 17 | import torch.nn.functional as F 18 | from torch.utils.data import DataLoader 19 | 20 | from models.mtl_model import * 21 | from models.scene_graph import * 22 | from models.surgicalDataset import * 23 | from models.segmentation_model import get_gcnet 24 | 25 | from utils.scene_graph_eval_matrix import * 26 | from utils.segmentation_eval_matrix import * 27 | 28 | 29 | import torch.multiprocessing as mp 30 | import torch.distributed as dist 31 | from torch.nn.parallel import DistributedDataParallel as DDP 32 | 33 | 34 | def seed_everything(seed=27): 35 | ''' 36 | Set random seed for reproducible experiments 37 | Inputs: seed number 38 | ''' 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | 45 | 46 | def seg_eval_batch(seg_output, target): 47 | ''' 48 | Calculate segmentation loss, pixel acc and IoU 49 | Inputs: predicted segmentation mask, GT segmentation mask 50 | ''' 51 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2) 52 | loss = seg_criterion(seg_output, target) 53 | correct, labeled = batch_pix_accuracy(seg_output.data, target) 54 | inter, union = batch_intersection_union(seg_output.data, target, 8) # 8 is num classes 55 | return correct, labeled, inter, union, loss 56 | 57 | def get_checkpoint_loc(model_type, seg_mode = None): 58 | loc = None 59 | if model_type == 'amtl-t0' or model_type == 'amtl-t3': 60 | if seg_mode is None: 61 | loc = 'checkpoints/stl_s/stl_s/epoch_train/checkpoint_D153_epoch.pth' 62 | elif seg_mode == 'v1': 63 | loc = 'checkpoints/stl_s_v1/stl_s_v1/epoch_train/checkpoint_D168_epoch.pth' 64 | elif seg_mode == 'v2_gc': 65 | loc = 'checkpoints/stl_s_v2_gc/stl_s_v2_gc/epoch_train/checkpoint_D168_epoch.pth' 66 | elif model_type == 'amtl-t1': 67 | loc = 'checkpoints/stl_s/stl_s/epoch_train/checkpoint_D168_epoch.pth' 68 | elif model_type == 'amtl-t2': 69 | loc = 'checkpoints/stl_sg_wfe/stl_sg_wfe/epoch_train/checkpoint_D110_epoch.pth' 70 | return loc 71 | 72 | def build_model(args): 73 | ''' 74 | Build MTL model 75 | 1) Scene Graph Understanding Model 76 | 2) Segmentation Model : Encoder, Reasoning unit, Decoder 77 | 78 | Inputs: args 79 | ''' 80 | 81 | '''==== Graph model ====''' 82 | # graph model 83 | scene_graph = AGRNN(bias=True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, global_feat=args.global_feat) 84 | 85 | # segmentation model 86 | seg_model = get_gcnet(backbone='resnet18_model', pretrained=True) 87 | model = mtl_model(seg_model.pretrained, scene_graph, seg_model.gr_interaction, seg_model.gr_decoder, seg_mode = args.seg_mode) 88 | model.to(torch.device('cpu')) 89 | return model 90 | 91 | 92 | def model_eval(args, model, validation_dataloader): 93 | ''' 94 | Evaluate function for the MTL model (Segmentation and Scene Graph Performance) 95 | Inputs: args, model, val-dataloader 96 | 97 | ''' 98 | 99 | model.eval() 100 | 101 | # graph 102 | scene_graph_criterion = nn.MultiLabelSoftMarginLoss() 103 | scene_graph_edge_count = 0 104 | scene_graph_total_acc = 0.0 105 | scene_graph_total_loss = 0.0 106 | scene_graph_logits_list = [] 107 | scene_graph_labels_list = [] 108 | 109 | test_seg_loss = 0.0 110 | total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 111 | 112 | for data in tqdm(validation_dataloader): 113 | seg_img = data['img'] 114 | seg_masks = data['mask'] 115 | img_loc = data['img_loc'] 116 | node_num = data['node_num'] 117 | roi_labels = data['roi_labels'] 118 | det_boxes = data['det_boxes'] 119 | edge_labels = data['edge_labels'] 120 | spatial_feat = data['spatial_feat'] 121 | word2vec = data['word2vec'] 122 | 123 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True) 124 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True) 125 | 126 | with torch.no_grad(): 127 | interaction, seg_outputs, _ = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True) 128 | 129 | scene_graph_logits_list.append(interaction) 130 | scene_graph_labels_list.append(edge_labels) 131 | 132 | # Loss and accuracy 133 | scene_graph_loss = scene_graph_criterion(interaction, edge_labels.float()) 134 | scene_graph_acc = np.sum(np.equal(np.argmax(interaction.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1))) 135 | correct, labeled, inter, union, t_loss = seg_eval_batch(seg_outputs, seg_masks) 136 | 137 | # Accumulate scene graph loss and acc 138 | scene_graph_total_loss += scene_graph_loss.item() * edge_labels.shape[0] 139 | scene_graph_total_acc += scene_graph_acc 140 | scene_graph_edge_count += edge_labels.shape[0] 141 | 142 | total_correct += correct 143 | total_label += labeled 144 | total_inter += inter 145 | total_union += union 146 | test_seg_loss += t_loss.item() 147 | 148 | # Graph evaluation 149 | scene_graph_total_acc = scene_graph_total_acc / scene_graph_edge_count 150 | scene_graph_total_loss = scene_graph_total_loss / len(validation_dataloader) 151 | scene_graph_logits_all = torch.cat(scene_graph_logits_list).cuda() 152 | scene_graph_labels_all = torch.cat(scene_graph_labels_list).cuda() 153 | scene_graph_logits_all = F.softmax(scene_graph_logits_all, dim=1) 154 | scene_graph_map_value, scene_graph_recall = calibration_metrics(scene_graph_logits_all, scene_graph_labels_all) 155 | 156 | # Segmentation evaluation 157 | pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label) 158 | IoU = 1.0 * total_inter / (np.spacing(1) + total_union) 159 | mIoU = IoU.mean() 160 | 161 | print('================= Evaluation ====================') 162 | print('Graph : acc: %0.4f map: %0.4f recall: %0.4f loss: %0.4f}' % (scene_graph_total_acc, scene_graph_map_value, scene_graph_recall, scene_graph_total_loss)) 163 | print('Segmentation : Pacc: %0.4f mIoU: %0.4f loss: %0.4f}' % (pixAcc, mIoU, test_seg_loss/len(validation_dataloader))) 164 | return(scene_graph_total_acc, scene_graph_map_value, mIoU) 165 | 166 | 167 | def train_model(gpu, args): 168 | ''' 169 | Train function for the MTL model 170 | Inputs: number of gpus per node, args 171 | 172 | ''' 173 | # Store best value and epoch number 174 | best_value = [0.0, 0.0, 0.0] 175 | best_epoch = [0, 0, 0] 176 | 177 | # Decaying learning rate 178 | decay_lr = args.lr 179 | 180 | # This is placed above the dist.init process, because of the feature_extraction model. 181 | model = build_model(args) 182 | 183 | # Load pre-trained weights 184 | if args.model == 'amtl-t0' or args.model == 'amtl-t3' or args.model == 'amtl-t0-ft' or args.model == 'amtl-t1' or args.model == 'amtl-t2': 185 | print('Loading pre-trained weights for Sequential Optimisation') 186 | pretrained_model = torch.load(get_checkpoint_loc(args.model, args.seg_mode)) 187 | pretrained_dict = pretrained_model['state_dict'] 188 | model_dict = model.state_dict() 189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} 190 | model_dict.update(pretrained_dict) 191 | model.load_state_dict(model_dict) 192 | 193 | # Set training flag for submodules based on train model. 194 | model.set_train_test(args.model) 195 | 196 | 197 | if args.KD: 198 | teacher_model = build_model(args, load_pretrained=False) 199 | # Load pre-trained stl_mtl_model 200 | print('Preparing teacher model') 201 | pretrained_model = torch.load('/media/mobarak/data/lalith/mtl_scene_understanding_and_segmentation/checkpoints/stl_s_v1/stl_s_v1/epoch_train/checkpoint_D168_epoch.pth') 202 | pretrained_dict = pretrained_model['state_dict'] 203 | model_dict = teacher_model.state_dict() 204 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} 205 | model_dict.update(pretrained_dict) 206 | teacher_model.load_state_dict(model_dict) 207 | if args.model == 'mtl-t3': 208 | teacher_model.set_train_test('mtl-t3') 209 | teacher_model.model_type3_insert() 210 | teacher_model.cuda() 211 | else: 212 | teacher_model.set_train_test('stl-s') 213 | teacher_model.cuda() 214 | teacher_model.eval() 215 | 216 | # Insert nn layers based on type. 217 | if args.model == 'amtl-t1' or args.model == 'mtl-t1': 218 | model.model_type1_insert() 219 | elif args.model == 'amtl-t2' or args.model == 'mtl-t2': 220 | model.model_type2_insert() 221 | elif args.model == 'amtl-t3' or args.model == 'mtl-t3': 222 | model.model_type3_insert() 223 | 224 | # Priority rank given to node 0 -> current pc, if more nodes -> multiple PCs 225 | os.environ['MASTER_ADDR'] = 'localhost' 226 | os.environ['MASTER_PORT'] = args.port #8892 227 | rank = args.nr * args.gpus + gpu 228 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) 229 | 230 | # Set cuda 231 | torch.cuda.set_device(gpu) 232 | 233 | # Wrap the model with ddp 234 | model.cuda() 235 | model = DDP(model, device_ids=[gpu], find_unused_parameters=True)#, find_unused_parameters=True) 236 | 237 | # Define loss function (criterion) and optimizer 238 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2).cuda(gpu) 239 | graph_scene_criterion = nn.MultiLabelSoftMarginLoss().cuda(gpu) 240 | 241 | # train and test dataloader 242 | train_seq = [[2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]] 243 | val_seq = [[1, 5, 16]] 244 | data_dir = ['datasets/instruments18/seq_'] 245 | img_dir = ['/left_frames/'] 246 | mask_dir = ['/annotations/'] 247 | dset = [0] 248 | data_const = SurgicalSceneConstants() 249 | 250 | seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir': img_dir, 'dset': dset, 'mask_dir': mask_dir} 251 | 252 | # Val_dataset only set in 1 GPU 253 | val_dataset = SurgicalSceneDataset(seq_set=seq['val_seq'], dset=seq['dset'], data_dir=seq['data_dir'], \ 254 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], istrain=False, dataconst=data_const, \ 255 | feature_extractor=args.feature_extractor, reduce_size=False) 256 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) 257 | 258 | # Train_dataset distributed to 2 GPU 259 | train_dataset = SurgicalSceneDataset(seq_set=seq['train_seq'], data_dir=seq['data_dir'], 260 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], dset=seq['dset'], istrain=True, dataconst=data_const, 261 | feature_extractor=args.feature_extractor, reduce_size=False) 262 | 263 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True) 264 | train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True, sampler=train_sampler) 265 | 266 | # Evaluate the model before start of training 267 | if gpu == 0: 268 | if args.KD: 269 | print("=================== Teacher Model=========================") 270 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, teacher_model, val_dataloader) 271 | print("=================== Student Model=========================") 272 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, model, val_dataloader) 273 | print("PT SC ACC: [value: {:0.4f}] PT SC mAP: [value: {:0.4f}] PT Seg mIoU: [value: {:0.4f}]".format(eval_sc_acc, eval_sc_map, eval_seg_miou)) 274 | 275 | for epoch_count in range(args.epoch): 276 | 277 | start_time = time.time() 278 | 279 | # Set model / submodules in train mode 280 | model.train() 281 | if args.model == 'stl-sg' or args.model == 'amtl-t0' or args.model == 'amtl-t3': 282 | model.module.feature_encoder.eval() 283 | model.module.gcn_unit.eval() 284 | model.module.seg_decoder.eval() 285 | elif args.model == 'stl-sg-wfe': 286 | model.module.gcn_unit.eval() 287 | model.module.seg_decoder.eval() 288 | elif args.model == 'stl-s': 289 | model.module.scene_graph.eval() 290 | 291 | train_seg_loss = 0.0 292 | train_scene_graph_loss = 0.0 293 | 294 | model.cuda() 295 | 296 | # Optimizer with decaying learning rate 297 | decay_lr = decay_lr*0.98 if ((epoch_count+1) %10 == 0) else decay_lr 298 | optimizer = optim.Adam(model.parameters(), lr=decay_lr, weight_decay=0) 299 | 300 | train_sampler.set_epoch(epoch_count) 301 | 302 | if gpu == 0: print('================= Train ====================') 303 | 304 | for data in tqdm(train_dataloader): 305 | seg_img = data['img'] 306 | seg_masks = data['mask'] 307 | img_loc = data['img_loc'] 308 | node_num = data['node_num'] 309 | roi_labels = data['roi_labels'] 310 | det_boxes = data['det_boxes'] 311 | edge_labels = data['edge_labels'] 312 | spatial_feat = data['spatial_feat'] 313 | word2vec = data['word2vec'] 314 | 315 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True) 316 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True) 317 | 318 | # Forward propagation 319 | interaction, seg_outputs, fe_feat = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels) 320 | 321 | # Loss calculation 322 | seg_loss = seg_criterion(seg_outputs, seg_masks) 323 | scene_graph_loss = graph_scene_criterion(interaction, edge_labels.float()) 324 | 325 | # KD-Loss 326 | if args.KD: 327 | with torch.no_grad(): 328 | _, _, t_fe_feat = teacher_model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True) 329 | t_fe_feat = t_fe_feat.detach() 330 | t_fe_feat = t_fe_feat / (t_fe_feat.pow(2).sum(1) + 1e-6).sqrt().view(t_fe_feat.size(0), 1, t_fe_feat.size(2), t_fe_feat.size(3)) 331 | 332 | 333 | fe_feat = fe_feat 334 | fe_feat = fe_feat / (fe_feat.pow(2).sum(1) + 1e-6).sqrt().view(fe_feat.size(0), 1, fe_feat.size(2), fe_feat.size(3)) 335 | dist_loss = (fe_feat - t_fe_feat).pow(2).sum(1).mean() 336 | 337 | 338 | if args.model == 'stl-s': 339 | loss_total = seg_loss 340 | elif args.model == 'stl-sg' or args.model == 'stl-sg-wfe' or args.model == 'amtl-t0' or args.model == 'amtl-t3': 341 | loss_total = scene_graph_loss 342 | elif args.KD: 343 | loss_total = (0.4 * scene_graph_loss) + seg_loss + dist_loss 344 | else: 345 | loss_total = (0.4 * scene_graph_loss)+ (0.6 * seg_loss) 346 | 347 | optimizer.zero_grad() 348 | loss_total.backward() 349 | optimizer.step() 350 | 351 | train_seg_loss += seg_loss.item() 352 | train_scene_graph_loss += scene_graph_loss.item() * edge_labels.shape[0] 353 | 354 | # calculate the loss and accuracy of each epoch 355 | train_seg_loss += train_seg_loss / len(train_dataloader) 356 | train_scene_graph_loss = train_scene_graph_loss / len(train_dataloader) 357 | 358 | if gpu == 0: 359 | end_time = time.time() 360 | print("Train Epoch: {}/{} lr: {:0.9f} Graph_loss: {:0.4f} Segmentation_Loss: {:0.4f} Execution time: {:0.4f}".format(\ 361 | epoch_count + 1, args.epoch, decay_lr, train_scene_graph_loss, train_seg_loss, (end_time-start_time))) 362 | 363 | #if epoch_count % 2 == 0: 364 | # save model 365 | # if epoch_loss<0.0405 or epoch_count % args.save_every == (args.save_every - 1): 366 | checkpoint = { 'lr': args.lr, 'b_s': args.batch_size, 'bias': args.bias, 'bn': args.bn, 'dropout': args.drop_prob, 367 | 'layers': args.layers, 'multi_head': args.multi_attn, 368 | 'diff_edge': args.diff_edge, 'state_dict': model.module.state_dict() } 369 | 370 | save_name = "checkpoint_D1" + str(epoch_count+1) + '_epoch.pth' 371 | torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, 'epoch_train', save_name)) 372 | 373 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, model, val_dataloader) 374 | if eval_sc_acc > best_value[0]: 375 | best_value[0] = eval_sc_acc 376 | best_epoch[0] = epoch_count+1 377 | if eval_sc_map > best_value[1]: 378 | best_value[1] = eval_sc_map 379 | best_epoch[1] = epoch_count+1 380 | if eval_seg_miou > best_value[2]: 381 | best_value[2] = eval_seg_miou 382 | best_epoch[2] = epoch_count+1 383 | print("Best SC Acc: [Epoch: {} value: {:0.4f}] Best SC mAP: [Epoch: {} value: {:0.4f}] Best Seg mIoU: [Epoch: {} value: {:0.4f}]".format(\ 384 | best_epoch[0], best_value[0], best_epoch[1], best_value[1], best_epoch[2], best_value[2])) 385 | 386 | return 387 | 388 | 389 | if __name__ == "__main__": 390 | ''' 391 | Main function to set arguments 392 | ''' 393 | 394 | # ---------------------------------------------- Optimization and feature sharing variants ---------------------------------------------- 395 | ''' 396 | Format for the model_type : X-Y 397 | 398 | -> X : Optimisation technique [1. amtl - Sequential MTL Optimisation, 2. mtl - Naive MTL Optimisation] 399 | -> Y : Feature Sharing mechanism [1. t0 - Base model, 400 | 2. t1 - Scene graph features to enhance segmentation (SGFSEG), 401 | 3. t3 - Global interaction space features to improve scene graph (GISFSG)] 402 | 403 | ''' 404 | model_type = 'amtl-t0' 405 | ver = model_type + '_v5' 406 | port = '8892' 407 | f_e = 'resnet18_11_cbs_ts' 408 | 409 | 410 | # ----------------------------------------------Global reasoning variant in segmentation ----------------------------------------------- 411 | ''' 412 | -> seg_mode : v1 - (MSLRGR - multi-scale local reasoning and global reasoning) 413 | v2gc - (MSLR - multi-scale local reasoning) 414 | None - Base model 415 | ''' 416 | seg_mode = 'v1' 417 | 418 | # Set random seed 419 | seed_everything() 420 | print(ver, seg_mode) 421 | 422 | # Device Count 423 | num_gpu = torch.cuda.device_count() 424 | 425 | # Arguments 426 | parser = argparse.ArgumentParser(description='MTL Scene graph and Segmentation') 427 | 428 | # Hyperparameters 429 | parser.add_argument('--lr', type=float, default = 0.00001) 430 | parser.add_argument('--epoch', type=int, default = 130) 431 | parser.add_argument('--start_epoch', type=int, default = 0) 432 | parser.add_argument('--batch_size', type=int, default = 4) 433 | parser.add_argument('--gpu', type=bool, default = True) 434 | parser.add_argument('--train_model', type=str, default = 'epoch') 435 | parser.add_argument('--exp_ver', type=str, default = ver) 436 | 437 | # File locations 438 | parser.add_argument('--log_dir', type=str, default = './log/' + ver) 439 | parser.add_argument('--save_dir', type=str, default = './checkpoints/' + ver) 440 | parser.add_argument('--output_img_dir', type=str, default = './results/' + ver) 441 | parser.add_argument('--save_every', type=int, default = 10) 442 | parser.add_argument('--pretrained', type=str, default = None) 443 | 444 | # Network settings 445 | parser.add_argument('--layers', type=int, default = 1) 446 | parser.add_argument('--bn', type=bool, default = False) 447 | parser.add_argument('--drop_prob', type=float, default = 0.3) 448 | parser.add_argument('--bias', type=bool, default = True) 449 | parser.add_argument('--multi_attn', type=bool, default = False) 450 | parser.add_argument('--diff_edge', type=bool, default = False) 451 | 452 | if model_type == 'mtl-t3' or model_type == 'amtl-t3': 453 | parser.add_argument('--global_feat', type=int, default = 128) 454 | else: 455 | parser.add_argument('--global_feat', type=int, default = 0) 456 | 457 | # Data processing 458 | parser.add_argument('--sampler', type=int, default = 0) 459 | parser.add_argument('--data_aug', type=bool, default = False) 460 | parser.add_argument('--feature_extractor', type=str, default = f_e) 461 | parser.add_argument('--seg_mode', type=str, default = seg_mode) # v1/v2_gc 462 | 463 | parser.add_argument('--KD', type=bool, default = False) 464 | 465 | # GPU distributor 466 | parser.add_argument('--port', type=str, default = port) 467 | parser.add_argument('--nodes', type=int, default = 1, metavar='N', help='Number of data loading workers (default: 4)') 468 | parser.add_argument('--gpus', type=int, default = num_gpu, help='Number of gpus per node') 469 | parser.add_argument('--nr', type=int, default = 0, help='Ranking within the nodes') 470 | 471 | # Model type 472 | parser.add_argument('--model', type=str, default = model_type) 473 | args = parser.parse_args() 474 | 475 | # Constants for the surgical scene 476 | data_const = SurgicalSceneConstants() 477 | 478 | # GPU distributed 479 | args.world_size = args.gpus * args.nodes 480 | 481 | # Train model in distributed settings - (train function, number of GPUs, arguments) 482 | mp.spawn(train_model, nprocs=args.gpus, args=(args,)) -------------------------------------------------------------------------------- /models/segmentation_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding 3 | Lab : MMLAB, National University of Singapore 4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren 5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation 6 | 7 | @inproceedings{fu2019dual, 8 | title={Dual attention network for scene segmentation}, 9 | author={Fu, Jun and Liu, Jing and Tian, Haijie and Li, Yong and Bao, Yongjun and Fang, Zhiwei and Lu, Hanqing}, 10 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 11 | pages={3146--3154}, 12 | year={2019} 13 | } 14 | ''' 15 | 16 | 17 | import math 18 | import numpy as np 19 | from collections import OrderedDict 20 | 21 | import torch 22 | import torch.nn as nn 23 | from torch import Tensor 24 | from torch.nn import functional as F 25 | from torch.nn.functional import interpolate 26 | from typing import Type, Any, Callable, Union, List, Optional 27 | 28 | # Setting the kwargs for upsample configuration 29 | up_kwargs = {'mode': 'bilinear', 'align_corners': True} 30 | 31 | 32 | class Namespace: 33 | """ 34 | Namespace class for custom args to be parsed 35 | Inputs: **kwargs 36 | 37 | """ 38 | def __init__(self, **kwargs): 39 | self.__dict__.update(kwargs) 40 | 41 | def get_backbone(name, **kwargs): 42 | """ 43 | Function to get backbone feature extractor 44 | Inputs: name of backbone, **kwargs 45 | 46 | """ 47 | models = { 48 | 'resnet18_model': resnet18_model, 49 | } 50 | name = name.lower() 51 | if name not in models: 52 | raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys())))) 53 | net = models[name](**kwargs) 54 | return net 55 | 56 | 57 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 58 | """ 59 | 3x3 convolution with padding 60 | Inputs: in_planes, out_planes, stride, groups, dilation 61 | 62 | """ 63 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 64 | padding=dilation, groups=groups, bias=False, dilation=dilation) 65 | 66 | 67 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 68 | """ 69 | 1x1 convolution 70 | Inputs: in_planes, out_planes, stride 71 | 72 | """ 73 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 74 | 75 | 76 | class BasicBlock(nn.Module): 77 | """ 78 | Basic block for ResNet18 backbone 79 | init : 80 | inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer 81 | 82 | forward : x 83 | 84 | """ 85 | expansion: int = 1 86 | 87 | def __init__( 88 | self, 89 | inplanes: int, 90 | planes: int, 91 | stride: int = 1, 92 | downsample: Optional[nn.Module] = None, 93 | groups: int = 1, 94 | base_width: int = 64, 95 | dilation: int = 1, 96 | norm_layer: Optional[Callable[..., nn.Module]] = None 97 | ) -> None: 98 | super(BasicBlock, self).__init__() 99 | if norm_layer is None: 100 | norm_layer = nn.BatchNorm2d 101 | if groups != 1 or base_width != 64: 102 | raise ValueError( 103 | 'BasicBlock only supports groups=1 and base_width=64') 104 | if dilation > 1: 105 | raise NotImplementedError( 106 | "Dilation > 1 not supported in BasicBlock") 107 | 108 | self.planes = planes 109 | 110 | self.conv1 = conv3x3(inplanes, planes, stride) 111 | self.bn1 = norm_layer(planes) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.conv2 = conv3x3(planes, planes) 114 | self.bn2 = norm_layer(planes) 115 | self.downsample = downsample 116 | self.stride = stride 117 | 118 | def forward(self, x: Tensor) -> Tensor: 119 | identity = x 120 | 121 | out = self.conv1(x) 122 | out = self.bn1(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv2(out) 126 | out = self.bn2(out) 127 | 128 | if self.downsample is not None: 129 | identity = self.downsample(x) 130 | 131 | out += identity 132 | out = self.relu(out) 133 | 134 | return out 135 | 136 | 137 | class Bottleneck(nn.Module): 138 | """ 139 | Bottleneck block for ResNet18 140 | init : 141 | inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer 142 | 143 | forward : x 144 | 145 | """ 146 | expansion: int = 4 147 | 148 | def __init__( 149 | self, 150 | inplanes: int, 151 | planes: int, 152 | stride: int = 1, 153 | downsample: Optional[nn.Module] = None, 154 | groups: int = 1, 155 | base_width: int = 64, 156 | dilation: int = 1, 157 | norm_layer: Optional[Callable[..., nn.Module]] = None 158 | ) -> None: 159 | super(Bottleneck, self).__init__() 160 | if norm_layer is None: 161 | norm_layer = nn.BatchNorm2d 162 | width = int(planes * (base_width / 64.)) * groups 163 | 164 | # self.conv2 and self.downsample layers downsample the input when stride != 1 165 | self.conv1 = conv1x1(inplanes, width) 166 | self.bn1 = norm_layer(width) 167 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 168 | self.bn2 = norm_layer(width) 169 | self.conv3 = conv1x1(width, planes * self.expansion) 170 | self.bn3 = norm_layer(planes * self.expansion) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.downsample = downsample 173 | self.stride = stride 174 | 175 | def forward(self, x: Tensor) -> Tensor: 176 | identity = x 177 | 178 | out = self.conv1(x) 179 | out = self.bn1(out) 180 | out = self.relu(out) 181 | 182 | out = self.conv2(out) 183 | out = self.bn2(out) 184 | out = self.relu(out) 185 | 186 | out = self.conv3(out) 187 | out = self.bn3(out) 188 | 189 | # Downsampling of the input variable (x) 190 | if self.downsample is not None: 191 | identity = self.downsample(x) 192 | 193 | out += identity 194 | out = self.relu(out) 195 | 196 | return out 197 | 198 | 199 | class ResNet(nn.Module): 200 | """ 201 | ResNet base class for different variants 202 | init : 203 | block, layers, num_classes (ImageNet), zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer 204 | 205 | forward : x 206 | """ 207 | 208 | def __init__( 209 | self, 210 | block: Type[Union[BasicBlock, Bottleneck]], 211 | layers: List[int], 212 | num_classes: int = 1000, 213 | zero_init_residual: bool = False, 214 | groups: int = 1, 215 | width_per_group: int = 64, 216 | replace_stride_with_dilation: Optional[List[bool]] = None, 217 | norm_layer: Optional[Callable[..., nn.Module]] = None 218 | ) -> None: 219 | 220 | super(ResNet, self).__init__() 221 | if norm_layer is None: 222 | norm_layer = nn.BatchNorm2d 223 | self._norm_layer = norm_layer 224 | self.inplanes = 64 225 | self.dilation = 1 226 | 227 | if replace_stride_with_dilation is None: 228 | # Each element in the tuple indicates whether we should replace the 2x2 stride with a dilated convolution 229 | replace_stride_with_dilation = [False, False, False] 230 | 231 | if len(replace_stride_with_dilation) != 3: 232 | raise ValueError("replace_stride_with_dilation should be None " 233 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 234 | 235 | self.groups = groups 236 | self.base_width = width_per_group 237 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 238 | bias=False) 239 | self.bn1 = norm_layer(self.inplanes) 240 | self.relu = nn.ReLU(inplace=True) 241 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 242 | self.layer1 = self._make_layer(block, 64, layers[0]) 243 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 244 | dilate=replace_stride_with_dilation[0]) 245 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 246 | dilate=replace_stride_with_dilation[1]) 247 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 248 | dilate=replace_stride_with_dilation[2]) 249 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 250 | self.fc = nn.Linear(512 * block.expansion, num_classes) 251 | 252 | for m in self.modules(): 253 | if isinstance(m, nn.Conv2d): 254 | nn.init.kaiming_normal_( 255 | m.weight, mode='fan_out', nonlinearity='relu') 256 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 257 | nn.init.constant_(m.weight, 1) 258 | nn.init.constant_(m.bias, 0) 259 | 260 | if zero_init_residual: 261 | for m in self.modules(): 262 | if isinstance(m, Bottleneck): 263 | nn.init.constant_(m.bn3.weight, 0) 264 | elif isinstance(m, BasicBlock): 265 | nn.init.constant_(m.bn2.weight, 0) 266 | 267 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 268 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 269 | norm_layer = self._norm_layer 270 | downsample = None 271 | previous_dilation = self.dilation 272 | if dilate: 273 | self.dilation *= stride 274 | stride = 1 275 | if stride != 1 or self.inplanes != planes * block.expansion: 276 | downsample = nn.Sequential( 277 | conv1x1(self.inplanes, planes * block.expansion, stride), 278 | norm_layer(planes * block.expansion), 279 | ) 280 | 281 | layers = [] 282 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 283 | self.base_width, previous_dilation, norm_layer)) 284 | self.inplanes = planes * block.expansion 285 | for _ in range(1, blocks): 286 | layers.append(block(self.inplanes, planes, groups=self.groups, 287 | base_width=self.base_width, dilation=self.dilation, 288 | norm_layer=norm_layer)) 289 | 290 | return nn.Sequential(*layers) 291 | 292 | 293 | def _forward_impl(self, x) -> Tensor: 294 | x = self.conv1(x) 295 | x = self.bn1(x) 296 | x = self.relu(x) 297 | x = self.maxpool(x) 298 | 299 | c1 = self.layer1(x) 300 | c2 = self.layer2(c1) 301 | c3 = self.layer3(c2) 302 | c4 = self.layer4(c3) 303 | 304 | return c1, c2, c3, c4 305 | 306 | 307 | def forward(self, x: Tensor) -> Tensor: 308 | return self._forward_impl(x) 309 | 310 | 311 | class BaseNet(nn.Module): 312 | """ 313 | BaseNet class for Multi-scale global reasoned segmentation module 314 | 315 | init : 316 | block, layers, num_classes (ImageNet), zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer 317 | 318 | forward : x 319 | 320 | """ 321 | def __init__(self, nclass, backbone, pretrained, dilated=True, norm_layer=None, 322 | root='~/.encoding/models', *args, **kwargs): 323 | super(BaseNet, self).__init__() 324 | self.nclass = nclass 325 | 326 | # Copying modules from pretrained models 327 | self.backbone = backbone 328 | self.pretrained = get_backbone(backbone, pretrained=pretrained, dilated=dilated, 329 | norm_layer=norm_layer, root=root, 330 | *args, **kwargs) 331 | self.pretrained.fc = None 332 | self._up_kwargs = up_kwargs 333 | 334 | def base_forward(self, x): 335 | 336 | x = self.pretrained.conv1(x) 337 | x = self.pretrained.bn1(x) 338 | x = self.pretrained.relu(x) 339 | x = self.pretrained.maxpool(x) 340 | c = self.pretrained.layer1(x) 341 | c = self.pretrained.layer2(c) 342 | c = self.pretrained.layer3(c) 343 | c = self.pretrained.layer4(c) 344 | 345 | return None, None, None, c 346 | 347 | def evaluate(self, x, target=None): 348 | pred = self.forward(x) 349 | if isinstance(pred, (tuple, list)): 350 | pred = pred[0] 351 | if target is None: 352 | return pred 353 | correct, labeled = batch_pix_accuracy(pred.data, target.data) 354 | inter, union = batch_intersection_union( 355 | pred.data, target.data, self.nclass) 356 | return correct, labeled, inter, union 357 | 358 | 359 | def _resnet( 360 | arch: str, 361 | block: Type[Union[BasicBlock, Bottleneck]], 362 | layers: List[int], 363 | pretrained: bool, 364 | progress: bool, 365 | **kwargs: Any 366 | ) -> ResNet: 367 | 368 | """ 369 | ResNet model function to load pre-trained model: Class call 370 | init : 371 | arch, block, layers, pretrained, progress, **kwargs 372 | 373 | forward : x 374 | """ 375 | 376 | model = ResNet(block, layers, **kwargs) 377 | if pretrained: 378 | print("Loading pre-trained ImageNet weights") 379 | state_dict = torch.load('models/r18/resnet18-f37072fd.pth') 380 | model.load_state_dict(state_dict) 381 | return model 382 | 383 | 384 | def resnet18(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet: 385 | """ 386 | ResNet18 model call function 387 | Inputs: pretrained, progress, **kwargs 388 | 389 | """ 390 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 391 | **kwargs) 392 | 393 | class Resnet18_main(nn.Module): 394 | """ 395 | ResNet main function for feature extractor 396 | init : pretrained, num_classes 397 | forward : x 398 | """ 399 | def __init__(self, pretrained, num_classes=1000): 400 | 401 | super(Resnet18_main, self).__init__() 402 | resnet18_block = resnet18( 403 | pretrained=pretrained) 404 | 405 | resnet18_block.fc = nn.Conv2d(resnet18_block.inplanes, num_classes, 1) 406 | 407 | self.resnet18_block = resnet18_block 408 | self._normal_initialization(self.resnet18_block.fc) 409 | 410 | self.in_planes = 64 411 | self.kernel_size = 3 412 | 413 | 414 | def _normal_initialization(self, layer): 415 | 416 | layer.weight.data.normal_(0, 0.01) 417 | layer.bias.data.zero_() 418 | 419 | def forward(self, x): 420 | c1, c2, c3, c4 = self.resnet18_block(x) 421 | 422 | return c1, c2, c3, c4 423 | 424 | 425 | class GCN(nn.Module): 426 | """ 427 | Graph Convolution network for Global interaction space 428 | init : 429 | num_state, num_node, bias=False 430 | 431 | forward : x, scene_feat = None, model_type = None 432 | 433 | """ 434 | def __init__(self, num_state, num_node, bias=False): 435 | super(GCN, self).__init__() 436 | self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1, padding=0, 437 | stride=1, groups=1, bias=True) 438 | self.relu = nn.ReLU(inplace=True) 439 | self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, padding=0, 440 | stride=1, groups=1, bias=bias) 441 | self.x_avg_pool = nn.AvgPool1d(128,1) 442 | 443 | def forward(self, x, scene_feat = None, model_type = None): 444 | h = self.conv1(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1) 445 | 446 | if (model_type == 'amtl-t1' or model_type == 'mtl-t1') and scene_feat is not None: # (x+h+(avg(x)*f)) 447 | x_p = torch.matmul(self.x_avg_pool(x.permute(0, 2, 1).contiguous()), scene_feat) 448 | h = h + x + x_p.permute(0, 2, 1).contiguous() 449 | else: 450 | h = h + x 451 | 452 | h = self.relu(h) 453 | h = self.conv2(h) 454 | 455 | return h 456 | 457 | 458 | class GloRe_Unit(nn.Module): 459 | """ 460 | Global Reasoning Unit (GR/GloRe) 461 | init : 462 | num_in, num_mid, stride=(1, 1), kernel=1 463 | 464 | forward : x, scene_feat = None, model_type = None 465 | AMTL - Sequential MTL Optimisation 466 | MTL - Naive MTL Optimisation 467 | 468 | """ 469 | def __init__(self, num_in, num_mid, stride=(1, 1), kernel=1): 470 | super(GloRe_Unit, self).__init__() 471 | 472 | self.num_s = int(2 * num_mid) 473 | self.num_n = int(1 * num_mid) 474 | 475 | kernel_size = (kernel, kernel) 476 | padding = (1, 1) if kernel == 3 else (0, 0) 477 | 478 | # Reduce dimension 479 | self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=kernel_size, padding=padding) 480 | # generate graph transformation function 481 | self.conv_proj = nn.Conv2d(num_in, self.num_n, kernel_size=kernel_size, padding=padding) 482 | # ---------- 483 | self.gcn = GCN(num_state=self.num_s, num_node=self.num_n) 484 | # ---------- 485 | # tail: extend dimension 486 | self.fc_2 = nn.Conv2d(self.num_s, num_in, kernel_size=kernel_size, padding=padding, stride=(1, 1),groups=1, bias=False) 487 | 488 | self.blocker = nn.BatchNorm2d(num_in) 489 | 490 | def forward(self, x, scene_feat = None, model_type = None): 491 | ''' 492 | Parameter x dimension : (N, C, H, W) 493 | ''' 494 | batch_size = x.size(0) 495 | x_state_reshaped = self.conv_state(x).view(batch_size, self.num_s, -1) 496 | x_proj_reshaped = self.conv_proj(x).view(batch_size, self.num_n, -1) 497 | x_rproj_reshaped = x_proj_reshaped 498 | 499 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 500 | 501 | # Projection: Coordinate space -> Interaction space 502 | x_n_state = torch.matmul( x_state_reshaped, x_proj_reshaped.permute(0, 2, 1)) 503 | x_n_state = x_n_state * (1. / x_state_reshaped.size(2)) 504 | 505 | if model_type == 'amtl-t2' or model_type == 'mtl-t2': 506 | x_n_rel = torch.matmul(x_n_state.permute(0, 2, 1).contiguous(), scene_feat).permute(0, 2, 1) 507 | else: 508 | x_n_rel = self.gcn(x_n_state, scene_feat, model_type) 509 | 510 | out2 = None 511 | if model_type == 'amtl-t3' or model_type == 'mtl-t3': 512 | out2 = x_n_rel 513 | 514 | # Reverse projection: Interaction space -> Coordinate space 515 | x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped) 516 | 517 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 518 | x_state = x_state_reshaped.view(batch_size, self.num_s, *x.size()[2:]) 519 | out = x + self.blocker(self.fc_2(x_state)) 520 | 521 | return out, out2 522 | 523 | 524 | class GR_Decoder(nn.Module): 525 | """ 526 | Multi-scale Global Reasoned (GR) Decoder for Feature Aggregation 527 | init : 528 | in_channels, out_channels, norm_layer 529 | 530 | forward : s4, s1 = None, s2 = None, s3 = None, imsize = None, seg_mode = None 531 | 532 | -> s1-s4 are Scale-specific features 533 | -> out_channels = num_classes (8) 534 | -> seg_mode : V1 (MSLRGR - multi-scale local reasoning and global reasoning) 535 | V2GC (MSLR - multi-scale local reasoning) 536 | """ 537 | def __init__(self, in_channels, out_channels, norm_layer): 538 | super(GR_Decoder, self).__init__() 539 | 540 | # Scale-specific channel dimensions 541 | inter_channels = in_channels // 2 # 256 542 | c2 = inter_channels // 2 # 128 543 | c1 = c2 // 2 # 64 544 | 545 | # Scale-specific decoder layers with simple Conv-BN-ReLU-Dropout-Conv Block 546 | self.s1_layer = nn.Sequential(nn.Sequential(nn.Conv2d(c1, c1, 3, padding=1, bias=False), norm_layer(c1), nn.ReLU()), 547 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(c1, out_channels, 1))) 548 | 549 | self.s2_layer = nn.Sequential(nn.Sequential(nn.Conv2d(c2, c2, 3, padding=1, bias=False), norm_layer(c2), nn.ReLU()), 550 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(c2, out_channels, 1))) 551 | 552 | self.s3_layer = nn.Sequential(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU()), 553 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(inter_channels, out_channels, 1))) 554 | 555 | self.s4_decoder = nn.Sequential(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU()), 556 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(256, out_channels, 1))) 557 | 558 | 559 | def forward(self, x, s1 = None, s2 = None, s3 = None, imsize = None, seg_mode = None): 560 | x = list(tuple([self.s4_decoder(x)])) 561 | outputs = [] 562 | for i in range(len(x)): 563 | outputs.append( 564 | interpolate(x[i], imsize, mode='bilinear', align_corners=True)) 565 | 566 | # V1 and V2_GC are Segmentation modes, MSLRGR and MSGR Respectively 567 | if seg_mode == 'v2_gc' or seg_mode == 'v1': 568 | s1 = interpolate(self.s1_layer(s1), imsize, mode='bilinear', align_corners=True) 569 | s2 = interpolate(self.s2_layer(s2), imsize, mode='bilinear', align_corners=True) 570 | s3 = interpolate(self.s3_layer(s3), imsize, mode='bilinear', align_corners=True) 571 | outputs = outputs[0] 572 | outputs = s1 + s2 + s3 + outputs 573 | return outputs 574 | else: 575 | return tuple(outputs)[0] 576 | 577 | 578 | class GR_Segmentation(BaseNet): 579 | """ 580 | Global-Reasoned (GR) Segmentation module INITIALISATION 581 | init : 582 | nclass, backbone, aux=False, se_loss=False, norm_layer=nn.BatchNorm2d, gcn_search=None, **kwargs 583 | 584 | forward : x (Not used in MTL forward pass) 585 | 586 | """ 587 | def __init__(self, nclass, backbone, pretrained, aux=False, se_loss=False, norm_layer=nn.BatchNorm2d, gcn_search=None, **kwargs): 588 | super(GR_Segmentation, self).__init__(nclass, backbone, pretrained, norm_layer=norm_layer, **kwargs) 589 | 590 | in_channels = 512 591 | 592 | # GR module 593 | self.gr_interaction = GR_module(in_channels, nclass, norm_layer, gcn_search) 594 | 595 | # GR decoder 596 | self.gr_decoder = GR_Decoder(in_channels, nclass, norm_layer) 597 | 598 | # !NOTE: - In the MTL forward pass, this forward function is NOT USED !!!!!!!!!!!!!!!! 599 | 600 | def forward(self, x): 601 | imsize = x.size()[2:] 602 | 603 | # Encoder module 604 | s1, s2, s3, s4 = self.base_forward(x) 605 | 606 | # GCN with 1 conv block to bridge to GloRE Unit 607 | x = self.gr_interaction(c4) 608 | 609 | # Decoder module 610 | x = self.gr_decoder(x, imsize) 611 | return x 612 | 613 | 614 | class GR_module(nn.Module): 615 | """ 616 | Multi-scale Global Reasoning (GR) Unit 617 | init : 618 | in_channels, out_channels, norm_layer, gcn_search 619 | 620 | forward : x, s1 = None, s2 = None, s3 = None, scene_feat = None, seg_mode = None, model_type = None 621 | -> s1-s4 are Scale-specific features 622 | -> out_channels = num_classes (8) 623 | -> seg_mode : V1 (MSLRGR - multi-scale local reasoning and global reasoning) 624 | V2GC (MSLR - multi-scale local reasoning) 625 | 626 | """ 627 | def __init__(self, in_channels, out_channels, norm_layer, gcn_search): 628 | super(GR_module, self).__init__() 629 | 630 | inter_channels = in_channels // 2 # 256 631 | c2 = inter_channels // 2 # 128 632 | c1 = c2 // 2 # 64 633 | 634 | # Simple Conv-BN-ReLU Block 635 | self.conv_s4 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU()) 636 | 637 | # Scale-specific GR unit (GloRE) 638 | self.gcn1 = GloRe_Unit(c1, 64, kernel=1) 639 | self.gcn2 = GloRe_Unit(c2, 64, kernel=1) 640 | self.gcn3 = GloRe_Unit(inter_channels, 64, kernel=1) 641 | self.gcn4 = GloRe_Unit(inter_channels, 64, kernel=1) 642 | 643 | def forward(self, x, s1 = None, s2 = None, s3 = None, scene_feat = None, seg_mode = None, model_type = None): 644 | 645 | feat1 = None 646 | feat2 = None 647 | feat3 = None 648 | feat5 = None 649 | 650 | if seg_mode == 'v2_gc': # MODE - MSGR 651 | feat1, _ = self.gcn1(s1, scene_feat) 652 | feat2, _ = self.gcn2(s2, scene_feat) 653 | feat3, _ = self.gcn3(s3, scene_feat) 654 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type) 655 | 656 | elif seg_mode == 'v1': # MODE - MSLRGR 657 | feat1, feat2, feat3 = s1, s2, s3 658 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type) 659 | 660 | else: 661 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type) 662 | 663 | return feat1, feat2, feat3, feat4, feat5 664 | 665 | def resnet18_model(pretrained=True, root='~/.encoding/models', **kwargs): 666 | model = Resnet18_main(pretrained, num_classes=8) 667 | return model 668 | 669 | 670 | def get_gcnet(dataset='endovis18', backbone='resnet18_model', num_classes=8, pretrained=False, root='./pretrain_models', **kwargs): 671 | model = GR_Segmentation(nclass=num_classes, backbone=backbone, pretrained=pretrained, root=root, **kwargs) 672 | return model 673 | --------------------------------------------------------------------------------