├── figures
├── figure_1.pdf
├── figure_2.jpg
├── figure_2.pdf
├── figure_3.pdf
├── figure_4.jpg
├── figure_4.pdf
└── figure_5.pdf
├── .gitignore
├── utils
├── scene_graph_eval_matrix.py
├── utils.py
├── vis_tool.py
├── segmentation_eval_matrix.py
└── io.py
├── environment.yml
├── eval_instructions.txt
├── README.md
├── models
├── surgicalDataset.py
├── mtl_model.py
├── scene_graph.py
└── segmentation_model.py
├── evaluation.py
└── model_train.py
/figures/figure_1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_1.pdf
--------------------------------------------------------------------------------
/figures/figure_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_2.jpg
--------------------------------------------------------------------------------
/figures/figure_2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_2.pdf
--------------------------------------------------------------------------------
/figures/figure_3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_3.pdf
--------------------------------------------------------------------------------
/figures/figure_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_4.jpg
--------------------------------------------------------------------------------
/figures/figure_4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_4.pdf
--------------------------------------------------------------------------------
/figures/figure_5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lalithjets/Global-reasoned-multi-task-model/HEAD/figures/figure_5.pdf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ipynb_checkpoints/
2 | .vscode/
3 |
4 | checkpoints/
5 | models/r18/
6 | datasets/
7 | old_deprecated/
8 |
9 | feature_extractor/checkpoint
10 | log/
11 | results/
12 | venv/
13 |
14 | sai_transfer/
15 | process_checkpoint/
16 |
17 | tmp*
18 | *__pycache__
19 | *.pyc
--------------------------------------------------------------------------------
/utils/scene_graph_eval_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn.metrics
3 |
4 | def compute_mean_avg_prec(y_true, y_score):
5 | try:
6 | avg_prec = sklearn.metrics.average_precision_score(y_true, y_score, average=None)
7 | mean_avg_prec = np.nansum(avg_prec) / len(avg_prec)
8 | except ValueError:
9 | mean_avg_prec = 0
10 |
11 | return mean_avg_prec
12 |
13 | def calibration_metrics(logits_all, labels_all):
14 |
15 | logits = logits_all.detach().cpu().numpy()
16 | labels = labels_all.detach().cpu().numpy()
17 | map_value = compute_mean_avg_prec(labels, logits)
18 | labels = np.argmax(labels, axis=-1)
19 | recall = sklearn.metrics.recall_score(labels, np.argmax(logits,1), average='macro')
20 | return(map_value, recall)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_from_url(url, path):
4 | """Download file, with logic (from tensor2tensor) for Google Drive"""
5 | if 'drive.google.com' not in url:
6 | print('Downloading %s; may take a few minutes' % url)
7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
8 | with open(path, "wb") as file:
9 | file.write(r.content)
10 | return
11 | print('Downloading from Google Drive; may take a few minutes')
12 | confirm_token = None
13 | session = requests.Session()
14 | response = session.get(url, stream=True)
15 | for k, v in response.cookies.items():
16 | if k.startswith("download_warning"):
17 | confirm_token = v
18 |
19 | if confirm_token:
20 | url = url + "&confirm=" + confirm_token
21 | response = session.get(url, stream=True)
22 |
23 | chunk_size = 16 * 1024
24 | with open(path, "wb") as f:
25 | for chunk in response.iter_content(chunk_size):
26 | if chunk:
27 | f.write(chunk)
28 |
--------------------------------------------------------------------------------
/utils/vis_tool.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import random
4 | import numpy as np
5 | import matplotlib
6 | import torch as t
7 |
8 | matplotlib.use('Agg')
9 | from matplotlib import pyplot as plot
10 | from PIL import Image, ImageDraw, ImageFont
11 |
12 |
13 | def vis_img(img, node_classes, bboxs, det_action, data_const, score_thresh = 0.7):
14 |
15 | Drawer = ImageDraw.Draw(img)
16 | line_width = 3
17 | outline = '#FF0000'
18 | font = ImageFont.truetype(font='/usr/share/fonts/truetype/freefont/FreeMono.ttf', size=25)
19 |
20 | im_w,im_h = img.size
21 | node_num = len(node_classes)
22 | edge_num = len(det_action)
23 | tissue_num = len(np.where(node_classes == 1)[0])
24 |
25 | for node in range(node_num):
26 |
27 | r_color = random.choice(np.arange(256))
28 | g_color = random.choice(np.arange(256))
29 | b_color = random.choice(np.arange(256))
30 |
31 | text = data_const.instrument_classes[node_classes[node]]
32 | h, w = font.getsize(text)
33 | Drawer.rectangle(list(bboxs[node]), outline=outline, width=line_width)
34 | Drawer.text(xy=(bboxs[node][0], bboxs[node][1]-w-1), text=text, font=font, fill=(r_color,g_color,b_color))
35 |
36 | edge_idx = 0
37 |
38 | for tissue in range(tissue_num):
39 | for instrument in range(tissue+1, node_num):
40 |
41 | #action_idx = np.where(det_action[edge_idx] > score_thresh)[0]
42 | action_idx = np.argmax(det_action[edge_idx])
43 | # print('det_action', det_action[edge_idx])
44 | # print('action_idx',action_idx)
45 |
46 | text = data_const.action_classes[action_idx]
47 | r_color = random.choice(np.arange(256))
48 | g_color = random.choice(np.arange(256))
49 | b_color = random.choice(np.arange(256))
50 |
51 | x1,y1,x2,y2 = bboxs[tissue]
52 | x1_,y1_,x2_,y2_ = bboxs[instrument]
53 |
54 | c0 = int(0.5*x1)+int(0.5*x2)
55 | c0 = max(0,min(c0,im_w-1))
56 | r0 = int(0.5*y1)+int(0.5*y2)
57 | r0 = max(0,min(r0,im_h-1))
58 | c1 = int(0.5*x1_)+int(0.5*x2_)
59 | c1 = max(0,min(c1,im_w-1))
60 | r1 = int(0.5*y1_)+int(0.5*y2_)
61 | r1 = max(0,min(r1,im_h-1))
62 | Drawer.line(((c0,r0),(c1,r1)), fill=(r_color,g_color,b_color), width=3)
63 | Drawer.text(xy=(c1, r1), text=text, font=font, fill=(r_color,g_color,b_color))
64 |
65 | edge_idx +=1
66 |
67 | return img
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gr-mtl-environment
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - blas=1.0=mkl
9 | - ca-certificates=2021.7.5=h06a4308_1
10 | - certifi=2021.5.30=py36h06a4308_0
11 | - cudatoolkit=10.2.89=hfd86e86_1
12 | - dataclasses=0.8=pyh4f3eec9_6
13 | - freetype=2.10.4=h5ab3b9f_0
14 | - intel-openmp=2021.3.0=h06a4308_3350
15 | - jpeg=9b=h024ee3a_2
16 | - lcms2=2.12=h3be6417_0
17 | - libedit=3.1.20210216=h27cfd23_1
18 | - libffi=3.2.1=hf484d3e_1007
19 | - libgcc-ng=9.3.0=h5101ec6_17
20 | - libgomp=9.3.0=h5101ec6_17
21 | - libpng=1.6.37=hbc83047_0
22 | - libstdcxx-ng=9.3.0=hd4cf53a_17
23 | - libtiff=4.2.0=h85742a9_0
24 | - libuv=1.40.0=h7b6447c_0
25 | - libwebp-base=1.2.0=h27cfd23_0
26 | - lz4-c=1.9.3=h295c915_1
27 | - mkl=2020.2=256
28 | - mkl-service=2.3.0=py36he8ac12f_0
29 | - mkl_fft=1.3.0=py36h54f3939_0
30 | - mkl_random=1.1.1=py36h0573a6f_0
31 | - ncurses=6.2=he6710b0_1
32 | - ninja=1.10.2=hff7bd54_1
33 | - numpy=1.19.2=py36h54aff64_0
34 | - numpy-base=1.19.2=py36hfa32c7d_0
35 | - olefile=0.46=py36_0
36 | - openjpeg=2.3.0=h05c96fa_1
37 | - openssl=1.1.1k=h27cfd23_0
38 | - pillow=8.3.1=py36h2c7a002_0
39 | - pip=21.2.2=py36h06a4308_0
40 | - python=3.6.9=h265db76_0
41 | - pytorch=1.7.1=py3.6_cuda10.2.89_cudnn7.6.5_0
42 | - readline=7.0=h7b6447c_5
43 | - setuptools=52.0.0=py36h06a4308_0
44 | - six=1.16.0=pyhd3eb1b0_0
45 | - sqlite=3.33.0=h62c20be_0
46 | - tk=8.6.10=hbc83047_0
47 | - torchaudio=0.7.2=py36
48 | - torchvision=0.8.2=py36_cu102
49 | - typing_extensions=3.10.0.0=pyh06a4308_0
50 | - wheel=0.37.0=pyhd3eb1b0_0
51 | - xz=5.2.5=h7b6447c_0
52 | - zlib=1.2.11=h7b6447c_3
53 | - zstd=1.4.9=haebb681_0
54 | - pip:
55 | - albumentations==1.0.3
56 | - cached-property==1.5.2
57 | - charset-normalizer==2.0.4
58 | - cycler==0.10.0
59 | - decorator==4.4.2
60 | - dgl-cu102==0.4.2
61 | - h5py==2.10.0
62 | - idna==3.2
63 | - imageio==2.9.0
64 | - importlib-metadata==4.8.1
65 | - joblib==1.0.1
66 | - kiwisolver==1.3.1
67 | - matplotlib==3.3.4
68 | - networkx==2.5.1
69 | - opencv-python-headless==4.5.3.56
70 | - prettytable==2.2.0
71 | - pynvml==11.0.0
72 | - pyparsing==2.4.7
73 | - python-dateutil==2.8.2
74 | - pywavelets==1.1.1
75 | - pyyaml==5.4.1
76 | - requests==2.26.0
77 | - scikit-image==0.17.2
78 | - scikit-learn==0.24.2
79 | - scipy==1.5.4
80 | - tabulate==0.8.9
81 | - threadpoolctl==2.2.0
82 | - tifffile==2020.9.3
83 | - torchsummary==1.5.1
84 | - tqdm==4.62.2
85 | - urllib3==1.26.6
86 | - wcwidth==0.2.5
87 | - zipp==3.5.0
88 |
--------------------------------------------------------------------------------
/utils/segmentation_eval_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 |
7 |
8 |
9 | def batch_pix_accuracy(output, target):
10 | """Batch Pixel Accuracy
11 | Args:
12 | predict: input 4D tensor
13 | target: label 3D tensor
14 | """
15 | _, predict = torch.max(output, 1)
16 |
17 | predict = predict.cpu().numpy().astype('int64') + 1
18 | target = target.cpu().numpy().astype('int64') + 1
19 |
20 | pixel_labeled = np.sum(target > 0)
21 | pixel_correct = np.sum((predict == target)*(target > 0))
22 | assert pixel_correct <= pixel_labeled, \
23 | "Correct area should be smaller than Labeled"
24 | return pixel_correct, pixel_labeled
25 |
26 |
27 | def batch_intersection_union(output, target, nclass):
28 | """Batch Intersection of Union
29 | Args:
30 | predict: input 4D tensor
31 | target: label 3D tensor
32 | nclass: number of categories (int)
33 | """
34 | _, predict = torch.max(output, 1)
35 | mini = 1
36 | maxi = nclass
37 | nbins = nclass
38 | predict = predict.cpu().numpy().astype('int64') + 1
39 | target = target.cpu().numpy().astype('int64') + 1
40 |
41 | predict = predict * (target > 0).astype(predict.dtype)
42 | intersection = predict * (predict == target)
43 | # areas of intersection and union
44 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
45 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
46 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
47 | area_union = area_pred + area_lab - area_inter
48 | assert (area_inter <= area_union).all(), \
49 | "Intersection area should be smaller than Union area"
50 | return area_inter, area_union
51 |
52 |
53 | class SegmentationLosses(nn.CrossEntropyLoss):
54 | def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
55 | aux=False, aux_weight=0.4, weight=None,
56 | ignore_index=-1):
57 | super(SegmentationLosses, self).__init__(weight, None, ignore_index)
58 | self.se_loss = se_loss
59 | self.aux = aux
60 | self.nclass = nclass
61 | self.se_weight = se_weight
62 | self.aux_weight = aux_weight
63 | self.bceloss = nn.BCELoss(weight)
64 |
65 | def forward(self, *inputs):
66 | if not self.se_loss and not self.aux:
67 | return super(SegmentationLosses, self).forward(*inputs)
68 | elif not self.se_loss:
69 | pred1, pred2, target = tuple(inputs)
70 | loss1 = super(SegmentationLosses, self).forward(pred1, target)
71 | loss2 = super(SegmentationLosses, self).forward(pred2, target)
72 | return loss1 + self.aux_weight * loss2
73 | elif not self.aux:
74 | pred, se_pred, target = tuple(inputs)
75 | se_target = self._get_batch_label_vector(
76 | target, nclass=self.nclass).type_as(pred)
77 | loss1 = super(SegmentationLosses, self).forward(pred, target)
78 | loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
79 | return loss1 + self.se_weight * loss2
80 | else:
81 | pred1, se_pred, pred2, target = tuple(inputs)
82 | se_target = self._get_batch_label_vector(
83 | target, nclass=self.nclass).type_as(pred1)
84 | loss1 = super(SegmentationLosses, self).forward(pred1, target)
85 | loss2 = super(SegmentationLosses, self).forward(pred2, target)
86 | loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
87 | return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
88 |
89 | @staticmethod
90 | def _get_batch_label_vector(target, nclass):
91 | # target is a 3D Variable BxHxW, output is 2D BxnClass
92 | batch = target.size(0)
93 | tvect = Variable(torch.zeros(batch, nclass))
94 | for i in range(batch):
95 | hist = torch.histc(target[i].cpu().data.float(),
96 | bins=nclass, min=0,
97 | max=nclass-1)
98 | vect = hist > 0
99 | tvect[i] = vect
100 | return tvect
101 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import json
4 | import yaml
5 | import numpy as np
6 | import gzip
7 | import scipy.io
8 |
9 | def load_pickle_object(file_name, compress=True):
10 | data = read(file_name)
11 | if compress:
12 | load_object = pickle.loads(gzip.decompress(data))
13 | else:
14 | load_object = pickle.loads(data)
15 | return load_object
16 |
17 |
18 | def dump_pickle_object(dump_object, file_name, compress=True, compress_level=9):
19 | data = pickle.dumps(dump_object)
20 | if compress:
21 | write(file_name, gzip.compress(data, compresslevel=compress_level))
22 | else:
23 | write(file_name, data)
24 |
25 |
26 | def load_json_object(file_name, compress=False):
27 | if compress:
28 | return json.loads(gzip.decompress(read(file_name)).decode('utf8'))
29 | else:
30 | return json.loads(read(file_name, 'r'))
31 |
32 |
33 | def dump_json_object(dump_object, file_name, compress=False, indent=4):
34 | data = json.dumps(
35 | dump_object, cls=NumpyAwareJSONEncoder, sort_keys=True, indent=indent)
36 | if compress:
37 | write(file_name, gzip.compress(data.encode('utf8')))
38 | else:
39 | write(file_name, data, 'w')
40 |
41 |
42 | def dumps_json_object(dump_object, indent=4):
43 | data = json.dumps(
44 | dump_object, cls=NumpyAwareJSONEncoder, sort_keys=True, indent=indent)
45 | return data
46 |
47 |
48 | def load_mat_object(file_name):
49 | return scipy.io.loadmat(file_name=file_name)
50 |
51 |
52 | def load_yaml_object(file_name):
53 | return yaml.load(read(file_name, 'r'))
54 |
55 |
56 | def read(file_name, mode='rb'):
57 | with open(file_name, mode) as f:
58 | return f.read()
59 |
60 |
61 | def write(file_name, data, mode='wb'):
62 | with open(file_name, mode) as f:
63 | f.write(data)
64 |
65 |
66 | def serialize_object(in_obj, method='json'):
67 | if method == 'json':
68 | return json.dumps(in_obj)
69 | else:
70 | return pickle.dumps(in_obj)
71 |
72 |
73 | def deserialize_object(obj_str, method='json'):
74 | if method == 'json':
75 | return json.loads(obj_str)
76 | else:
77 | return pickle.loads(obj_str)
78 |
79 |
80 | def mkdir_if_not_exists(dir_name, recursive=False):
81 | if os.path.exists(dir_name):
82 | return
83 | if recursive:
84 | os.makedirs(dir_name)
85 | else:
86 | os.mkdir(dir_name)
87 |
88 |
89 | class NumpyAwareJSONEncoder(json.JSONEncoder):
90 | def default(self, obj):
91 | if isinstance(obj, np.ndarray):
92 | if obj.ndim == 1:
93 | return obj.tolist()
94 | else:
95 | return [self.default(obj[i]) for i in range(obj.shape[0])]
96 | elif isinstance(obj, np.int64):
97 | return int(obj)
98 | elif isinstance(obj, np.int32):
99 | return int(obj)
100 | elif isinstance(obj, np.int16):
101 | return int(obj)
102 | elif isinstance(obj, np.float64):
103 | return float(obj)
104 | elif isinstance(obj, np.float32):
105 | return float(obj)
106 | elif isinstance(obj, np.float16):
107 | return float(obj)
108 | elif isinstance(obj, np.uint64):
109 | return int(obj)
110 | elif isinstance(obj, np.uint32):
111 | return int(obj)
112 | elif isinstance(obj, np.uint16):
113 | return int(obj)
114 | return json.JSONEncoder.default(self, obj)
115 |
116 |
117 | class JsonSerializableClass():
118 | def to_json(self,json_filename=None):
119 | serialized_dict = json.dumps(
120 | self,
121 | default=lambda o: o.__dict__,
122 | sort_keys=True,
123 | indent=4)
124 | serialized_dict = json.loads(serialized_dict)
125 | if json_filename is not None:
126 | dump_json_object(serialized_dict,json_filename)
127 |
128 | return serialized_dict
129 |
130 | def from_json(self,json_filename):
131 | assert(type(json_filename is dict)), 'Use from dict instead'
132 | dict_to_restore = load_json_object(json_filename)
133 | for attr_name, attr_value in dict_to_restore.items():
134 | setattr(self,attr_name,attr_value)
135 |
136 | def from_dict(self,dict_to_restore):
137 | for attr_name, attr_value in dict_to_restore.items():
138 | setattr(self,attr_name,attr_value)
139 |
140 |
141 | class WritableToFile():
142 | def to_file(self,filename):
143 | with open(filename,'w') as file:
144 | file.write(self.__str__())
--------------------------------------------------------------------------------
/eval_instructions.txt:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------#
2 | Steps to be followed
3 | # ------------------------------------------------------------------------------------------------#
4 |
5 |
6 | 1. git clone https://github.com/lalithjets/Global-reasoned-multi-task-model.git
7 | 2. cd Global-reasoned-multi-task-model/
8 |
9 |
10 | # ------------------------- Download Commands ------------------------- #
11 |
12 | # ------------------------- Checkpoints ------------------------- #
13 | Link : https://drive.google.com/file/d/1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l/view?usp=sharing
14 |
15 | Command : (GDrive wget download - Optional) - Can be downloaded manually and placed in root
16 | > 3. wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l" -O gr_mtl_ssu_checkpoints.zip && rm -rf /tmp/cookies.txt
17 |
18 | 4. unzip gr_mtl_ssu_checkpoints.zip
19 | 5. rm gr_mtl_ssu_checkpoints.zip
20 |
21 | # ------------------------- Dataset ------------------------- #
22 | Link : https://drive.google.com/file/d/1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0/view?usp=sharing
23 |
24 | Command : (GDrive wget download - Optional) - Can be downloaded manually and placed in root
25 | > 6. wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0" -O gr_mtl_ssu_dataset.zip && rm -rf /tmp/cookies.txt
26 |
27 | 7. unzip gr_mtl_ssu_dataset.zip
28 | 8. rm gr_mtl_ssu_dataset.zip
29 |
30 | 9. Set the model_type, ver, seg_mode and checkpoint_dir in evaluation.py as given in instructions
31 |
32 | # ------------------------- Run the command for Evaluation ------------------------- #
33 | 10. CUDA_VISIBLE_DEVICES=1 python3 evaluation.py
34 |
35 |
36 | # --------------------------------------------- Sample Output --------------------------------------------- #
37 |
38 | Settings :
39 |
40 | model_type = 'amtl-t0'
41 | ver = 'amtl_t0_sv1'
42 | seg_mode = 'v1'
43 | checkpoint_dir = 'amtl_t0_sv1'
44 |
45 | # ------------------------------------------------------------------------------------------------#
46 | Output
47 | # ------------------------------------------------------------------------------------------------#
48 |
49 | ================= Evaluation ====================
50 | Graph : acc: 0.7003 map: 0.2885 recall: 0.3096 loss: 0.3764}
51 | Segmentation : Pacc: 0.9638 mIoU: 0.4354 loss: 0.1500}
52 |
53 | ================= Class-wise IoU ====================
54 | Mean Value: 0.435358693711956
55 |
56 | | Class | IoU |
57 | |---------------------------+------------|
58 | | Background | 0.971428 |
59 | | Bipolar_Forceps | 0.696591 |
60 | | Prograsp_Forceps | 0.435617 |
61 | | Large_Needle_Driver | 0.00154275 |
62 | | Monopolar_Curved_Scissors | 0.871583 |
63 | | Ultrasound_Probe | 0.120284 |
64 | | Suction_Instrument | 0.347132 |
65 | | Clip_Applier | 0.0386921 |
66 |
67 |
68 |
69 | # ------------------------------------------------------------------------------------------------#
70 | Eval Repository Structure
71 | # ------------------------------------------------------------------------------------------------#
72 |
73 | ├── checkpoints
74 | │ ├── amtl_t0_s
75 | │ │ └── best_epoch.pth
76 | │ ├── amtl_t0_sv1
77 | │ │ └── best_epoch.pth
78 | │ ├── amtl_t0_sv2gc
79 | │ │ └── best_epoch.pth
80 | │ ├── amtl_t3g_sv1
81 | │ │ └── best_epoch.pth
82 | │ ├── amtl_t3pn_sv1
83 | │ │ └── best_epoch.pth
84 | │ ├── mtl_kd_t0_s
85 | │ │ └── best_epoch.pth
86 | │ ├── mtl_kd_t0_sv1
87 | │ │ └── best_epoch.pth
88 | │ ├── mtl_kd_t1_sv1
89 | │ │ └── best_epoch.pth
90 | │ ├── mtl_kd_t3g_sv1
91 | │ │ └── best_epoch.pth
92 | │ ├── stl_s
93 | │ │ └── best_epoch.pth
94 | │ ├── stl_sg
95 | │ │ └── best_epoch.pth
96 | │ ├── stl_s_ng
97 | │ │ └── best_epoch.pth
98 | │ ├── stl_s_v1
99 | │ │ └── best_epoch.pth
100 | │ └── stl_s_v2gc
101 | │ └── best_epoch.pth
102 | ├── dataset
103 | │ ├── labels_isi_dataset.json
104 | │ ├── seq_1
105 | │ │ ├── annotations
106 | │ │ │ ├── frame000.png
107 | │ │ │ ├── ...
108 | │ │ ├── left_frames
109 | │ │ │ ├── frame000.png
110 | │ │ │ ├── ...
111 | │ │ ├── vsgat
112 | │ │ │ └── features
113 | │ │ │ ├── frame000_features.hdf5
114 | │ │ │ ├── ...
115 | │ │ └── xml
116 | │ │ ├── frame000.xml
117 | │ │ ├── ...
118 | │ ├── seq_16
119 | │ │ ├── annotations
120 | │ │ │ ├── frame000.png
121 | │ │ │ ├── ...
122 | │ │ ├── left_frames
123 | │ │ │ ├── frame000.png
124 | │ │ │ ├── ...
125 | │ │ ├── vsgat
126 | │ │ │ └── features
127 | │ │ │ ├── frame000_features.hdf5
128 | │ │ │ ├── ...
129 | │ │ └── xml
130 | │ │ ├── frame000.xml
131 | │ │ ├── ...
132 | │ ├── seq_5
133 | │ │ ├── annotations
134 | │ │ │ ├── frame000.png
135 | │ │ │ ├── ...
136 | │ │ ├── left_frames
137 | │ │ │ ├── frame000.png
138 | │ │ │ ├── ...
139 | │ │ ├── vsgat
140 | │ │ │ └── features
141 | │ │ │ ├── frame000_features.hdf5
142 | │ │ │ ├── ...
143 | │ │ └── xml
144 | │ │ ├── frame000.xml
145 | │ │ ├── ...
146 | │ └── surgicalscene_word2vec.hdf5
147 | ├── environment.yml
148 | ├── evaluation.py
149 | ├── eval_instructions.txt
150 | ├── figures
151 | │ ├── figure_1.pdf
152 | │ ├── figure_2.pdf
153 | │ ├── figure_3.pdf
154 | │ ├── figure_4.pdf
155 | │ └── figure_5.pdf
156 | ├── models
157 | │ ├── mtl_model.py
158 | │ ├── __pycache__
159 | │ │ ├── mtl_model.cpython-36.pyc
160 | │ │ ├── scene_graph.cpython-36.pyc
161 | │ │ ├── segmentation_model.cpython-36.pyc
162 | │ │ └── surgicalDataset.cpython-36.pyc
163 | │ ├── scene_graph.py
164 | │ ├── segmentation_model.py
165 | │ └── surgicalDataset.py
166 | ├── model_train.py
167 | ├── README.md
168 | ├── result_logs
169 | │ ├── results_combined
170 | │ └── results_kd.txt
171 | └── utils
172 | ├── io.py
173 | ├── __pycache__
174 | │ ├── scene_graph_eval_matrix.cpython-36.pyc
175 | │ └── segmentation_eval_matrix.cpython-36.pyc
176 | ├── scene_graph_eval_matrix.py
177 | ├── segmentation_eval_matrix.py
178 | ├── utils.py
179 | └── vis_tool.py
180 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Global-Reasoned Multi-Task Model for Surgical Scene Understanding
6 |
7 | Seenivasan lalithkumar, Sai Mitheran, Mobarakol Islam, Hongliang Ren
8 |
9 |
10 |
11 | ---
12 |
13 | | **[ [```arXiv```]() ]** |**[ [```Paper```]() ]** |**[ [```YouTube```]() ]** |
14 | |:-------------------:|:-------------------:|:-------------------:|
15 |
16 | ICRA 2022, IEEE Robotics and Automation Letters (RA-L)
17 |
18 |
19 |
20 | If you find our code or paper useful, please cite as
21 |
22 | ```bibtex
23 | @article{seenivasan2022global,
24 | title={Global-Reasoned Multi-Task Learning Model for Surgical Scene Understanding},
25 | author={Seenivasan, Lalithkumar and Mitheran, Sai and Islam, Mobarakol and Ren, Hongliang},
26 | journal={IEEE Robotics and Automation Letters},
27 | year={2022},
28 | publisher={IEEE}
29 | }
30 | ```
31 |
32 | ---
33 |
34 | ## Introduction
35 | Global and local relational reasoning enable scene understanding models to perform human-like scene analysis and understanding. Scene understanding enables better semantic segmentation and object-to-object interaction detection. In the medical domain, a robust surgical scene understanding model allows the automation of surgical skill evaluation, real-time monitoring of surgeon’s performance and post-surgical analysis. This paper introduces a globally-reasoned multi-task surgical scene understanding model capable of performing instrument segmentation and tool-tissue interaction detection. Here, we incorporate global relational reasoning in the latent interaction space and introduce multi-scale local (neighborhood) reasoning in the coordinate space to improve segmentation. Utilizing the multi-task model setup, the performance of the visual-semantic graph attention network in interaction detection is further enhanced through global reasoning. The global interaction space features from the segmentation module are introduced into the graph network, allowing it to detect interactions based on both node-to-node and global interaction reasoning. Our model reduces the computation cost compared to running two independent single-task models by sharing common modules, which is indispensable for practical applications. Using a sequential optimization technique, the proposed multi-task model outperforms other state-of-the-art single-task models on the MICCAI endoscopic vision challenge 2018 dataset. Additionally, we also observe the performance of the multi-task model when trained using the knowledge distillation technique.
36 |
37 | ## Method
38 |
39 | 
40 |
41 | The proposed network architecture. The proposed globally-reasoned multi-task scene understanding model consists of a shared feature extractor. The segmentation module performs latent global reasoning (GloRe unit [2]) and local reasoning (multi-scale local reasoning) to segment instruments. To detect tool interaction, the scene graph (tool interaction detection) model incorporates the global interaction space features to further improve the performance of the visual-semantic graph attention network [1].
42 |
43 | ### Feature Sharing
44 |
45 |
46 |
47 |
48 |
49 | Variants of feature sharing between the segmentation and scene graph modules in multi-task setting to improve single-task performance
50 |
51 | ---
52 |
53 | ## Directory setup
54 |
55 | In this project, we implement our method using the Pytorch and DGL library, the structure is as follows:
56 |
57 | - `dataset/`: Contains the data needed to train the network.
58 | - `checkpoints/`: Contains trained weights.
59 | - `models/`: Contains network models.
60 | - `utils/`: Contains utility tools used for training and evaluation.
61 |
62 | ---
63 |
64 | ## Library Prerequisities
65 |
66 | ### DGL
67 | DGL is a Python package dedicated to deep learning on graphs, built atop existing tensor DL frameworks (e.g. Pytorch, MXNet) and simplifying the implementation of graph-based neural networks
68 |
69 | ### Dependencies (Used for Experiments)
70 | - Python 3.6
71 | - Pytorch 1.7.1
72 | - DGL 0.4.2
73 | - CUDA 10.2
74 | - Ubuntu 16.04
75 |
76 | ---
77 |
78 | ## Setup (From an Env File)
79 |
80 | We have provided environment files for installation using conda
81 |
82 | ### Using Conda
83 |
84 | ```bash
85 | conda env create -f environment.yml
86 | ```
87 |
88 | ---
89 |
90 | ## Dataset:
91 | 1. Frames - Left camera images from [2018 robotic scene segmentation challenge](https://arxiv.org/pdf/2001.11190.pdf) are used in this work.
92 | 2. Instrument label - To be released!
93 | 3. BBox and Tool-Tissue interaction annotation - [Our annotations](https://drive.google.com/file/d/16G_Pf4E9KjVq7j_7BfBKHg0NyQQ0oTxP/view?usp=sharing) (Cite this paper / [our previous work](https://link.springer.com/chapter/10.1007/978-3-030-59716-0_60) when using these annotations.)
94 | 4. Download the pretrain word2vec model on [GoogleNews](https://code.google.com/archive/p/word2vec/) and put it into `dataset/word2vec`
95 |
96 | ---
97 |
98 | ## Training
99 | ### Process dataset (For Spatial Features)
100 | - To be released!
101 |
102 | ### Run training
103 | - Set the model_type, version for the mode to be trained according to the instructions given in the train file
104 |
105 | ```bash
106 | python3 model_train.py
107 | ```
108 |
109 | ---
110 | ## Evaluation
111 | For the direct sequence of commands to be followed, refer to [this link](https://github.com/lalithjets/Global-reasoned-multi-task-model/blob/master/eval_instructions.txt)
112 |
113 | ### Pre-trained Models
114 | Download from **[[`Checkpoints Link`](https://drive.google.com/file/d/1HTSYta_Dn9-nF1Df4TUym38Nu0VMtl5l/view?usp=sharing)]**, place it inside the repository root and unzip
115 |
116 | ### Evaluation Data
117 | Download from **[[`Dataset Link`](https://drive.google.com/file/d/1OwWfgBZE0W5grXVaQN63VUUaTvufEmW0/view?usp=sharing)]** and place it inside the repository root and unzip
118 |
119 | ### Inference
120 | To reproduce the results, set the model_type, ver, seg_mode and checkpoint_dir based on the table given [here](https://github.com/lalithjets/Global-reasoned-multi-task-model/blob/c6668fcca712d3bd5ca25c66b11d34305103af94/evaluation.py#L195)
121 | - model_type
122 | - ver
123 | - seg_mode
124 | - checkpoint_dir
125 |
126 | ```bash
127 | python3 evaluation.py
128 | ```
129 |
130 | ---
131 |
132 | ## Acknowledgement
133 | Code adopted and modified from :
134 | 1. Visual-Semantic Graph Attention Network for Human-Object Interaction Detecion
135 | - Paper [Visual-Semantic Graph Attention Network for Human-Object Interaction Detecion](https://arxiv.org/abs/2001.02302).
136 | - Official Pytorch implementation [code](https://github.com/birlrobotics/vs-gats).
137 | 1. Graph-Based Global Reasoning Networks
138 | - Paper [Graph-Based Global Reasoning Networks](https://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Graph-Based_Global_Reasoning_Networks_CVPR_2019_paper.pdf).
139 | - Official code implementation [code](https://github.com/facebookresearch/GloRe.git).
140 |
141 | ---
142 |
143 | ## Other Works:
144 | 1. Learning and Reasoning with the Graph Structure Representation in Robotic Surgery| **[ [```arXiv```]() ]** |**[ [```Paper```]() ]** |
145 |
146 | ---
147 |
148 | ## Contact
149 |
150 | For any queries, please contact [Lalithkumar](mailto:lalithjets@gmail.com) or [Sai Mitheran](mailto:saimitheran06@gmail.com)
151 |
--------------------------------------------------------------------------------
/models/surgicalDataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding
3 | Lab : MMLAB, National University of Singapore
4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren
5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation
6 |
7 | '''
8 |
9 |
10 | import os
11 | import sys
12 | import random
13 |
14 | import h5py
15 | import numpy as np
16 | from glob import glob
17 | from PIL import Image
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torchvision.transforms as transforms
22 | from torch.utils.data import Dataset
23 |
24 |
25 | class SurgicalSceneConstants():
26 | '''
27 | Set the instrument classes and action classes, with path to XML and Word2Vec Features (if applicable)
28 | '''
29 | def __init__(self):
30 | self.instrument_classes = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
31 | 'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
32 | 'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')
33 |
34 | self.action_classes = ('Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
35 | 'Tool_Manipulation', 'Cutting', 'Cauterization',
36 | 'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple',
37 | 'Ultrasound_Sensing')
38 |
39 | self.xml_data_dir = 'dataset/instruments18/seq_'
40 | self.word2vec_loc = 'dataset/surgicalscene_word2vec.hdf5'
41 |
42 |
43 | class SurgicalSceneDataset(Dataset):
44 | '''
45 | Dataset class for the MTL Model
46 | Inputs: sequence set, data directory (root), image directory, mask directory, augmentation flag (istrain), dataset (dset), feature extractor chosen
47 | '''
48 | def __init__(self, seq_set, data_dir, img_dir, mask_dir, istrain, dset, dataconst, feature_extractor, reduce_size=False):
49 |
50 | self.data_size = 143
51 | self.dataconst = dataconst
52 | self.img_dir = img_dir
53 | self.mask_dir = mask_dir
54 | self.is_train = istrain
55 | self.feature_extractor = feature_extractor
56 | self.reduce_size = reduce_size
57 |
58 | # Images and masks are resized to (320, 400)
59 | self.resizer = transforms.Compose([transforms.Resize((320, 400))])
60 |
61 | self.xml_dir_list = []
62 | self.dset = []
63 |
64 | for domain in range(len(seq_set)):
65 | domain_dir_list = []
66 | for i in seq_set[domain]:
67 | xml_dir_temp = data_dir[domain] + str(i) + '/xml/'
68 | domain_dir_list = domain_dir_list + glob(xml_dir_temp + '/*.xml')
69 | if self.reduce_size:
70 | indices = np.random.permutation(len(domain_dir_list))
71 | domain_dir_list = [domain_dir_list[j] for j in indices[0:self.data_size]]
72 | for file in domain_dir_list:
73 | self.xml_dir_list.append(file)
74 | self.dset.append(dset[domain])
75 | self.word2vec = h5py.File('dataset/surgicalscene_word2vec.hdf5', 'r')
76 |
77 | # Word2Vec function
78 | def _get_word2vec(self, node_ids, sgh=0):
79 | word2vec = np.empty((0, 300))
80 | for node_id in node_ids:
81 | if sgh == 1 and node_id == 0:
82 | vec = self.word2vec['tissue']
83 | else:
84 | vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
85 | word2vec = np.vstack((word2vec, vec))
86 | return word2vec
87 |
88 | # Dataset length
89 | def __len__(self):
90 | return len(self.xml_dir_list)
91 |
92 | # Function to get images and masks
93 | def __getitem__(self, idx):
94 |
95 | file_name = os.path.splitext(os.path.basename(self.xml_dir_list[idx]))[0]
96 | file_root = os.path.dirname(os.path.dirname(self.xml_dir_list[idx]))
97 | if len(self.img_dir) == 1:
98 | _img_loc = os.path.join(file_root+self.img_dir[0] + file_name + '.png')
99 | _mask_loc = os.path.join(file_root+self.mask_dir[0] + file_name + '.png')
100 |
101 | else:
102 | _img_loc = os.path.join( file_root+self.img_dir[self.dset[idx]] + file_name + '.png')
103 | _mask_loc = os.path.join( file_root+self.mask_dir[self.dset[idx]] + file_name + '.png')
104 |
105 |
106 | _img = Image.open(_img_loc).convert('RGB')
107 | _target = Image.open(_mask_loc)
108 |
109 | if self.is_train:
110 | isAugment = random.random() < 0.5
111 | if isAugment:
112 | isHflip = random.random() < 0.5
113 | if isHflip:
114 | _img = _img.transpose(Image.FLIP_LEFT_RIGHT)
115 | _target = _target.transpose(Image.FLIP_LEFT_RIGHT)
116 | else:
117 | _img = _img.transpose(Image.FLIP_TOP_BOTTOM)
118 | _target = _target.transpose(Image.FLIP_TOP_BOTTOM)
119 |
120 | _img = np.asarray(_img, np.float32) * 1.0 / 255
121 | _img = torch.from_numpy(np.array(_img).transpose(2, 0, 1)).float()
122 | _target = torch.from_numpy(np.array(_target)).long()
123 |
124 | frame_data = h5py.File(os.path.join( file_root+'/vsgat/'+self.feature_extractor+'/'+ file_name + '_features.hdf5'), 'r')
125 |
126 | data = {}
127 |
128 | data['img_name'] = frame_data['img_name'][()][:] + '.jpg'
129 | data['img_loc'] = _img_loc
130 |
131 | # segmentation
132 | data['img'] = self.resizer(_img.unsqueeze(0))
133 | data['mask'] = self.resizer(_target.unsqueeze(0))
134 |
135 |
136 | data['node_num'] = frame_data['node_num'][()]
137 | data['roi_labels'] = frame_data['classes'][:]
138 | data['det_boxes'] = frame_data['boxes'][:]
139 |
140 | data['edge_labels'] = frame_data['edge_labels'][:]
141 | data['edge_num'] = data['edge_labels'].shape[0]
142 |
143 | data['features'] = frame_data['node_features'][:]
144 | data['spatial_feat'] = frame_data['spatial_features'][:]
145 |
146 | data['word2vec'] = self._get_word2vec(data['roi_labels'], self.dset[idx])
147 | return data
148 |
149 |
150 | # For Dataset Loader
151 | def collate_fn(batch):
152 | '''
153 | Default collate_fn(): https://github.com/pytorch/pytorch/blob/1d53d0756668ce641e4f109200d9c65b003d05fa/torch/utils/data/_utils/collate.py#L43
154 | Inputs: Data Batch
155 | '''
156 | batch_data = {}
157 | batch_data['img_name'] = []
158 | batch_data['img_loc'] = []
159 | batch_data['img'] = []
160 | batch_data['mask'] = []
161 | batch_data['node_num'] = []
162 | batch_data['roi_labels'] = []
163 | batch_data['det_boxes'] = []
164 | batch_data['edge_labels'] = []
165 | batch_data['edge_num'] = []
166 | batch_data['features'] = []
167 | batch_data['spatial_feat'] = []
168 | batch_data['word2vec'] = []
169 |
170 | for data in batch:
171 | batch_data['img_name'].append(data['img_name'])
172 | batch_data['img_loc'].append(data['img_loc'])
173 | batch_data['img'].append(data['img'])
174 | batch_data['mask'].append(data['mask'])
175 | batch_data['node_num'].append(data['node_num'])
176 | batch_data['roi_labels'].append(data['roi_labels'])
177 | batch_data['det_boxes'].append(data['det_boxes'])
178 | batch_data['edge_labels'].append(data['edge_labels'])
179 | batch_data['edge_num'].append(data['edge_num'])
180 | batch_data['features'].append(data['features'])
181 | batch_data['spatial_feat'].append(data['spatial_feat'])
182 | batch_data['word2vec'].append(data['word2vec'])
183 |
184 | batch_data['img'] = torch.FloatTensor(np.concatenate(batch_data['img'], axis=0))
185 | batch_data['mask'] = torch.LongTensor(np.concatenate(batch_data['mask'], axis=0))
186 | batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(batch_data['edge_labels'], axis=0))
187 | batch_data['features'] = torch.FloatTensor(np.concatenate(batch_data['features'], axis=0))
188 | batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(batch_data['spatial_feat'], axis=0))
189 | batch_data['word2vec'] = torch.FloatTensor(np.concatenate(batch_data['word2vec'], axis=0))
190 |
191 | return batch_data
192 |
--------------------------------------------------------------------------------
/models/mtl_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding
3 | Lab : MMLAB, National University of Singapore
4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren
5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation
6 | '''
7 |
8 | import cv2
9 | import numpy as np
10 | from PIL import Image
11 |
12 | import torch
13 | import torchvision
14 | import torch.nn as nn
15 |
16 | class mtl_model(nn.Module):
17 | '''
18 | Multi-task model : Graph Scene Understanding and segmentation
19 | Forward uses features from feature_extractor
20 | '''
21 |
22 | def __init__(self, feature_encoder, scene_graph, seg_gcn_block, seg_decoder, seg_mode = None):
23 | super(mtl_model, self).__init__()
24 | self.feature_encoder = feature_encoder
25 | self.gcn_unit = seg_gcn_block
26 | self.seg_mode = seg_mode
27 | self.seg_decoder = seg_decoder
28 | self.scene_graph = scene_graph
29 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
30 | self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
31 |
32 | def model_type1_insert(self):
33 | self.sg_avgpool = nn.AdaptiveAvgPool1d(1)
34 | self.sg_linear = nn.Linear(1040, 128)
35 | self.sg_feat_s1d1 = nn.Conv1d(1, 1, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
36 |
37 | def model_type2_insert(self):
38 | self.sg2_linear = nn.Linear(1040, 128)
39 |
40 | def model_type3_insert(self):
41 | # self.sf_avgpool = nn.AdaptiveAvgPool2d((1, 1))
42 | self.sf_avgpool = nn.AdaptiveAvgPool1d(1)
43 | #self.sf_linear = nn.Linear(256, 128)
44 |
45 | def set_train_test(self, model_type):
46 | ''' train Feature extractor for scene graph '''
47 | # if model_type == 'stl-s' or model_type == 'amtl-t0' or model_type == 'amtl-t3' or model_type == 'stl-sg':
48 | if model_type == 'stl-s' or model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
49 | self.train_FE_SG = False
50 | else:
51 | self.train_FE_SG = True
52 |
53 | ''' train feature extractor for segmentation '''
54 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
55 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3':# or model_type == 'amtl-t1':
56 | self.Train_FE_SEG = False
57 | else:
58 | self.Train_FE_SEG = True
59 |
60 | ''' train scene graph'''
61 | # set train flag for scene graph
62 | if model_type == 'stl-s':
63 | self.Train_SG = False
64 | else:
65 | self.Train_SG = True
66 |
67 | ''' train segmentation GR-unit (Global-Reasoniing unit) '''
68 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
69 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
70 | self.Train_SEG_GR = False
71 | else:
72 | self.Train_SEG_GR = True
73 |
74 | ''' train segmentation decoder '''
75 | # set train flag for segmentation decoder
76 | # if model_type == 'stl-sg' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
77 | if model_type == 'stl-sg' or model_type == 'stl-sg-wfe' or model_type == 'amtl-t0' or model_type == 'amtl-t3':
78 | self.Train_SG_DECODER = False
79 | else:
80 | self.Train_SG_DECODER = True
81 |
82 | self.model_type = model_type
83 |
84 |
85 | def forward(self, img, img_dir, det_boxes_all, node_num, spatial_feat, word2vec, roi_labels, validation=False):
86 |
87 | gsu_node_feat = None
88 | seg_inputs = None
89 | interaction = None
90 | imsize = img.size()[2:]
91 |
92 | # ====================================================== Extract node features for Scene graph ==============================================================
93 | if not self.train_FE_SG:
94 | ''' skip training the feature extractor for scene graph '''
95 | with torch.no_grad():
96 | for index, img_loc in enumerate(img_dir):
97 | _img = Image.open(img_loc).convert('RGB')
98 | _img = np.array(_img)
99 | img_stack = None
100 | for bndbox in det_boxes_all[index]:
101 | roi = np.array(bndbox).astype(int)
102 | roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
103 | roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
104 | roi_image = torch.autograd.Variable(roi_image.unsqueeze(0))
105 | # stack nodes images per image
106 | img_stack = roi_image if img_stack == None else torch.cat((img_stack, roi_image))
107 |
108 | img_stack = img_stack.cuda(non_blocking=True)
109 | _, _, _, img_stack = self.feature_encoder(img_stack)
110 |
111 | img_stack = self.avgpool(img_stack)
112 | img_stack = img_stack.view(img_stack.size(0), -1)
113 |
114 | # # prepare graph node features
115 | gsu_node_feat = img_stack if gsu_node_feat == None else torch.cat((gsu_node_feat, img_stack))
116 |
117 | else:
118 | # print('node_info grad enabled')
119 | for index, img_loc in enumerate(img_dir):
120 | _img = Image.open(img_loc).convert('RGB')
121 | _img = np.array(_img)
122 | img_stack = None
123 | for bndbox in det_boxes_all[index]:
124 | roi = np.array(bndbox).astype(int)
125 | roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
126 | roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
127 | roi_image = torch.autograd.Variable(roi_image.unsqueeze(0))
128 | # stack nodes images per image
129 | img_stack = roi_image if img_stack == None else torch.cat((img_stack, roi_image))
130 |
131 | img_stack = img_stack.cuda(non_blocking=True)
132 | _, _, _, img_stack = self.feature_encoder(img_stack)
133 | img_stack = self.avgpool(img_stack)
134 | img_stack = img_stack.view(img_stack.size(0), -1)
135 | # prepare graph node features
136 | gsu_node_feat = img_stack if gsu_node_feat == None else torch.cat((gsu_node_feat, img_stack))
137 | # ================================================================================================================================================================
138 | # ===================================================== Segmentation feature extractor ===========================================================================
139 | if not self.Train_FE_SEG:
140 | ''' Skip training feature encoder for segmentation task '''
141 | with torch.no_grad():
142 | s1, s2, s3, seg_inputs = self.feature_encoder(img)
143 | fe_feat = seg_inputs
144 | else:
145 | # print('segment encoder enabled')
146 | s1, s2, s3, seg_inputs = self.feature_encoder(img)
147 | fe_feat = seg_inputs
148 | # ================================================================================================================================================================
149 | # ================================================= Scene graph and segmentation GR (Global Reasoning) unit ======================================================
150 | if self.model_type == 'amtl-t1' or self.model_type == 'mtl-t1':
151 | '''
152 | In type 1, interaction features are passed to segmentation GR (Global Reasoning) module.
153 | inside GR unit, (x = x + h + avg((x)T) * sg_feat[1x128])
154 | Here interation is called before GR unit.
155 | '''
156 | ''' ==== scene graph ==== '''
157 | # print('inside mtl-1')
158 | interaction, sg_feat = self.scene_graph(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels, validation= validation)
159 |
160 | ''' ==== GR (Global Reasoning) ==== '''
161 | edge_sum = 0
162 | batch_sg_feat = None
163 | for n in node_num:
164 | active_edges = n-1 if n >1 else n
165 | if batch_sg_feat == None:
166 | batch_sg_feat = self.sg_linear(self.sg_avgpool(sg_feat[edge_sum:edge_sum+active_edges, :].unsqueeze(0).permute(0,2,1)).permute(0,2,1))
167 | else:
168 | batch_sg_feat = torch.cat((batch_sg_feat, self.sg_linear(self.sg_avgpool(sg_feat[edge_sum:edge_sum+active_edges, :].unsqueeze(0).permute(0,2,1)).permute(0,2,1))))
169 | edge_sum += active_edges
170 | batch_sg_feat = self.sg_feat_s1d1(batch_sg_feat)
171 | s1, s2, s3, seg_inputs, _ = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, scene_feat = batch_sg_feat, seg_mode = self.seg_mode, model_type = self.model_type)
172 |
173 | elif self.model_type == 'amtl-t2' or self.model_type == 'mtl-t2':
174 | '''
175 | In type 2, interaction features are passed to segmentation GR module. Replace
176 | inside GR, GCN is replaced with x = x * sg_feat [128 x 128]
177 | Here interation is called before GR unit.
178 | '''
179 | ''' ==== scene graph ==== '''
180 | interaction, sg_feat = self.scene_graph(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels, validation= validation)
181 |
182 | ''' ==== GR (Global Reasoning) ==== '''
183 | edge_sum = 0
184 | batch_sg_feat = None
185 | for n in node_num:
186 | active_edges = n-1 if n >1 else n
187 | if batch_sg_feat == None:
188 | batch_sg_feat = torch.matmul(self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :]).permute(1, 0), \
189 | self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :])).unsqueeze(0)
190 | else:
191 | batch_sg_feat = torch.cat((batch_sg_feat, torch.matmul(self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :]).permute(1, 0), \
192 | self.sg2_linear(sg_feat[edge_sum:edge_sum+active_edges, :])).unsqueeze(0)))
193 | edge_sum += active_edges
194 | s1, s2, s3, seg_inputs, _ = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, scene_feat = batch_sg_feat, seg_mode = self.seg_mode, model_type = self.model_type)
195 |
196 | else:
197 | '''
198 | If it's not type 1 & 2, then GR is processed before interaction.
199 | '''
200 | ''' ==== GR (Global Reasoning) ==== '''
201 | if not self.Train_SEG_GR:
202 | ''' skip GR unit training '''
203 | with torch.no_grad():
204 | s1, s2, s3, seg_inputs, gi_feat = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, seg_mode = self.seg_mode, model_type = self.model_type)
205 | else:
206 | # print('segment gcn enabled')
207 | s1, s2, s3, seg_inputs, gi_feat = self.gcn_unit(seg_inputs, s1=s1, s2=s2, s3=s3, seg_mode = self.seg_mode, model_type = self.model_type)
208 |
209 | ''' ==== scene graph ==== '''
210 | if self.model_type == 'amtl-t3' or self.model_type == 'mtl-t3':
211 | gr_int_feat = self.sf_avgpool(gi_feat).view(gi_feat.size(0), 128)
212 |
213 | edge_sum = 0
214 | global_spatial_feat = None
215 |
216 | for b_i, n in enumerate(node_num):
217 | active_edges = (n*(n-1)) if n >1 else n
218 | if global_spatial_feat == None:
219 | global_spatial_feat = torch.cat((spatial_feat[edge_sum:edge_sum+active_edges, :], gr_int_feat[b_i,:].repeat(active_edges,1)),1)
220 | else:
221 | global_spatial_feat = torch.cat((global_spatial_feat, torch.cat((spatial_feat[edge_sum:edge_sum+active_edges, :], gr_int_feat[b_i,:].repeat(active_edges,1)),1)))
222 | edge_sum += active_edges
223 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= validation)
224 | elif not self.Train_SG:
225 | ''' skip scene graph training '''
226 | with torch.no_grad():
227 | global_spatial_feat = spatial_feat
228 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= True)
229 | else:
230 | # print('interaction encoder enabled')
231 | global_spatial_feat = spatial_feat
232 | interaction, _ = self.scene_graph(node_num, gsu_node_feat, global_spatial_feat, word2vec, roi_labels, validation= validation)
233 |
234 | # ================================================================================================================================================================
235 | # ================================================= Scene graph and segmentation GR Unit =========================================================================
236 |
237 | ''' ============== Segmentation decoder =============='''
238 | if not self.Train_SG_DECODER:
239 | ''' skip segmentation decoder '''
240 | with torch.no_grad():
241 | seg_inputs = self.seg_decoder(seg_inputs, s1 = s1, s2 = s2, s3 =s3, imsize = imsize, seg_mode = self.seg_mode)
242 |
243 | else:
244 | # print('segment_decoder_enabled')
245 | seg_inputs = self.seg_decoder(seg_inputs, s1 = s1, s2 = s2, s3 =s3, imsize = imsize, seg_mode = self.seg_mode)
246 | # ================================================================================================================================================================
247 |
248 | return interaction, seg_inputs, fe_feat
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | #from functools import lru_cache
2 | import os
3 | import time
4 | import json
5 |
6 | import argparse
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch import optim
13 | import torch.nn.functional as F
14 | from torch.utils.data import DataLoader
15 |
16 | from models.mtl_model import *
17 | from models.scene_graph import *
18 | from models.surgicalDataset import *
19 | from models.segmentation_model import get_gcnet # for the get_gcnet function
20 |
21 | from utils.scene_graph_eval_matrix import *
22 | from utils.segmentation_eval_matrix import * # SegmentationLoss and Eval code
23 |
24 | from tabulate import tabulate
25 |
26 | import torch.multiprocessing as mp
27 | import torch.distributed as dist
28 | from torch.nn.parallel import DistributedDataParallel as DDP
29 |
30 | import warnings
31 | warnings.filterwarnings('ignore')
32 |
33 | def label_to_index(lbl):
34 | '''
35 | Label to index mapping
36 | Input: class label
37 | Output: class index
38 | '''
39 | return torch.tensor(map_dict.index(lbl))
40 |
41 |
42 | def index_to_label(index):
43 | '''
44 | Index to label mapping
45 | Input: class index
46 | Output: class label
47 | '''
48 | return map_dict[index]
49 |
50 |
51 |
52 | def seed_everything(seed=27):
53 | '''
54 | Set random seed for reproducible experiments
55 | Inputs: seed number
56 | '''
57 | torch.manual_seed(seed)
58 | torch.cuda.manual_seed_all(seed)
59 | os.environ['PYTHONHASHSEED'] = str(seed)
60 | torch.backends.cudnn.deterministic = True
61 | torch.backends.cudnn.benchmark = False
62 |
63 |
64 | def seg_eval_batch(seg_output, target):
65 | '''
66 | Calculate segmentation loss, pixel acc and IoU
67 | '''
68 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2)
69 | loss = seg_criterion(seg_output, target)
70 | correct, labeled = batch_pix_accuracy(seg_output.data, target)
71 | inter, union = batch_intersection_union(seg_output.data, target, 8) # 8 is num classes
72 | return correct, labeled, inter, union, loss
73 |
74 |
75 | def build_model(args):
76 | '''
77 | Build MTL model
78 | 1) Scene Graph Understanding Model
79 | 2) Segmentation Model : Encoder, Reasoning unit, Decoder
80 |
81 | Inputs: args
82 | '''
83 |
84 | '''==== Graph model ===='''
85 | # graph model
86 | scene_graph = AGRNN(bias=True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, global_feat=args.global_feat)
87 |
88 | # segmentation model
89 | seg_model = get_gcnet(backbone='resnet18_model', pretrained=False)
90 | model = mtl_model(seg_model.pretrained, scene_graph, seg_model.gr_interaction, seg_model.gr_decoder, seg_mode = args.seg_mode)
91 | model.to(torch.device('cpu'))
92 | return model
93 |
94 |
95 |
96 | def model_eval(model, validation_dataloader, nclass=8):
97 | '''
98 | Evaluate MTL
99 | '''
100 |
101 | model.eval()
102 |
103 | class_values = np.zeros(nclass)
104 |
105 | # graph
106 | scene_graph_criterion = nn.MultiLabelSoftMarginLoss()
107 | scene_graph_edge_count = 0
108 | scene_graph_total_acc = 0.0
109 | scene_graph_total_loss = 0.0
110 | scene_graph_logits_list = []
111 | scene_graph_labels_list = []
112 |
113 | test_seg_loss = 0.0
114 | total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
115 |
116 |
117 | for data in tqdm(validation_dataloader):
118 | seg_img = data['img']
119 | seg_masks = data['mask']
120 | img_loc = data['img_loc']
121 | node_num = data['node_num']
122 | roi_labels = data['roi_labels']
123 | det_boxes = data['det_boxes']
124 | edge_labels = data['edge_labels']
125 | spatial_feat = data['spatial_feat']
126 | word2vec = data['word2vec']
127 |
128 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True)
129 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True)
130 |
131 | with torch.no_grad():
132 | interaction, seg_outputs, _ = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True)
133 |
134 | scene_graph_logits_list.append(interaction)
135 | scene_graph_labels_list.append(edge_labels)
136 |
137 | # loss and accuracy
138 | scene_graph_loss = scene_graph_criterion(interaction, edge_labels.float())
139 | scene_graph_acc = np.sum(np.equal(np.argmax(interaction.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
140 | correct, labeled, inter, union, t_loss = seg_eval_batch(seg_outputs, seg_masks)
141 |
142 | # accumulate scene graph loss and acc
143 | scene_graph_total_loss += scene_graph_loss.item() * edge_labels.shape[0]
144 | scene_graph_total_acc += scene_graph_acc
145 | scene_graph_edge_count += edge_labels.shape[0]
146 |
147 | total_correct += correct
148 | total_label += labeled
149 | total_inter += inter
150 | total_union += union
151 | test_seg_loss += t_loss.item()
152 |
153 | # graph evaluation
154 | scene_graph_total_acc = scene_graph_total_acc / scene_graph_edge_count
155 | scene_graph_total_loss = scene_graph_total_loss / len(validation_dataloader)
156 | scene_graph_logits_all = torch.cat(scene_graph_logits_list).cuda()
157 | scene_graph_labels_all = torch.cat(scene_graph_labels_list).cuda()
158 | scene_graph_logits_all = F.softmax(scene_graph_logits_all, dim=1)
159 | scene_graph_map_value, scene_graph_recall = calibration_metrics(scene_graph_logits_all, scene_graph_labels_all)
160 |
161 | # segmentation evaluation
162 | pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
163 | IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
164 | class_values += IoU
165 | mIoU = IoU.mean()
166 |
167 | print('\n================= Evaluation ====================')
168 | print('Graph : acc: %0.4f map: %0.4f recall: %0.4f loss: %0.4f}' % (scene_graph_total_acc, scene_graph_map_value, scene_graph_recall, scene_graph_total_loss))
169 | print('Segmentation : Pacc: %0.4f mIoU: %0.4f loss: %0.4f}' % (pixAcc, mIoU, test_seg_loss/len(validation_dataloader)))
170 |
171 | print('\n================= Class-wise IoU ====================')
172 | class_wise_IoU = []
173 | m_vals = []
174 | for idx, value in enumerate(class_values):
175 | class_name = index_to_label(idx)
176 | pair = [class_name, value]
177 | m_vals.append(value)
178 | class_wise_IoU.append(pair)
179 |
180 | print("Mean Value: ", np.mean(np.array(m_vals)), "\n")
181 |
182 | print(tabulate(class_wise_IoU,
183 | headers=['Class', 'IoU'], tablefmt='orgtbl'))
184 |
185 | return(scene_graph_total_acc, scene_graph_map_value, mIoU)
186 |
187 |
188 | if __name__ == "__main__":
189 |
190 | '''
191 | Main function to set arguments
192 | '''
193 |
194 | '''
195 | To reproduce the results, set the model_type, ver, seg_mode and checkpoint_dir based on the table below
196 | TBR = To be released
197 | ============================================================================================================
198 | Paper_name | model_type | ver | seg_mode | checkpoint_dir
199 | ============================================================================================================
200 | STL
201 | ------------------------|-----------------------------------------------------------------------------------
202 | VS-GAT | 'stl-sg' | 'stl_sg' | None | 'stl_sg'
203 | SEG | 'stl-s' | 'stl_s_ng' | TBR | 'stl_s_ng'
204 | SEG-GR | 'stl-s' | 'stl_s' | None | 'stl_s'
205 | SEG-MSGR | 'stl-s' | 'stl_s_v2gc' | 'v2gc' | 'stl_s_v2gc'
206 | SEG-MSLRGR | 'stl-s' | 'stl_s_v1' | 'v1' | 'stl_s_v1'
207 | ------------------------------------------------------------------------------------------------------------
208 | SMTL
209 | ------------------------------------------------------------------------------------------------------------
210 | GR | 'amtl-t0' | 'amtl_t0_s' | None | 'amtl_t0_s'
211 | MSGR | 'amtl-t0' | 'amtl_t0_sv2gc' | 'v2gc' | 'amtl_t0_sv2gc'
212 | MSLRGR | 'amtl-t0' | 'amtl_t0_sv1' | 'v1' | 'amtl_t0_sv1'
213 | MSLRGR-GISFSG | 'amtl-t3' | 'amtl_t3pn_sv1' | 'v1' | 'amtl_t3pn_sv1'
214 | ------------------------------------------------------------------------------------------------------------
215 | v-MTL
216 | ------------------------------------------------------------------------------------------------------------
217 | V-MTL-GR | 'mtl-t0' | 'mtl_t0_s | None | 'mtl_t0_s'
218 | ------------------------------------------------------------------------------------------------------------
219 | KD-MTL (set args.KD = True)
220 | ------------------------------------------------------------------------------------------------------------
221 | KD-MTL-GR | 'mtl-t0' | 'mtl_kd_t0_s' | None | TBR
222 | KD-MTL-MSLRGR | 'mtl-t0' | 'mtl_kd_t0_sv1' | 'v1' | 'mtl_kd_t0_sv1'
223 | KD-MTL-MSLRGR-SGFSEG | 'mtl-t1' | 'mtl_kd_t1_sv1' | 'v1' | 'mtl_kd_t1_sv1'
224 | KD-MTL-MSLRGR-GISFSG | 'mtl-t3' | 'mtl_kd_t3_sv1' | 'v1' | 'mtl_kd_t3_sv1'
225 | ------------------------------------------------------------------------------------------------------------
226 | '''
227 |
228 | model_type = 'amtl-t3'
229 | ver = 'amtl_t3_sv1'
230 | seg_mode = 'v1'
231 | checkpoint_dir = 'amtl_t3_sv1'
232 |
233 | port = '8892'
234 |
235 | # Set random seed
236 | seed_everything()
237 | print(ver, seg_mode)
238 |
239 | # arguments
240 | parser = argparse.ArgumentParser(description='GR_MTL_SSU')
241 |
242 | # hyper parameters
243 | parser.add_argument('--lr', type=float, default = 0.00001) #0.00001
244 | parser.add_argument('--epoch', type=int, default = 130)
245 | parser.add_argument('--start_epoch', type=int, default = 0)
246 | parser.add_argument('--batch_size', type=int, default = 1)
247 | parser.add_argument('--gpu', type=bool, default = True)
248 | parser.add_argument('--train_model', type=str, default = 'epoch')
249 | parser.add_argument('--exp_ver', type=str, default = ver)
250 |
251 | # file locations
252 | parser.add_argument('--log_dir', type=str, default = './log/' + ver)
253 | parser.add_argument('--save_dir', type=str, default = './checkpoints/' + ver)
254 | parser.add_argument('--output_img_dir', type=str, default = './results/' + ver)
255 | parser.add_argument('--save_every', type=int, default = 10)
256 | parser.add_argument('--pretrained', type=str, default = None)
257 |
258 | # network
259 | parser.add_argument('--layers', type=int, default = 1)
260 | parser.add_argument('--bn', type=bool, default = False)
261 | parser.add_argument('--drop_prob', type=float, default = 0.3)
262 | parser.add_argument('--bias', type=bool, default = True)
263 | parser.add_argument('--multi_attn', type=bool, default = False)
264 | parser.add_argument('--diff_edge', type=bool, default = False)
265 | if model_type == 'mtl-t3' or model_type == 'amtl-t3':
266 | parser.add_argument('--global_feat', type=int, default = 128)
267 | else:
268 | parser.add_argument('--global_feat', type=int, default = 0)
269 | # data_processing
270 | parser.add_argument('--sampler', type=int, default = 0)
271 | parser.add_argument('--data_aug', type=bool, default = False)
272 | parser.add_argument('--feature_extractor', type=str, default = 'features')
273 | parser.add_argument('--seg_mode', type=str, default = seg_mode)
274 |
275 | # CBS
276 | parser.add_argument('--use_cbs', type=bool, default = False)
277 |
278 | # Knowledge distillation
279 | parser.add_argument('--KD', type=bool, default = False)
280 |
281 | parser.add_argument('--model', type=str, default = model_type)
282 | args = parser.parse_args()
283 |
284 | # seed_everything()
285 | data_const = SurgicalSceneConstants()
286 |
287 | label_path = 'dataset/labels_isi_dataset.json'
288 | with open(label_path) as f:
289 | labels = json.load(f)
290 |
291 | CLASSES = []
292 | CLASS_ID = []
293 |
294 | for item in labels:
295 | CLASSES.append(item['name'])
296 | CLASS_ID.append(item['classid'])
297 |
298 | map_dict = {k: v for k, v in zip(CLASS_ID, CLASSES)}
299 |
300 | # this is placed above the dist.init process, possibility because of the feature_extraction model.
301 | model = build_model(args)
302 | model.set_train_test(args.model)
303 |
304 | # insert nn layers based on type.
305 | if args.model == 'amtl-t1' or args.model == 'mtl-t1':
306 | model.model_type1_insert()
307 | elif args.model == 'amtl-t2' or args.model == 'mtl-t2':
308 | model.model_type2_insert()
309 | elif args.model == 'amtl-t3' or args.model == 'mtl-t3':
310 | model.model_type3_insert()
311 |
312 | # load pre-trained stl_mtl_model
313 | print('Loading pre-trained weights')
314 | pretrained_model = torch.load(('checkpoints/'+checkpoint_dir+'/best_epoch.pth'))
315 | model.load_state_dict(pretrained_model)
316 |
317 | # Wrap the model with ddp
318 | model.cuda()
319 |
320 | # train and test dataloader
321 | val_seq = [[1, 5, 16]]
322 | data_dir = ['dataset/seq_']
323 | img_dir = ['/left_frames/']
324 | mask_dir = ['/annotations/']
325 | dset = [0]
326 | data_const = SurgicalSceneConstants()
327 |
328 | seq = {'val_seq': val_seq, 'data_dir': data_dir, 'img_dir': img_dir, 'dset': dset, 'mask_dir': mask_dir}
329 |
330 | # val_dataset only set in 1 GPU
331 | val_dataset = SurgicalSceneDataset(seq_set=seq['val_seq'], dset=seq['dset'], data_dir=seq['data_dir'], \
332 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], istrain=False, dataconst=data_const, \
333 | feature_extractor=args.feature_extractor, reduce_size=False)
334 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
335 |
336 | model_eval(model, val_dataloader)
337 |
--------------------------------------------------------------------------------
/models/scene_graph.py:
--------------------------------------------------------------------------------
1 | '''
2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding
3 | Lab : MMLAB, National University of Singapore
4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren
5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation
6 | Visual-Semantic Graph Network:
7 | @article{liang2020visual,
8 | title={Visual-Semantic Graph Attention Networks for Human-Object Interaction Detection},
9 | author={Liang, Zhijun and Rojas, Juan and Liu, Junfa and Guan, Yisheng},
10 | journal={arXiv preprint arXiv:2001.02302},
11 | year={2020}
12 | }
13 | '''
14 |
15 |
16 | import dgl
17 | import math
18 | import numpy as np
19 |
20 | import torch
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 |
26 | from collections import OrderedDict
27 |
28 | '''
29 | Configurations of the network
30 |
31 | readout: G_ER_L_S = [1024+300+16+300+1024, 1024, 117]
32 |
33 | node_func: G_N_L_S = [1024+1024, 1024]
34 | node_lang_func: G_N_L_S2 = [300+300+300]
35 |
36 | edge_func : G_E_L_S = [1024*2+16, 1024]
37 | edge_lang_func: [300*2, 1024]
38 |
39 | attn: [1024, 1]
40 | attn_lang: [1024, 1]
41 | '''
42 | class CONFIGURATION(object):
43 | '''
44 | Configuration arguments: feature type, layer, bias, batch normalization, dropout, multi-attn
45 |
46 | readout : fc_size, activation, bias, bn, droupout
47 | gnn_node : fc_size, activation, bias, bn, droupout
48 | gnn_node_for_lang : fc_size, activation, bias, bn, droupout
49 | gnn_edge : fc_size, activation, bias, bn, droupout
50 | gnn_edge_for_lang : fc_size, activation, bias, bn, droupout
51 | gnn_attn : fc_size, activation, bias, bn, droupout
52 | gnn_attn_for_lang : fc_size, activation, bias, bn, droupout
53 | '''
54 | def __init__(self, layer=1, bias=True, bn=False, dropout=0.2, multi_attn=False, global_feat = 0):
55 |
56 | # if multi_attn:
57 | if True:
58 | if layer==1:
59 | feature_size = 512
60 | additional_sf = global_feat
61 | # readout
62 | self.G_ER_L_S = [feature_size+300+16+additional_sf+300+feature_size, feature_size, 13]
63 | self.G_ER_A = ['ReLU', 'Identity']
64 | self.G_ER_B = bias #true
65 | self.G_ER_BN = bn #false
66 | self.G_ER_D = dropout #0.3
67 | # self.G_ER_GRU = feature_size
68 |
69 | # # gnn node function
70 | self.G_N_L_S = [feature_size+feature_size, feature_size]
71 | self.G_N_A = ['ReLU']
72 | self.G_N_B = bias #true
73 | self.G_N_BN = bn #false
74 | self.G_N_D = dropout #0.3
75 | # self.G_N_GRU = feature_size
76 |
77 | # # gnn node function for language
78 | self.G_N_L_S2 = [300+300, 300]
79 | self.G_N_A2 = ['ReLU']
80 | self.G_N_B2 = bias #true
81 | self.G_N_BN2 = bn #false
82 | self.G_N_D2 = dropout #0.3
83 | # self.G_N_GRU2 = feature_size
84 |
85 | # gnn edge function1
86 | self.G_E_L_S = [feature_size*2+16+additional_sf, feature_size]
87 | self.G_E_A = ['ReLU']
88 | self.G_E_B = bias # true
89 | self.G_E_BN = bn # false
90 | self.G_E_D = dropout # 0.3
91 | # self.G_E_c_kernel_size = 3
92 |
93 |
94 | # gnn edge function2 for language
95 | self.G_E_L_S2 = [300*2, feature_size]
96 | self.G_E_A2 = ['ReLU']
97 | self.G_E_B2 = bias #true
98 | self.G_E_BN2 = bn #false
99 | self.G_E_D2 = dropout #0.3
100 |
101 | # gnn attention mechanism
102 | self.G_A_L_S = [feature_size, 1]
103 | self.G_A_A = ['LeakyReLU']
104 | self.G_A_B = bias #true
105 | self.G_A_BN = bn #false
106 | self.G_A_D = dropout #0.3
107 |
108 | # gnn attention mechanism2 for language
109 | self.G_A_L_S2 = [feature_size, 1]
110 | self.G_A_A2 = ['LeakyReLU']
111 | self.G_A_B2 = bias #true
112 | self.G_A_BN2 = bn #false
113 | self.G_A_D2 = dropout #0.3
114 |
115 | def save_config(self):
116 | model_config = {'graph_head':{}, 'graph_node':{}, 'graph_edge':{}, 'graph_attn':{}}
117 | CONFIG=self.__dict__
118 | for k, v in CONFIG.items():
119 | if 'G_H' in k:
120 | model_config['graph_head'][k]=v
121 | elif 'G_N' in k:
122 | model_config['graph_node'][k]=v
123 | elif 'G_E' in k:
124 | model_config['graph_edge'][k]=v
125 | elif 'G_A' in k:
126 | model_config['graph_attn'][k]=v
127 | else:
128 | model_config[k]=v
129 |
130 | return model_config
131 |
132 |
133 | class Identity(nn.Module):
134 | '''
135 | Identity class activation layer
136 | f(x) = x
137 | '''
138 | def __init__(self):
139 | super(Identity,self).__init__()
140 |
141 | def forward(self, x):
142 | return x
143 |
144 | def get_activation(name):
145 | '''
146 | get_activation function
147 | argument: Activation name (eg. ReLU, Identity, Tanh, Sigmoid, LeakyReLU)
148 | '''
149 | if name=='ReLU': return nn.ReLU(inplace=True)
150 | elif name=='Identity': return Identity()
151 | elif name=='Tanh': return nn.Tanh()
152 | elif name=='Sigmoid': return nn.Sigmoid()
153 | elif name=='LeakyReLU': return nn.LeakyReLU(0.2,inplace=True)
154 | else: assert(False), 'Not Implemented'
155 |
156 |
157 | class MLP(nn.Module):
158 | '''
159 | Args:
160 | layer_sizes: a list, [1024,1024,...]
161 | activation: a list, ['ReLU', 'Tanh',...]
162 | bias : bool
163 | use_bn: bool
164 | drop_prob: default is None, use drop out layer or not
165 | '''
166 | def __init__(self, layer_sizes, activation, bias=True, use_bn=False, drop_prob=None):
167 | super(MLP, self).__init__()
168 | self.bn = use_bn
169 | self.layers = nn.ModuleList()
170 | for i in range(len(layer_sizes)-1):
171 | layer = nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=bias)
172 | activate = get_activation(activation[i])
173 | block = nn.Sequential(OrderedDict([(f'L{i}', layer), ]))
174 |
175 | # !NOTE:# Actually, it is inappropriate to use batch-normalization here
176 | if use_bn:
177 | bn = nn.BatchNorm1d(layer_sizes[i+1])
178 | block.add_module(f'B{i}', bn)
179 |
180 | # batch normalization is put before activation function
181 | block.add_module(f'A{i}', activate)
182 |
183 | # dropout probablility
184 | if drop_prob:
185 | block.add_module(f'D{i}', nn.Dropout(drop_prob))
186 |
187 | self.layers.append(block)
188 |
189 | def forward(self, x):
190 | for layer in self.layers:
191 | # !NOTE: Sometime the shape of x will be [1,N], and we cannot use batch-normalization in that situation
192 | if self.bn and x.shape[0]==1:
193 | x = layer[0](x)
194 | x = layer[:-1](x)
195 | else:
196 | x = layer(x)
197 | return x
198 |
199 |
200 | class H_H_EdgeApplyModule(nn.Module): #Human to Human edge
201 | '''
202 | init : config, multi_attn
203 | forward : edge
204 | '''
205 | def __init__(self, CONFIG, multi_attn=False):
206 | super(H_H_EdgeApplyModule, self).__init__()
207 | self.edge_fc = MLP(CONFIG.G_E_L_S, CONFIG.G_E_A, CONFIG.G_E_B, CONFIG.G_E_BN, CONFIG.G_E_D)
208 | self.edge_fc_lang = MLP(CONFIG.G_E_L_S2, CONFIG.G_E_A2, CONFIG.G_E_B2, CONFIG.G_E_BN2, CONFIG.G_E_D2)
209 |
210 | def forward(self, edge):
211 | feat = torch.cat([edge.src['n_f'], edge.data['s_f'], edge.dst['n_f']], dim=1)
212 | feat_lang = torch.cat([edge.src['word2vec'], edge.dst['word2vec']], dim=1)
213 | e_feat = self.edge_fc(feat)
214 | e_feat_lang = self.edge_fc_lang(feat_lang)
215 |
216 | return {'e_f': e_feat, 'e_f_lang': e_feat_lang}
217 |
218 |
219 |
220 | class H_NodeApplyModule(nn.Module): #human node
221 | '''
222 | init : config
223 | forward : node
224 | '''
225 | def __init__(self, CONFIG):
226 | super(H_NodeApplyModule, self).__init__()
227 | self.node_fc = MLP(CONFIG.G_N_L_S, CONFIG.G_N_A, CONFIG.G_N_B, CONFIG.G_N_BN, CONFIG.G_N_D)
228 | self.node_fc_lang = MLP(CONFIG.G_N_L_S2, CONFIG.G_N_A2, CONFIG.G_N_B2, CONFIG.G_N_BN2, CONFIG.G_N_D2)
229 |
230 | def forward(self, node):
231 | feat = torch.cat([node.data['n_f'], node.data['z_f']], dim=1)
232 | feat_lang = torch.cat([node.data['word2vec'], node.data['z_f_lang']], dim=1)
233 | n_feat = self.node_fc(feat)
234 | n_feat_lang = self.node_fc_lang(feat_lang)
235 |
236 | return {'new_n_f': n_feat, 'new_n_f_lang': n_feat_lang}
237 |
238 |
239 | class E_AttentionModule1(nn.Module): #edge attention
240 | '''
241 | init : config
242 | forward : edge
243 | '''
244 | def __init__(self, CONFIG):
245 | super(E_AttentionModule1, self).__init__()
246 | self.attn_fc = MLP(CONFIG.G_A_L_S, CONFIG.G_A_A, CONFIG.G_A_B, CONFIG.G_A_BN, CONFIG.G_A_D)
247 | self.attn_fc_lang = MLP(CONFIG.G_A_L_S2, CONFIG.G_A_A2, CONFIG.G_A_B2, CONFIG.G_A_BN2, CONFIG.G_A_D2)
248 |
249 | def forward(self, edge):
250 | a_feat = self.attn_fc(edge.data['e_f'])
251 | a_feat_lang = self.attn_fc_lang(edge.data['e_f_lang'])
252 | return {'a_feat': a_feat, 'a_feat_lang': a_feat_lang}
253 |
254 |
255 | class GNN(nn.Module):
256 | '''
257 | init : config, multi_attn, diff_edge
258 | forward : g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_features
259 | '''
260 | def __init__(self, CONFIG, multi_attn=False, diff_edge=True):
261 | super(GNN, self).__init__()
262 | self.diff_edge = diff_edge # false
263 | self.apply_h_h_edge = H_H_EdgeApplyModule(CONFIG, multi_attn)
264 | self.apply_edge_attn1 = E_AttentionModule1(CONFIG)
265 | self.apply_h_node = H_NodeApplyModule(CONFIG)
266 |
267 | def _message_func(self, edges):
268 | return {'nei_n_f': edges.src['n_f'], 'nei_n_w': edges.src['word2vec'], 'e_f': edges.data['e_f'], 'e_f_lang': edges.data['e_f_lang'], 'a_feat': edges.data['a_feat'], 'a_feat_lang': edges.data['a_feat_lang']}
269 |
270 | def _reduce_func(self, nodes):
271 | alpha = F.softmax(nodes.mailbox['a_feat'], dim=1)
272 | alpha_lang = F.softmax(nodes.mailbox['a_feat_lang'], dim=1)
273 |
274 | z_raw_f = nodes.mailbox['nei_n_f']+nodes.mailbox['e_f']
275 | z_f = torch.sum( alpha * z_raw_f, dim=1)
276 |
277 | z_raw_f_lang = nodes.mailbox['nei_n_w']
278 | z_f_lang = torch.sum(alpha_lang * z_raw_f_lang, dim=1)
279 |
280 | # we cannot return 'alpha' for the different dimension
281 | if self.training or validation: return {'z_f': z_f, 'z_f_lang': z_f_lang}
282 | else: return {'z_f': z_f, 'z_f_lang': z_f_lang, 'alpha': alpha, 'alpha_lang': alpha_lang}
283 |
284 | def forward(self, g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_feat=False):
285 |
286 | g.apply_edges(self.apply_h_h_edge, g.edges())
287 | g.apply_edges(self.apply_edge_attn1)
288 | g.update_all(self._message_func, self._reduce_func)
289 | g.apply_nodes(self.apply_h_node, h_node+o_node)
290 |
291 | # !NOTE:PAY ATTENTION WHEN ADDING MORE FEATURE
292 | g.ndata.pop('n_f')
293 | g.ndata.pop('word2vec')
294 |
295 | g.ndata.pop('z_f')
296 | g.edata.pop('e_f')
297 | g.edata.pop('a_feat')
298 |
299 | g.ndata.pop('z_f_lang')
300 | g.edata.pop('e_f_lang')
301 | g.edata.pop('a_feat_lang')
302 |
303 |
304 | class GRNN(nn.Module):
305 | '''
306 | init:
307 | config, multi_attn, diff_edge
308 | forward:
309 | batch_graph, batch_h_node_list, batch_obj_node_list,
310 | batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list,
311 | features, spatial_features, word2vec,
312 | valid, pop_features, initial_features
313 | '''
314 | def __init__(self, CONFIG, multi_attn=False, diff_edge=True):
315 | super(GRNN, self).__init__()
316 | self.multi_attn = multi_attn #false
317 | self.gnn = GNN(CONFIG, multi_attn, diff_edge)
318 |
319 | def forward(self, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, valid=False, pop_feat=False, initial_feat=False):
320 |
321 | # !NOTE: if node_num==1, there will be something wrong to forward the attention mechanism
322 | global validation
323 | validation = valid
324 |
325 | # initialize the graph with some datas
326 | batch_graph.ndata['n_f'] = feat # node: features
327 | batch_graph.ndata['word2vec'] = word2vec # node: words
328 | batch_graph.edata['s_f'] = spatial_feat # edge: spatial features
329 |
330 | try:
331 | self.gnn(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list)
332 | except Exception as e:
333 | print(e)
334 |
335 |
336 | class Predictor(nn.Module):
337 | '''
338 | init : config
339 | forward : edge
340 | '''
341 | def __init__(self, CONFIG):
342 | super(Predictor, self).__init__()
343 | self.classifier = MLP(CONFIG.G_ER_L_S, CONFIG.G_ER_A, CONFIG.G_ER_B, CONFIG.G_ER_BN, CONFIG.G_ER_D)
344 | self.sigmoid = nn.Sigmoid()
345 |
346 | def forward(self, edge):
347 | feat = torch.cat([edge.dst['new_n_f'], edge.dst['new_n_f_lang'], edge.data['s_f'], edge.src['new_n_f_lang'], edge.src['new_n_f']], dim=1)
348 | scene_feat = torch.cat([edge.dst['new_n_f'], edge.src['new_n_f'],edge.data['s_f']], dim=1)
349 | pred = self.classifier(feat)
350 | # If the criterion is BCELoss, uncomment the following code ->
351 | # output = self.sigmoid(output)
352 | return {'pred': pred, 'scene_feat': scene_feat}
353 |
354 |
355 | class AGRNN(nn.Module):
356 | '''
357 | init :
358 | feature_type, bias, bn, dropout, multi_attn, layer, diff_edge
359 |
360 | forward :
361 | node_num, features, spatial_features, word2vec, roi_label,
362 | validation, choose_nodes, remove_nodes
363 | '''
364 | def __init__(self, bias=True, bn=True, dropout=None, multi_attn=False, layer=1, diff_edge=True, global_feat = 0):
365 | super(AGRNN, self).__init__()
366 |
367 | self.multi_attn = multi_attn # false
368 | self.layer = layer # 1 layer
369 | self.diff_edge = diff_edge # false
370 |
371 | self.CONFIG1 = CONFIGURATION(layer=1, bias=bias, bn=bn, dropout=dropout, multi_attn=multi_attn, global_feat=global_feat)
372 |
373 | self.grnn1 = GRNN(self.CONFIG1, multi_attn=multi_attn, diff_edge=diff_edge)
374 | self.edge_readout = Predictor(self.CONFIG1)
375 |
376 | def _collect_edge(self, node_num, roi_label, node_space, diff_edge):
377 | '''
378 | arguments: node_num, roi_label, node_space, diff_edge
379 | '''
380 |
381 | # get human nodes && object nodes
382 | h_node_list = np.where(roi_label == 0)[0]
383 | obj_node_list = np.where(roi_label != 0)[0]
384 | edge_list = []
385 |
386 | h_h_e_list = []
387 | o_o_e_list = []
388 | h_o_e_list = []
389 |
390 | readout_edge_list = []
391 | readout_h_h_e_list = []
392 | readout_h_o_e_list = []
393 |
394 | # get all edge in the fully-connected graph, edge_list, For node_num = 2, edge_list = [(0, 1), (1, 0)]
395 | for src in range(node_num):
396 | for dst in range(node_num):
397 | if src == dst:
398 | continue
399 | else:
400 | edge_list.append((src, dst))
401 |
402 | # readout_edge_list, get corresponding readout edge in the graph
403 | src_box_list = np.arange(roi_label.shape[0])
404 | for dst in h_node_list:
405 | for src in src_box_list:
406 | if src not in h_node_list:
407 | readout_edge_list.append((src, dst))
408 |
409 | # readout h_h_e_list, get corresponding readout h_h edges && h_o edges
410 | temp_h_node_list = h_node_list[:]
411 | for dst in h_node_list:
412 | if dst == h_node_list.shape[0]-1:
413 | continue
414 | temp_h_node_list = temp_h_node_list[1:]
415 | for src in temp_h_node_list:
416 | if src == dst: continue
417 | readout_h_h_e_list.append((src, dst))
418 |
419 | # readout h_o_e_list
420 | readout_h_o_e_list = [x for x in readout_edge_list if x not in readout_h_h_e_list]
421 |
422 | # add node space to match the batch graph
423 | h_node_list = (np.array(h_node_list)+node_space).tolist()
424 | obj_node_list = (np.array(obj_node_list)+node_space).tolist()
425 |
426 | h_h_e_list = (np.array(h_h_e_list)+node_space).tolist() #empty no diff_edge
427 | o_o_e_list = (np.array(o_o_e_list)+node_space).tolist() #empty no diff_edge
428 | h_o_e_list = (np.array(h_o_e_list)+node_space).tolist() #empty no diff_edge
429 |
430 | readout_h_h_e_list = (np.array(readout_h_h_e_list)+node_space).tolist()
431 | readout_h_o_e_list = (np.array(readout_h_o_e_list)+node_space).tolist()
432 | readout_edge_list = (np.array(readout_edge_list)+node_space).tolist()
433 |
434 | return edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list
435 |
436 | def _build_graph(self, node_num, roi_label, node_space, diff_edge):
437 | '''
438 | Declare graph, add_nodes, collect edges, add_edges
439 | '''
440 | graph = dgl.DGLGraph()
441 | graph.add_nodes(node_num)
442 |
443 | edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._collect_edge(node_num, roi_label, node_space, diff_edge)
444 | src, dst = tuple(zip(*edge_list))
445 | graph.add_edges(src, dst) # make the graph bi-directional
446 |
447 | return graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list
448 |
449 | def forward(self, node_num=None, feat=None, spatial_feat=None, word2vec=None, roi_label=None, validation=False, choose_nodes=None, remove_nodes=None):
450 |
451 | batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, batch_readout_edge_list, batch_readout_h_h_e_list, batch_readout_h_o_e_list = [], [], [], [], [], [], [], [], []
452 | node_num_cum = np.cumsum(node_num) # !IMPORTANT
453 |
454 | for i in range(len(node_num)):
455 | # set node space
456 | node_space = 0
457 | if i != 0:
458 | node_space = node_num_cum[i-1]
459 | graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._build_graph(node_num[i], roi_label[i], node_space, diff_edge=self.diff_edge)
460 |
461 | # update batch
462 | batch_graph.append(graph)
463 | batch_h_node_list += h_node_list
464 | batch_obj_node_list += obj_node_list
465 |
466 | batch_h_h_e_list += h_h_e_list
467 | batch_o_o_e_list += o_o_e_list
468 | batch_h_o_e_list += h_o_e_list
469 |
470 | batch_readout_edge_list += readout_edge_list
471 | batch_readout_h_h_e_list += readout_h_h_e_list
472 | batch_readout_h_o_e_list += readout_h_o_e_list
473 |
474 | batch_graph = dgl.batch(batch_graph)
475 |
476 | # GRNN
477 | self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, initial_feat=True)
478 | batch_graph.apply_edges(self.edge_readout, tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list))))
479 |
480 | if self.training or validation:
481 | # !NOTE: cannot use "batch_readout_h_o_e_list+batch_readout_h_h_e_list" because of the wrong order
482 | return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \
483 | batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['scene_feat']
484 | else:
485 | return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \
486 | batch_graph.nodes[batch_h_node_list].data['alpha'], \
487 | batch_graph.nodes[batch_h_node_list].data['alpha_lang']
488 |
--------------------------------------------------------------------------------
/model_train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding
3 | Lab : MMLAB, National University of Singapore
4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren
5 | '''
6 |
7 | import os
8 | import time
9 |
10 | import argparse
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch import optim
17 | import torch.nn.functional as F
18 | from torch.utils.data import DataLoader
19 |
20 | from models.mtl_model import *
21 | from models.scene_graph import *
22 | from models.surgicalDataset import *
23 | from models.segmentation_model import get_gcnet
24 |
25 | from utils.scene_graph_eval_matrix import *
26 | from utils.segmentation_eval_matrix import *
27 |
28 |
29 | import torch.multiprocessing as mp
30 | import torch.distributed as dist
31 | from torch.nn.parallel import DistributedDataParallel as DDP
32 |
33 |
34 | def seed_everything(seed=27):
35 | '''
36 | Set random seed for reproducible experiments
37 | Inputs: seed number
38 | '''
39 | torch.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 | os.environ['PYTHONHASHSEED'] = str(seed)
42 | torch.backends.cudnn.deterministic = True
43 | torch.backends.cudnn.benchmark = False
44 |
45 |
46 | def seg_eval_batch(seg_output, target):
47 | '''
48 | Calculate segmentation loss, pixel acc and IoU
49 | Inputs: predicted segmentation mask, GT segmentation mask
50 | '''
51 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2)
52 | loss = seg_criterion(seg_output, target)
53 | correct, labeled = batch_pix_accuracy(seg_output.data, target)
54 | inter, union = batch_intersection_union(seg_output.data, target, 8) # 8 is num classes
55 | return correct, labeled, inter, union, loss
56 |
57 | def get_checkpoint_loc(model_type, seg_mode = None):
58 | loc = None
59 | if model_type == 'amtl-t0' or model_type == 'amtl-t3':
60 | if seg_mode is None:
61 | loc = 'checkpoints/stl_s/stl_s/epoch_train/checkpoint_D153_epoch.pth'
62 | elif seg_mode == 'v1':
63 | loc = 'checkpoints/stl_s_v1/stl_s_v1/epoch_train/checkpoint_D168_epoch.pth'
64 | elif seg_mode == 'v2_gc':
65 | loc = 'checkpoints/stl_s_v2_gc/stl_s_v2_gc/epoch_train/checkpoint_D168_epoch.pth'
66 | elif model_type == 'amtl-t1':
67 | loc = 'checkpoints/stl_s/stl_s/epoch_train/checkpoint_D168_epoch.pth'
68 | elif model_type == 'amtl-t2':
69 | loc = 'checkpoints/stl_sg_wfe/stl_sg_wfe/epoch_train/checkpoint_D110_epoch.pth'
70 | return loc
71 |
72 | def build_model(args):
73 | '''
74 | Build MTL model
75 | 1) Scene Graph Understanding Model
76 | 2) Segmentation Model : Encoder, Reasoning unit, Decoder
77 |
78 | Inputs: args
79 | '''
80 |
81 | '''==== Graph model ===='''
82 | # graph model
83 | scene_graph = AGRNN(bias=True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, global_feat=args.global_feat)
84 |
85 | # segmentation model
86 | seg_model = get_gcnet(backbone='resnet18_model', pretrained=True)
87 | model = mtl_model(seg_model.pretrained, scene_graph, seg_model.gr_interaction, seg_model.gr_decoder, seg_mode = args.seg_mode)
88 | model.to(torch.device('cpu'))
89 | return model
90 |
91 |
92 | def model_eval(args, model, validation_dataloader):
93 | '''
94 | Evaluate function for the MTL model (Segmentation and Scene Graph Performance)
95 | Inputs: args, model, val-dataloader
96 |
97 | '''
98 |
99 | model.eval()
100 |
101 | # graph
102 | scene_graph_criterion = nn.MultiLabelSoftMarginLoss()
103 | scene_graph_edge_count = 0
104 | scene_graph_total_acc = 0.0
105 | scene_graph_total_loss = 0.0
106 | scene_graph_logits_list = []
107 | scene_graph_labels_list = []
108 |
109 | test_seg_loss = 0.0
110 | total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
111 |
112 | for data in tqdm(validation_dataloader):
113 | seg_img = data['img']
114 | seg_masks = data['mask']
115 | img_loc = data['img_loc']
116 | node_num = data['node_num']
117 | roi_labels = data['roi_labels']
118 | det_boxes = data['det_boxes']
119 | edge_labels = data['edge_labels']
120 | spatial_feat = data['spatial_feat']
121 | word2vec = data['word2vec']
122 |
123 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True)
124 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True)
125 |
126 | with torch.no_grad():
127 | interaction, seg_outputs, _ = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True)
128 |
129 | scene_graph_logits_list.append(interaction)
130 | scene_graph_labels_list.append(edge_labels)
131 |
132 | # Loss and accuracy
133 | scene_graph_loss = scene_graph_criterion(interaction, edge_labels.float())
134 | scene_graph_acc = np.sum(np.equal(np.argmax(interaction.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
135 | correct, labeled, inter, union, t_loss = seg_eval_batch(seg_outputs, seg_masks)
136 |
137 | # Accumulate scene graph loss and acc
138 | scene_graph_total_loss += scene_graph_loss.item() * edge_labels.shape[0]
139 | scene_graph_total_acc += scene_graph_acc
140 | scene_graph_edge_count += edge_labels.shape[0]
141 |
142 | total_correct += correct
143 | total_label += labeled
144 | total_inter += inter
145 | total_union += union
146 | test_seg_loss += t_loss.item()
147 |
148 | # Graph evaluation
149 | scene_graph_total_acc = scene_graph_total_acc / scene_graph_edge_count
150 | scene_graph_total_loss = scene_graph_total_loss / len(validation_dataloader)
151 | scene_graph_logits_all = torch.cat(scene_graph_logits_list).cuda()
152 | scene_graph_labels_all = torch.cat(scene_graph_labels_list).cuda()
153 | scene_graph_logits_all = F.softmax(scene_graph_logits_all, dim=1)
154 | scene_graph_map_value, scene_graph_recall = calibration_metrics(scene_graph_logits_all, scene_graph_labels_all)
155 |
156 | # Segmentation evaluation
157 | pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
158 | IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
159 | mIoU = IoU.mean()
160 |
161 | print('================= Evaluation ====================')
162 | print('Graph : acc: %0.4f map: %0.4f recall: %0.4f loss: %0.4f}' % (scene_graph_total_acc, scene_graph_map_value, scene_graph_recall, scene_graph_total_loss))
163 | print('Segmentation : Pacc: %0.4f mIoU: %0.4f loss: %0.4f}' % (pixAcc, mIoU, test_seg_loss/len(validation_dataloader)))
164 | return(scene_graph_total_acc, scene_graph_map_value, mIoU)
165 |
166 |
167 | def train_model(gpu, args):
168 | '''
169 | Train function for the MTL model
170 | Inputs: number of gpus per node, args
171 |
172 | '''
173 | # Store best value and epoch number
174 | best_value = [0.0, 0.0, 0.0]
175 | best_epoch = [0, 0, 0]
176 |
177 | # Decaying learning rate
178 | decay_lr = args.lr
179 |
180 | # This is placed above the dist.init process, because of the feature_extraction model.
181 | model = build_model(args)
182 |
183 | # Load pre-trained weights
184 | if args.model == 'amtl-t0' or args.model == 'amtl-t3' or args.model == 'amtl-t0-ft' or args.model == 'amtl-t1' or args.model == 'amtl-t2':
185 | print('Loading pre-trained weights for Sequential Optimisation')
186 | pretrained_model = torch.load(get_checkpoint_loc(args.model, args.seg_mode))
187 | pretrained_dict = pretrained_model['state_dict']
188 | model_dict = model.state_dict()
189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
190 | model_dict.update(pretrained_dict)
191 | model.load_state_dict(model_dict)
192 |
193 | # Set training flag for submodules based on train model.
194 | model.set_train_test(args.model)
195 |
196 |
197 | if args.KD:
198 | teacher_model = build_model(args, load_pretrained=False)
199 | # Load pre-trained stl_mtl_model
200 | print('Preparing teacher model')
201 | pretrained_model = torch.load('/media/mobarak/data/lalith/mtl_scene_understanding_and_segmentation/checkpoints/stl_s_v1/stl_s_v1/epoch_train/checkpoint_D168_epoch.pth')
202 | pretrained_dict = pretrained_model['state_dict']
203 | model_dict = teacher_model.state_dict()
204 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
205 | model_dict.update(pretrained_dict)
206 | teacher_model.load_state_dict(model_dict)
207 | if args.model == 'mtl-t3':
208 | teacher_model.set_train_test('mtl-t3')
209 | teacher_model.model_type3_insert()
210 | teacher_model.cuda()
211 | else:
212 | teacher_model.set_train_test('stl-s')
213 | teacher_model.cuda()
214 | teacher_model.eval()
215 |
216 | # Insert nn layers based on type.
217 | if args.model == 'amtl-t1' or args.model == 'mtl-t1':
218 | model.model_type1_insert()
219 | elif args.model == 'amtl-t2' or args.model == 'mtl-t2':
220 | model.model_type2_insert()
221 | elif args.model == 'amtl-t3' or args.model == 'mtl-t3':
222 | model.model_type3_insert()
223 |
224 | # Priority rank given to node 0 -> current pc, if more nodes -> multiple PCs
225 | os.environ['MASTER_ADDR'] = 'localhost'
226 | os.environ['MASTER_PORT'] = args.port #8892
227 | rank = args.nr * args.gpus + gpu
228 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
229 |
230 | # Set cuda
231 | torch.cuda.set_device(gpu)
232 |
233 | # Wrap the model with ddp
234 | model.cuda()
235 | model = DDP(model, device_ids=[gpu], find_unused_parameters=True)#, find_unused_parameters=True)
236 |
237 | # Define loss function (criterion) and optimizer
238 | seg_criterion = SegmentationLosses(se_loss=False, aux=False, nclass=8, se_weight=0.2, aux_weight=0.2).cuda(gpu)
239 | graph_scene_criterion = nn.MultiLabelSoftMarginLoss().cuda(gpu)
240 |
241 | # train and test dataloader
242 | train_seq = [[2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]]
243 | val_seq = [[1, 5, 16]]
244 | data_dir = ['datasets/instruments18/seq_']
245 | img_dir = ['/left_frames/']
246 | mask_dir = ['/annotations/']
247 | dset = [0]
248 | data_const = SurgicalSceneConstants()
249 |
250 | seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir': img_dir, 'dset': dset, 'mask_dir': mask_dir}
251 |
252 | # Val_dataset only set in 1 GPU
253 | val_dataset = SurgicalSceneDataset(seq_set=seq['val_seq'], dset=seq['dset'], data_dir=seq['data_dir'], \
254 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], istrain=False, dataconst=data_const, \
255 | feature_extractor=args.feature_extractor, reduce_size=False)
256 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
257 |
258 | # Train_dataset distributed to 2 GPU
259 | train_dataset = SurgicalSceneDataset(seq_set=seq['train_seq'], data_dir=seq['data_dir'],
260 | img_dir=seq['img_dir'], mask_dir=seq['mask_dir'], dset=seq['dset'], istrain=True, dataconst=data_const,
261 | feature_extractor=args.feature_extractor, reduce_size=False)
262 |
263 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True)
264 | train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True, sampler=train_sampler)
265 |
266 | # Evaluate the model before start of training
267 | if gpu == 0:
268 | if args.KD:
269 | print("=================== Teacher Model=========================")
270 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, teacher_model, val_dataloader)
271 | print("=================== Student Model=========================")
272 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, model, val_dataloader)
273 | print("PT SC ACC: [value: {:0.4f}] PT SC mAP: [value: {:0.4f}] PT Seg mIoU: [value: {:0.4f}]".format(eval_sc_acc, eval_sc_map, eval_seg_miou))
274 |
275 | for epoch_count in range(args.epoch):
276 |
277 | start_time = time.time()
278 |
279 | # Set model / submodules in train mode
280 | model.train()
281 | if args.model == 'stl-sg' or args.model == 'amtl-t0' or args.model == 'amtl-t3':
282 | model.module.feature_encoder.eval()
283 | model.module.gcn_unit.eval()
284 | model.module.seg_decoder.eval()
285 | elif args.model == 'stl-sg-wfe':
286 | model.module.gcn_unit.eval()
287 | model.module.seg_decoder.eval()
288 | elif args.model == 'stl-s':
289 | model.module.scene_graph.eval()
290 |
291 | train_seg_loss = 0.0
292 | train_scene_graph_loss = 0.0
293 |
294 | model.cuda()
295 |
296 | # Optimizer with decaying learning rate
297 | decay_lr = decay_lr*0.98 if ((epoch_count+1) %10 == 0) else decay_lr
298 | optimizer = optim.Adam(model.parameters(), lr=decay_lr, weight_decay=0)
299 |
300 | train_sampler.set_epoch(epoch_count)
301 |
302 | if gpu == 0: print('================= Train ====================')
303 |
304 | for data in tqdm(train_dataloader):
305 | seg_img = data['img']
306 | seg_masks = data['mask']
307 | img_loc = data['img_loc']
308 | node_num = data['node_num']
309 | roi_labels = data['roi_labels']
310 | det_boxes = data['det_boxes']
311 | edge_labels = data['edge_labels']
312 | spatial_feat = data['spatial_feat']
313 | word2vec = data['word2vec']
314 |
315 | spatial_feat, word2vec, edge_labels = spatial_feat.cuda(non_blocking=True), word2vec.cuda(non_blocking=True), edge_labels.cuda(non_blocking=True)
316 | seg_img, seg_masks = seg_img.cuda(non_blocking=True), seg_masks.cuda(non_blocking=True)
317 |
318 | # Forward propagation
319 | interaction, seg_outputs, fe_feat = model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels)
320 |
321 | # Loss calculation
322 | seg_loss = seg_criterion(seg_outputs, seg_masks)
323 | scene_graph_loss = graph_scene_criterion(interaction, edge_labels.float())
324 |
325 | # KD-Loss
326 | if args.KD:
327 | with torch.no_grad():
328 | _, _, t_fe_feat = teacher_model(seg_img, img_loc, det_boxes, node_num, spatial_feat, word2vec, roi_labels, validation=True)
329 | t_fe_feat = t_fe_feat.detach()
330 | t_fe_feat = t_fe_feat / (t_fe_feat.pow(2).sum(1) + 1e-6).sqrt().view(t_fe_feat.size(0), 1, t_fe_feat.size(2), t_fe_feat.size(3))
331 |
332 |
333 | fe_feat = fe_feat
334 | fe_feat = fe_feat / (fe_feat.pow(2).sum(1) + 1e-6).sqrt().view(fe_feat.size(0), 1, fe_feat.size(2), fe_feat.size(3))
335 | dist_loss = (fe_feat - t_fe_feat).pow(2).sum(1).mean()
336 |
337 |
338 | if args.model == 'stl-s':
339 | loss_total = seg_loss
340 | elif args.model == 'stl-sg' or args.model == 'stl-sg-wfe' or args.model == 'amtl-t0' or args.model == 'amtl-t3':
341 | loss_total = scene_graph_loss
342 | elif args.KD:
343 | loss_total = (0.4 * scene_graph_loss) + seg_loss + dist_loss
344 | else:
345 | loss_total = (0.4 * scene_graph_loss)+ (0.6 * seg_loss)
346 |
347 | optimizer.zero_grad()
348 | loss_total.backward()
349 | optimizer.step()
350 |
351 | train_seg_loss += seg_loss.item()
352 | train_scene_graph_loss += scene_graph_loss.item() * edge_labels.shape[0]
353 |
354 | # calculate the loss and accuracy of each epoch
355 | train_seg_loss += train_seg_loss / len(train_dataloader)
356 | train_scene_graph_loss = train_scene_graph_loss / len(train_dataloader)
357 |
358 | if gpu == 0:
359 | end_time = time.time()
360 | print("Train Epoch: {}/{} lr: {:0.9f} Graph_loss: {:0.4f} Segmentation_Loss: {:0.4f} Execution time: {:0.4f}".format(\
361 | epoch_count + 1, args.epoch, decay_lr, train_scene_graph_loss, train_seg_loss, (end_time-start_time)))
362 |
363 | #if epoch_count % 2 == 0:
364 | # save model
365 | # if epoch_loss<0.0405 or epoch_count % args.save_every == (args.save_every - 1):
366 | checkpoint = { 'lr': args.lr, 'b_s': args.batch_size, 'bias': args.bias, 'bn': args.bn, 'dropout': args.drop_prob,
367 | 'layers': args.layers, 'multi_head': args.multi_attn,
368 | 'diff_edge': args.diff_edge, 'state_dict': model.module.state_dict() }
369 |
370 | save_name = "checkpoint_D1" + str(epoch_count+1) + '_epoch.pth'
371 | torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, 'epoch_train', save_name))
372 |
373 | eval_sc_acc, eval_sc_map, eval_seg_miou = model_eval(args, model, val_dataloader)
374 | if eval_sc_acc > best_value[0]:
375 | best_value[0] = eval_sc_acc
376 | best_epoch[0] = epoch_count+1
377 | if eval_sc_map > best_value[1]:
378 | best_value[1] = eval_sc_map
379 | best_epoch[1] = epoch_count+1
380 | if eval_seg_miou > best_value[2]:
381 | best_value[2] = eval_seg_miou
382 | best_epoch[2] = epoch_count+1
383 | print("Best SC Acc: [Epoch: {} value: {:0.4f}] Best SC mAP: [Epoch: {} value: {:0.4f}] Best Seg mIoU: [Epoch: {} value: {:0.4f}]".format(\
384 | best_epoch[0], best_value[0], best_epoch[1], best_value[1], best_epoch[2], best_value[2]))
385 |
386 | return
387 |
388 |
389 | if __name__ == "__main__":
390 | '''
391 | Main function to set arguments
392 | '''
393 |
394 | # ---------------------------------------------- Optimization and feature sharing variants ----------------------------------------------
395 | '''
396 | Format for the model_type : X-Y
397 |
398 | -> X : Optimisation technique [1. amtl - Sequential MTL Optimisation, 2. mtl - Naive MTL Optimisation]
399 | -> Y : Feature Sharing mechanism [1. t0 - Base model,
400 | 2. t1 - Scene graph features to enhance segmentation (SGFSEG),
401 | 3. t3 - Global interaction space features to improve scene graph (GISFSG)]
402 |
403 | '''
404 | model_type = 'amtl-t0'
405 | ver = model_type + '_v5'
406 | port = '8892'
407 | f_e = 'resnet18_11_cbs_ts'
408 |
409 |
410 | # ----------------------------------------------Global reasoning variant in segmentation -----------------------------------------------
411 | '''
412 | -> seg_mode : v1 - (MSLRGR - multi-scale local reasoning and global reasoning)
413 | v2gc - (MSLR - multi-scale local reasoning)
414 | None - Base model
415 | '''
416 | seg_mode = 'v1'
417 |
418 | # Set random seed
419 | seed_everything()
420 | print(ver, seg_mode)
421 |
422 | # Device Count
423 | num_gpu = torch.cuda.device_count()
424 |
425 | # Arguments
426 | parser = argparse.ArgumentParser(description='MTL Scene graph and Segmentation')
427 |
428 | # Hyperparameters
429 | parser.add_argument('--lr', type=float, default = 0.00001)
430 | parser.add_argument('--epoch', type=int, default = 130)
431 | parser.add_argument('--start_epoch', type=int, default = 0)
432 | parser.add_argument('--batch_size', type=int, default = 4)
433 | parser.add_argument('--gpu', type=bool, default = True)
434 | parser.add_argument('--train_model', type=str, default = 'epoch')
435 | parser.add_argument('--exp_ver', type=str, default = ver)
436 |
437 | # File locations
438 | parser.add_argument('--log_dir', type=str, default = './log/' + ver)
439 | parser.add_argument('--save_dir', type=str, default = './checkpoints/' + ver)
440 | parser.add_argument('--output_img_dir', type=str, default = './results/' + ver)
441 | parser.add_argument('--save_every', type=int, default = 10)
442 | parser.add_argument('--pretrained', type=str, default = None)
443 |
444 | # Network settings
445 | parser.add_argument('--layers', type=int, default = 1)
446 | parser.add_argument('--bn', type=bool, default = False)
447 | parser.add_argument('--drop_prob', type=float, default = 0.3)
448 | parser.add_argument('--bias', type=bool, default = True)
449 | parser.add_argument('--multi_attn', type=bool, default = False)
450 | parser.add_argument('--diff_edge', type=bool, default = False)
451 |
452 | if model_type == 'mtl-t3' or model_type == 'amtl-t3':
453 | parser.add_argument('--global_feat', type=int, default = 128)
454 | else:
455 | parser.add_argument('--global_feat', type=int, default = 0)
456 |
457 | # Data processing
458 | parser.add_argument('--sampler', type=int, default = 0)
459 | parser.add_argument('--data_aug', type=bool, default = False)
460 | parser.add_argument('--feature_extractor', type=str, default = f_e)
461 | parser.add_argument('--seg_mode', type=str, default = seg_mode) # v1/v2_gc
462 |
463 | parser.add_argument('--KD', type=bool, default = False)
464 |
465 | # GPU distributor
466 | parser.add_argument('--port', type=str, default = port)
467 | parser.add_argument('--nodes', type=int, default = 1, metavar='N', help='Number of data loading workers (default: 4)')
468 | parser.add_argument('--gpus', type=int, default = num_gpu, help='Number of gpus per node')
469 | parser.add_argument('--nr', type=int, default = 0, help='Ranking within the nodes')
470 |
471 | # Model type
472 | parser.add_argument('--model', type=str, default = model_type)
473 | args = parser.parse_args()
474 |
475 | # Constants for the surgical scene
476 | data_const = SurgicalSceneConstants()
477 |
478 | # GPU distributed
479 | args.world_size = args.gpus * args.nodes
480 |
481 | # Train model in distributed settings - (train function, number of GPUs, arguments)
482 | mp.spawn(train_model, nprocs=args.gpus, args=(args,))
--------------------------------------------------------------------------------
/models/segmentation_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Project : Global-Reasoned Multi-Task Surgical Scene Understanding
3 | Lab : MMLAB, National University of Singapore
4 | contributors : Lalithkumar Seenivasan, Sai Mitheran, Mobarakol Islam, Hongliang Ren
5 | Note : Code adopted and modified from Visual-Semantic Graph Attention Networks and Dual attention network for scene segmentation
6 |
7 | @inproceedings{fu2019dual,
8 | title={Dual attention network for scene segmentation},
9 | author={Fu, Jun and Liu, Jing and Tian, Haijie and Li, Yong and Bao, Yongjun and Fang, Zhiwei and Lu, Hanqing},
10 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
11 | pages={3146--3154},
12 | year={2019}
13 | }
14 | '''
15 |
16 |
17 | import math
18 | import numpy as np
19 | from collections import OrderedDict
20 |
21 | import torch
22 | import torch.nn as nn
23 | from torch import Tensor
24 | from torch.nn import functional as F
25 | from torch.nn.functional import interpolate
26 | from typing import Type, Any, Callable, Union, List, Optional
27 |
28 | # Setting the kwargs for upsample configuration
29 | up_kwargs = {'mode': 'bilinear', 'align_corners': True}
30 |
31 |
32 | class Namespace:
33 | """
34 | Namespace class for custom args to be parsed
35 | Inputs: **kwargs
36 |
37 | """
38 | def __init__(self, **kwargs):
39 | self.__dict__.update(kwargs)
40 |
41 | def get_backbone(name, **kwargs):
42 | """
43 | Function to get backbone feature extractor
44 | Inputs: name of backbone, **kwargs
45 |
46 | """
47 | models = {
48 | 'resnet18_model': resnet18_model,
49 | }
50 | name = name.lower()
51 | if name not in models:
52 | raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
53 | net = models[name](**kwargs)
54 | return net
55 |
56 |
57 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
58 | """
59 | 3x3 convolution with padding
60 | Inputs: in_planes, out_planes, stride, groups, dilation
61 |
62 | """
63 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
64 | padding=dilation, groups=groups, bias=False, dilation=dilation)
65 |
66 |
67 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
68 | """
69 | 1x1 convolution
70 | Inputs: in_planes, out_planes, stride
71 |
72 | """
73 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
74 |
75 |
76 | class BasicBlock(nn.Module):
77 | """
78 | Basic block for ResNet18 backbone
79 | init :
80 | inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer
81 |
82 | forward : x
83 |
84 | """
85 | expansion: int = 1
86 |
87 | def __init__(
88 | self,
89 | inplanes: int,
90 | planes: int,
91 | stride: int = 1,
92 | downsample: Optional[nn.Module] = None,
93 | groups: int = 1,
94 | base_width: int = 64,
95 | dilation: int = 1,
96 | norm_layer: Optional[Callable[..., nn.Module]] = None
97 | ) -> None:
98 | super(BasicBlock, self).__init__()
99 | if norm_layer is None:
100 | norm_layer = nn.BatchNorm2d
101 | if groups != 1 or base_width != 64:
102 | raise ValueError(
103 | 'BasicBlock only supports groups=1 and base_width=64')
104 | if dilation > 1:
105 | raise NotImplementedError(
106 | "Dilation > 1 not supported in BasicBlock")
107 |
108 | self.planes = planes
109 |
110 | self.conv1 = conv3x3(inplanes, planes, stride)
111 | self.bn1 = norm_layer(planes)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.conv2 = conv3x3(planes, planes)
114 | self.bn2 = norm_layer(planes)
115 | self.downsample = downsample
116 | self.stride = stride
117 |
118 | def forward(self, x: Tensor) -> Tensor:
119 | identity = x
120 |
121 | out = self.conv1(x)
122 | out = self.bn1(out)
123 | out = self.relu(out)
124 |
125 | out = self.conv2(out)
126 | out = self.bn2(out)
127 |
128 | if self.downsample is not None:
129 | identity = self.downsample(x)
130 |
131 | out += identity
132 | out = self.relu(out)
133 |
134 | return out
135 |
136 |
137 | class Bottleneck(nn.Module):
138 | """
139 | Bottleneck block for ResNet18
140 | init :
141 | inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer
142 |
143 | forward : x
144 |
145 | """
146 | expansion: int = 4
147 |
148 | def __init__(
149 | self,
150 | inplanes: int,
151 | planes: int,
152 | stride: int = 1,
153 | downsample: Optional[nn.Module] = None,
154 | groups: int = 1,
155 | base_width: int = 64,
156 | dilation: int = 1,
157 | norm_layer: Optional[Callable[..., nn.Module]] = None
158 | ) -> None:
159 | super(Bottleneck, self).__init__()
160 | if norm_layer is None:
161 | norm_layer = nn.BatchNorm2d
162 | width = int(planes * (base_width / 64.)) * groups
163 |
164 | # self.conv2 and self.downsample layers downsample the input when stride != 1
165 | self.conv1 = conv1x1(inplanes, width)
166 | self.bn1 = norm_layer(width)
167 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
168 | self.bn2 = norm_layer(width)
169 | self.conv3 = conv1x1(width, planes * self.expansion)
170 | self.bn3 = norm_layer(planes * self.expansion)
171 | self.relu = nn.ReLU(inplace=True)
172 | self.downsample = downsample
173 | self.stride = stride
174 |
175 | def forward(self, x: Tensor) -> Tensor:
176 | identity = x
177 |
178 | out = self.conv1(x)
179 | out = self.bn1(out)
180 | out = self.relu(out)
181 |
182 | out = self.conv2(out)
183 | out = self.bn2(out)
184 | out = self.relu(out)
185 |
186 | out = self.conv3(out)
187 | out = self.bn3(out)
188 |
189 | # Downsampling of the input variable (x)
190 | if self.downsample is not None:
191 | identity = self.downsample(x)
192 |
193 | out += identity
194 | out = self.relu(out)
195 |
196 | return out
197 |
198 |
199 | class ResNet(nn.Module):
200 | """
201 | ResNet base class for different variants
202 | init :
203 | block, layers, num_classes (ImageNet), zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer
204 |
205 | forward : x
206 | """
207 |
208 | def __init__(
209 | self,
210 | block: Type[Union[BasicBlock, Bottleneck]],
211 | layers: List[int],
212 | num_classes: int = 1000,
213 | zero_init_residual: bool = False,
214 | groups: int = 1,
215 | width_per_group: int = 64,
216 | replace_stride_with_dilation: Optional[List[bool]] = None,
217 | norm_layer: Optional[Callable[..., nn.Module]] = None
218 | ) -> None:
219 |
220 | super(ResNet, self).__init__()
221 | if norm_layer is None:
222 | norm_layer = nn.BatchNorm2d
223 | self._norm_layer = norm_layer
224 | self.inplanes = 64
225 | self.dilation = 1
226 |
227 | if replace_stride_with_dilation is None:
228 | # Each element in the tuple indicates whether we should replace the 2x2 stride with a dilated convolution
229 | replace_stride_with_dilation = [False, False, False]
230 |
231 | if len(replace_stride_with_dilation) != 3:
232 | raise ValueError("replace_stride_with_dilation should be None "
233 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
234 |
235 | self.groups = groups
236 | self.base_width = width_per_group
237 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
238 | bias=False)
239 | self.bn1 = norm_layer(self.inplanes)
240 | self.relu = nn.ReLU(inplace=True)
241 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
242 | self.layer1 = self._make_layer(block, 64, layers[0])
243 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
244 | dilate=replace_stride_with_dilation[0])
245 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
246 | dilate=replace_stride_with_dilation[1])
247 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
248 | dilate=replace_stride_with_dilation[2])
249 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
250 | self.fc = nn.Linear(512 * block.expansion, num_classes)
251 |
252 | for m in self.modules():
253 | if isinstance(m, nn.Conv2d):
254 | nn.init.kaiming_normal_(
255 | m.weight, mode='fan_out', nonlinearity='relu')
256 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
257 | nn.init.constant_(m.weight, 1)
258 | nn.init.constant_(m.bias, 0)
259 |
260 | if zero_init_residual:
261 | for m in self.modules():
262 | if isinstance(m, Bottleneck):
263 | nn.init.constant_(m.bn3.weight, 0)
264 | elif isinstance(m, BasicBlock):
265 | nn.init.constant_(m.bn2.weight, 0)
266 |
267 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
268 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
269 | norm_layer = self._norm_layer
270 | downsample = None
271 | previous_dilation = self.dilation
272 | if dilate:
273 | self.dilation *= stride
274 | stride = 1
275 | if stride != 1 or self.inplanes != planes * block.expansion:
276 | downsample = nn.Sequential(
277 | conv1x1(self.inplanes, planes * block.expansion, stride),
278 | norm_layer(planes * block.expansion),
279 | )
280 |
281 | layers = []
282 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
283 | self.base_width, previous_dilation, norm_layer))
284 | self.inplanes = planes * block.expansion
285 | for _ in range(1, blocks):
286 | layers.append(block(self.inplanes, planes, groups=self.groups,
287 | base_width=self.base_width, dilation=self.dilation,
288 | norm_layer=norm_layer))
289 |
290 | return nn.Sequential(*layers)
291 |
292 |
293 | def _forward_impl(self, x) -> Tensor:
294 | x = self.conv1(x)
295 | x = self.bn1(x)
296 | x = self.relu(x)
297 | x = self.maxpool(x)
298 |
299 | c1 = self.layer1(x)
300 | c2 = self.layer2(c1)
301 | c3 = self.layer3(c2)
302 | c4 = self.layer4(c3)
303 |
304 | return c1, c2, c3, c4
305 |
306 |
307 | def forward(self, x: Tensor) -> Tensor:
308 | return self._forward_impl(x)
309 |
310 |
311 | class BaseNet(nn.Module):
312 | """
313 | BaseNet class for Multi-scale global reasoned segmentation module
314 |
315 | init :
316 | block, layers, num_classes (ImageNet), zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer
317 |
318 | forward : x
319 |
320 | """
321 | def __init__(self, nclass, backbone, pretrained, dilated=True, norm_layer=None,
322 | root='~/.encoding/models', *args, **kwargs):
323 | super(BaseNet, self).__init__()
324 | self.nclass = nclass
325 |
326 | # Copying modules from pretrained models
327 | self.backbone = backbone
328 | self.pretrained = get_backbone(backbone, pretrained=pretrained, dilated=dilated,
329 | norm_layer=norm_layer, root=root,
330 | *args, **kwargs)
331 | self.pretrained.fc = None
332 | self._up_kwargs = up_kwargs
333 |
334 | def base_forward(self, x):
335 |
336 | x = self.pretrained.conv1(x)
337 | x = self.pretrained.bn1(x)
338 | x = self.pretrained.relu(x)
339 | x = self.pretrained.maxpool(x)
340 | c = self.pretrained.layer1(x)
341 | c = self.pretrained.layer2(c)
342 | c = self.pretrained.layer3(c)
343 | c = self.pretrained.layer4(c)
344 |
345 | return None, None, None, c
346 |
347 | def evaluate(self, x, target=None):
348 | pred = self.forward(x)
349 | if isinstance(pred, (tuple, list)):
350 | pred = pred[0]
351 | if target is None:
352 | return pred
353 | correct, labeled = batch_pix_accuracy(pred.data, target.data)
354 | inter, union = batch_intersection_union(
355 | pred.data, target.data, self.nclass)
356 | return correct, labeled, inter, union
357 |
358 |
359 | def _resnet(
360 | arch: str,
361 | block: Type[Union[BasicBlock, Bottleneck]],
362 | layers: List[int],
363 | pretrained: bool,
364 | progress: bool,
365 | **kwargs: Any
366 | ) -> ResNet:
367 |
368 | """
369 | ResNet model function to load pre-trained model: Class call
370 | init :
371 | arch, block, layers, pretrained, progress, **kwargs
372 |
373 | forward : x
374 | """
375 |
376 | model = ResNet(block, layers, **kwargs)
377 | if pretrained:
378 | print("Loading pre-trained ImageNet weights")
379 | state_dict = torch.load('models/r18/resnet18-f37072fd.pth')
380 | model.load_state_dict(state_dict)
381 | return model
382 |
383 |
384 | def resnet18(pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet:
385 | """
386 | ResNet18 model call function
387 | Inputs: pretrained, progress, **kwargs
388 |
389 | """
390 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
391 | **kwargs)
392 |
393 | class Resnet18_main(nn.Module):
394 | """
395 | ResNet main function for feature extractor
396 | init : pretrained, num_classes
397 | forward : x
398 | """
399 | def __init__(self, pretrained, num_classes=1000):
400 |
401 | super(Resnet18_main, self).__init__()
402 | resnet18_block = resnet18(
403 | pretrained=pretrained)
404 |
405 | resnet18_block.fc = nn.Conv2d(resnet18_block.inplanes, num_classes, 1)
406 |
407 | self.resnet18_block = resnet18_block
408 | self._normal_initialization(self.resnet18_block.fc)
409 |
410 | self.in_planes = 64
411 | self.kernel_size = 3
412 |
413 |
414 | def _normal_initialization(self, layer):
415 |
416 | layer.weight.data.normal_(0, 0.01)
417 | layer.bias.data.zero_()
418 |
419 | def forward(self, x):
420 | c1, c2, c3, c4 = self.resnet18_block(x)
421 |
422 | return c1, c2, c3, c4
423 |
424 |
425 | class GCN(nn.Module):
426 | """
427 | Graph Convolution network for Global interaction space
428 | init :
429 | num_state, num_node, bias=False
430 |
431 | forward : x, scene_feat = None, model_type = None
432 |
433 | """
434 | def __init__(self, num_state, num_node, bias=False):
435 | super(GCN, self).__init__()
436 | self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1, padding=0,
437 | stride=1, groups=1, bias=True)
438 | self.relu = nn.ReLU(inplace=True)
439 | self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, padding=0,
440 | stride=1, groups=1, bias=bias)
441 | self.x_avg_pool = nn.AvgPool1d(128,1)
442 |
443 | def forward(self, x, scene_feat = None, model_type = None):
444 | h = self.conv1(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1)
445 |
446 | if (model_type == 'amtl-t1' or model_type == 'mtl-t1') and scene_feat is not None: # (x+h+(avg(x)*f))
447 | x_p = torch.matmul(self.x_avg_pool(x.permute(0, 2, 1).contiguous()), scene_feat)
448 | h = h + x + x_p.permute(0, 2, 1).contiguous()
449 | else:
450 | h = h + x
451 |
452 | h = self.relu(h)
453 | h = self.conv2(h)
454 |
455 | return h
456 |
457 |
458 | class GloRe_Unit(nn.Module):
459 | """
460 | Global Reasoning Unit (GR/GloRe)
461 | init :
462 | num_in, num_mid, stride=(1, 1), kernel=1
463 |
464 | forward : x, scene_feat = None, model_type = None
465 | AMTL - Sequential MTL Optimisation
466 | MTL - Naive MTL Optimisation
467 |
468 | """
469 | def __init__(self, num_in, num_mid, stride=(1, 1), kernel=1):
470 | super(GloRe_Unit, self).__init__()
471 |
472 | self.num_s = int(2 * num_mid)
473 | self.num_n = int(1 * num_mid)
474 |
475 | kernel_size = (kernel, kernel)
476 | padding = (1, 1) if kernel == 3 else (0, 0)
477 |
478 | # Reduce dimension
479 | self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=kernel_size, padding=padding)
480 | # generate graph transformation function
481 | self.conv_proj = nn.Conv2d(num_in, self.num_n, kernel_size=kernel_size, padding=padding)
482 | # ----------
483 | self.gcn = GCN(num_state=self.num_s, num_node=self.num_n)
484 | # ----------
485 | # tail: extend dimension
486 | self.fc_2 = nn.Conv2d(self.num_s, num_in, kernel_size=kernel_size, padding=padding, stride=(1, 1),groups=1, bias=False)
487 |
488 | self.blocker = nn.BatchNorm2d(num_in)
489 |
490 | def forward(self, x, scene_feat = None, model_type = None):
491 | '''
492 | Parameter x dimension : (N, C, H, W)
493 | '''
494 | batch_size = x.size(0)
495 | x_state_reshaped = self.conv_state(x).view(batch_size, self.num_s, -1)
496 | x_proj_reshaped = self.conv_proj(x).view(batch_size, self.num_n, -1)
497 | x_rproj_reshaped = x_proj_reshaped
498 |
499 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
500 |
501 | # Projection: Coordinate space -> Interaction space
502 | x_n_state = torch.matmul( x_state_reshaped, x_proj_reshaped.permute(0, 2, 1))
503 | x_n_state = x_n_state * (1. / x_state_reshaped.size(2))
504 |
505 | if model_type == 'amtl-t2' or model_type == 'mtl-t2':
506 | x_n_rel = torch.matmul(x_n_state.permute(0, 2, 1).contiguous(), scene_feat).permute(0, 2, 1)
507 | else:
508 | x_n_rel = self.gcn(x_n_state, scene_feat, model_type)
509 |
510 | out2 = None
511 | if model_type == 'amtl-t3' or model_type == 'mtl-t3':
512 | out2 = x_n_rel
513 |
514 | # Reverse projection: Interaction space -> Coordinate space
515 | x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped)
516 |
517 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
518 | x_state = x_state_reshaped.view(batch_size, self.num_s, *x.size()[2:])
519 | out = x + self.blocker(self.fc_2(x_state))
520 |
521 | return out, out2
522 |
523 |
524 | class GR_Decoder(nn.Module):
525 | """
526 | Multi-scale Global Reasoned (GR) Decoder for Feature Aggregation
527 | init :
528 | in_channels, out_channels, norm_layer
529 |
530 | forward : s4, s1 = None, s2 = None, s3 = None, imsize = None, seg_mode = None
531 |
532 | -> s1-s4 are Scale-specific features
533 | -> out_channels = num_classes (8)
534 | -> seg_mode : V1 (MSLRGR - multi-scale local reasoning and global reasoning)
535 | V2GC (MSLR - multi-scale local reasoning)
536 | """
537 | def __init__(self, in_channels, out_channels, norm_layer):
538 | super(GR_Decoder, self).__init__()
539 |
540 | # Scale-specific channel dimensions
541 | inter_channels = in_channels // 2 # 256
542 | c2 = inter_channels // 2 # 128
543 | c1 = c2 // 2 # 64
544 |
545 | # Scale-specific decoder layers with simple Conv-BN-ReLU-Dropout-Conv Block
546 | self.s1_layer = nn.Sequential(nn.Sequential(nn.Conv2d(c1, c1, 3, padding=1, bias=False), norm_layer(c1), nn.ReLU()),
547 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(c1, out_channels, 1)))
548 |
549 | self.s2_layer = nn.Sequential(nn.Sequential(nn.Conv2d(c2, c2, 3, padding=1, bias=False), norm_layer(c2), nn.ReLU()),
550 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(c2, out_channels, 1)))
551 |
552 | self.s3_layer = nn.Sequential(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU()),
553 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(inter_channels, out_channels, 1)))
554 |
555 | self.s4_decoder = nn.Sequential(nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU()),
556 | nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(256, out_channels, 1)))
557 |
558 |
559 | def forward(self, x, s1 = None, s2 = None, s3 = None, imsize = None, seg_mode = None):
560 | x = list(tuple([self.s4_decoder(x)]))
561 | outputs = []
562 | for i in range(len(x)):
563 | outputs.append(
564 | interpolate(x[i], imsize, mode='bilinear', align_corners=True))
565 |
566 | # V1 and V2_GC are Segmentation modes, MSLRGR and MSGR Respectively
567 | if seg_mode == 'v2_gc' or seg_mode == 'v1':
568 | s1 = interpolate(self.s1_layer(s1), imsize, mode='bilinear', align_corners=True)
569 | s2 = interpolate(self.s2_layer(s2), imsize, mode='bilinear', align_corners=True)
570 | s3 = interpolate(self.s3_layer(s3), imsize, mode='bilinear', align_corners=True)
571 | outputs = outputs[0]
572 | outputs = s1 + s2 + s3 + outputs
573 | return outputs
574 | else:
575 | return tuple(outputs)[0]
576 |
577 |
578 | class GR_Segmentation(BaseNet):
579 | """
580 | Global-Reasoned (GR) Segmentation module INITIALISATION
581 | init :
582 | nclass, backbone, aux=False, se_loss=False, norm_layer=nn.BatchNorm2d, gcn_search=None, **kwargs
583 |
584 | forward : x (Not used in MTL forward pass)
585 |
586 | """
587 | def __init__(self, nclass, backbone, pretrained, aux=False, se_loss=False, norm_layer=nn.BatchNorm2d, gcn_search=None, **kwargs):
588 | super(GR_Segmentation, self).__init__(nclass, backbone, pretrained, norm_layer=norm_layer, **kwargs)
589 |
590 | in_channels = 512
591 |
592 | # GR module
593 | self.gr_interaction = GR_module(in_channels, nclass, norm_layer, gcn_search)
594 |
595 | # GR decoder
596 | self.gr_decoder = GR_Decoder(in_channels, nclass, norm_layer)
597 |
598 | # !NOTE: - In the MTL forward pass, this forward function is NOT USED !!!!!!!!!!!!!!!!
599 |
600 | def forward(self, x):
601 | imsize = x.size()[2:]
602 |
603 | # Encoder module
604 | s1, s2, s3, s4 = self.base_forward(x)
605 |
606 | # GCN with 1 conv block to bridge to GloRE Unit
607 | x = self.gr_interaction(c4)
608 |
609 | # Decoder module
610 | x = self.gr_decoder(x, imsize)
611 | return x
612 |
613 |
614 | class GR_module(nn.Module):
615 | """
616 | Multi-scale Global Reasoning (GR) Unit
617 | init :
618 | in_channels, out_channels, norm_layer, gcn_search
619 |
620 | forward : x, s1 = None, s2 = None, s3 = None, scene_feat = None, seg_mode = None, model_type = None
621 | -> s1-s4 are Scale-specific features
622 | -> out_channels = num_classes (8)
623 | -> seg_mode : V1 (MSLRGR - multi-scale local reasoning and global reasoning)
624 | V2GC (MSLR - multi-scale local reasoning)
625 |
626 | """
627 | def __init__(self, in_channels, out_channels, norm_layer, gcn_search):
628 | super(GR_module, self).__init__()
629 |
630 | inter_channels = in_channels // 2 # 256
631 | c2 = inter_channels // 2 # 128
632 | c1 = c2 // 2 # 64
633 |
634 | # Simple Conv-BN-ReLU Block
635 | self.conv_s4 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), norm_layer(inter_channels), nn.ReLU())
636 |
637 | # Scale-specific GR unit (GloRE)
638 | self.gcn1 = GloRe_Unit(c1, 64, kernel=1)
639 | self.gcn2 = GloRe_Unit(c2, 64, kernel=1)
640 | self.gcn3 = GloRe_Unit(inter_channels, 64, kernel=1)
641 | self.gcn4 = GloRe_Unit(inter_channels, 64, kernel=1)
642 |
643 | def forward(self, x, s1 = None, s2 = None, s3 = None, scene_feat = None, seg_mode = None, model_type = None):
644 |
645 | feat1 = None
646 | feat2 = None
647 | feat3 = None
648 | feat5 = None
649 |
650 | if seg_mode == 'v2_gc': # MODE - MSGR
651 | feat1, _ = self.gcn1(s1, scene_feat)
652 | feat2, _ = self.gcn2(s2, scene_feat)
653 | feat3, _ = self.gcn3(s3, scene_feat)
654 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type)
655 |
656 | elif seg_mode == 'v1': # MODE - MSLRGR
657 | feat1, feat2, feat3 = s1, s2, s3
658 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type)
659 |
660 | else:
661 | feat4, feat5 = self.gcn4(self.conv_s4(x), scene_feat, model_type)
662 |
663 | return feat1, feat2, feat3, feat4, feat5
664 |
665 | def resnet18_model(pretrained=True, root='~/.encoding/models', **kwargs):
666 | model = Resnet18_main(pretrained, num_classes=8)
667 | return model
668 |
669 |
670 | def get_gcnet(dataset='endovis18', backbone='resnet18_model', num_classes=8, pretrained=False, root='./pretrain_models', **kwargs):
671 | model = GR_Segmentation(nclass=num_classes, backbone=backbone, pretrained=pretrained, root=root, **kwargs)
672 | return model
673 |
--------------------------------------------------------------------------------