├── matlab-eval ├── octave-workspace ├── README.md ├── ReadPhaseLabel.m ├── Evaluate_m2cai.m ├── Evaluate.m ├── Main_m2cai.m └── Main.m ├── resources └── fig_architecture.pdf ├── requirements.txt ├── PositionalEncoding.py ├── README.md ├── dataset.py ├── main.py ├── utils.py ├── decoder.py ├── prototype.py └── hierarch_tcn2.py /matlab-eval/octave-workspace: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/SAHC/HEAD/matlab-eval/octave-workspace -------------------------------------------------------------------------------- /resources/fig_architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/SAHC/HEAD/resources/fig_architecture.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | numpy==1.20.3 3 | scikit_learn==1.0.2 4 | seaborn==0.11.2 5 | thop==0.0.31.post2005241907 6 | torch==1.9.0 7 | torchvision==0.10.0 8 | tqdm==4.61.2 9 | -------------------------------------------------------------------------------- /matlab-eval/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### MATLAB scripts to perform the evaluation 4 | 5 | ### Acknowledgement: 6 | MICCAI M2CAI challenge; the official webpage of the challenge can be found here: http://camma.u-strasbg.fr/m2cai2016 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /matlab-eval/ReadPhaseLabel.m: -------------------------------------------------------------------------------- 1 | function [ outp ] = ReadPhaseLabel( file ) 2 | %READPHASELABEL 3 | % Read the phase label (annotation and prediction) 4 | 5 | fid = fopen(file, 'r'); 6 | 7 | % read the header first 8 | %disp(file) 9 | tline = fgets(fid); 10 | 11 | % read the labels 12 | [outp] = textscan(fid, '%d %s\n'); 13 | 14 | end 15 | 16 | -------------------------------------------------------------------------------- /PositionalEncoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FixedPositionalEncoding(nn.Module): 6 | def __init__(self, embedding_dim, max_length=5000): 7 | super(FixedPositionalEncoding, self).__init__() 8 | 9 | pe = torch.zeros(max_length, embedding_dim) 10 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) 11 | div_term = torch.exp( 12 | torch.arange(0, embedding_dim, 2).float() 13 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim) 14 | ) 15 | pe[:, 0::2] = torch.sin(position * div_term) 16 | pe[:, 1::2] = torch.cos(position * div_term) 17 | pe = pe.unsqueeze(0).transpose(0, 1) 18 | self.register_buffer('pe', pe) 19 | 20 | def forward(self, x): 21 | pos = self.pe[: x.size(0), :]+x 22 | return pos 23 | 24 | 25 | class LearnedPositionalEncoding(nn.Module): 26 | def __init__(self, max_position_embeddings, embedding_dim): 27 | super(LearnedPositionalEncoding, self).__init__() 28 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim ) 29 | 30 | 31 | self.register_buffer( 32 | "position_ids", 33 | torch.arange(max_position_embeddings).expand((1, -1)), 34 | ) 35 | 36 | def forward(self, x, position_ids=None): 37 | if position_ids is None: 38 | # print(self.position_ids) 39 | position_ids = self.position_ids[:, : x.size(2)] 40 | 41 | # print(self.pe(position_ids).size(), x.size()) 42 | 43 | position_embeddings = self.pe(position_ids).transpose(1,2) + x 44 | return position_embeddings 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Segment-level Semantics for Online Phase Recognition from Surgical Videos 2 | 3 | ## Introduction 4 | 5 | This is a PyTorch implementation of IEEE TMI [Exploring Segment-level Semantics for Online Phase Recognition from Surgical Videos](https://arxiv.org/pdf/2111.11044.pdf) 6 | 7 | In this papper, we design a temporal hierarchical network to generate hierarchical high-level segments to refine low-level frame predictions, based on [NETE](https://github.com/ChinaYi/NETE) 8 | 9 | Framework visualization 10 | ![framework visualization](resources/fig_architecture.svg) 11 | 12 | 13 | ## Preparation 14 | 15 | **Datasets and our trained model** 16 | 17 | Cholec80, M2CAI16 and our trained model [GoogleDrive](https://drive.google.com/drive/folders/1grGXjTTUnN717WpN4k7T-bAq3GKy2bLv?usp=sharing) 18 | 19 | 20 | 21 | ## Run the code 22 | 23 | 24 | **Installation** 25 | ``` 26 | matplotlib==3.5.1 27 | numpy==1.20.3 28 | scikit_learn==1.0.2 29 | seaborn==0.11.2 30 | thop==0.0.31.post2005241907 31 | torch==1.9.0 32 | torchvision==0.10.0 33 | tqdm==4.61.2 34 | ``` 35 | 36 | 37 | **Train the model** 38 | ```shell 39 | python main.py --action=hierarch_train --hier=True --first=True --trans=True 40 | ``` 41 | (The model would be saved in "models/") 42 | 43 | **Evaluate** 44 | 45 | ***Generate predictions*** 46 | ```shell 47 | python main.py --action=hierarch_predict --hier=True --first=True --trans=True 48 | ``` 49 | (This would generate predictions in "results/") 50 | 51 | ***Evaluate the predcitions*** 52 | ```shell 53 | matlab-eval/Main.m (cholec80) 54 | matlab-eval/Main_m2cai.m (m2cai16) 55 | ``` 56 | 57 | Mean jaccard: 83.53 +- 5.76 58 | Mean accuracy: 91.85 +- 7.55 59 | Mean precision: 91.75 +- 5.46 60 | Mean recall: 91.74 +- 5.77 61 | 62 | Mean jaccard: 84.92 +- 7.70 63 | Mean accuracy: 91.99 +- 8.44 64 | Mean precision: 93.74 +- 5.77 65 | Mean recall: 92.88 +- 4.83 66 | 67 | 68 | 69 | ## Citation 70 | If this code is useful for your research, please citing: 71 | ``` 72 | @article{ding2022exploring, 73 | title={Exploring Segment-level Semantics for Online Phase Recognition from Surgical Videos}, 74 | author={Ding, Xinpeng and Li, Xiaomeng}, 75 | journal={IEEE Transactions on Medical Imaging}, 76 | year={2022}, 77 | publisher={IEEE} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /matlab-eval/Evaluate_m2cai.m: -------------------------------------------------------------------------------- 1 | function [ res, prec, rec, acc ] = Evaluate( gtLabelID, predLabelID, fps ) 2 | %EVALUATE 3 | % A function to evaluate the performance of the phase recognition method 4 | % providing jaccard index, precision, and recall for each phase 5 | % and accuracy over the surgery. All metrics are computed in a relaxed 6 | % boundary mode. 7 | % OUTPUT: 8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT 9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT 10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT 11 | % acc: the accuracy over the video (relaxed) 12 | 13 | oriT = 10 * fps; % 10 seconds relaxed boundary 14 | 15 | res = []; prec = []; rec = []; 16 | diff = predLabelID - gtLabelID; 17 | updatedDiff = []; 18 | 19 | % obtain the true positive with relaxed boundary 20 | for iPhase = 1:8 % nPhases 21 | gtConn = bwconncomp(gtLabelID == iPhase); 22 | 23 | for iConn = 1:gtConn.NumObjects 24 | startIdx = min(gtConn.PixelIdxList{iConn}); 25 | endIdx = max(gtConn.PixelIdxList{iConn}); 26 | 27 | curDiff = diff(startIdx:endIdx); 28 | 29 | % in the case where the phase is shorter than the relaxed boundary 30 | t = oriT; 31 | if(t > length(curDiff)) 32 | t = length(curDiff); 33 | disp(['Very short phase ' num2str(iPhase)]); 34 | end 35 | 36 | % relaxed boundary 37 | if(iPhase == 5 || iPhase == 6) % Gallbladder dissection and packaging might jump between two phases 38 | curDiff(curDiff(1:t)==-1) = 0; % late transition 39 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7 40 | elseif(iPhase == 7 || iPhase == 8) % Gallbladder dissection might jump between two phases 41 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition 42 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition 43 | else 44 | % general situation 45 | curDiff(curDiff(1:t)==-1) = 0; % late transition 46 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition 47 | end 48 | 49 | updatedDiff(startIdx:endIdx) = curDiff; 50 | end 51 | end 52 | 53 | % compute jaccard index, prec, and rec per phase 54 | for iPhase = 1:8 55 | gtConn = bwconncomp(gtLabelID == iPhase); 56 | predConn = bwconncomp(predLabelID == iPhase); 57 | 58 | if(gtConn.NumObjects == 0) 59 | % no iPhase in current ground truth, assigned NaN values 60 | % SHOULD be excluded in the computation of mean (use nanmean) 61 | res(end+1) = NaN; 62 | prec(end+1) = NaN; 63 | rec(end+1) = NaN; 64 | continue; 65 | end 66 | 67 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:})); 68 | tp = sum(updatedDiff(iPUnion) == 0); 69 | jaccard = tp/length(iPUnion); 70 | jaccard = jaccard * 100; 71 | 72 | % res(end+1, :) = [iPhase jaccard]; 73 | res(end+1) = jaccard; 74 | 75 | % Compute prec and rec 76 | indx = (gtLabelID == iPhase); 77 | 78 | sumTP = tp; % sum(predLabelID(indx) == iPhase); 79 | sumPred = sum(predLabelID == iPhase); 80 | sumGT = sum(indx); 81 | 82 | prec(end+1) = sumTP * 100 / sumPred; 83 | rec(end+1) = sumTP * 100 / sumGT; 84 | end 85 | 86 | % compute accuracy 87 | acc = sum(updatedDiff==0) / length(gtLabelID); 88 | acc = acc * 100; 89 | 90 | end 91 | 92 | -------------------------------------------------------------------------------- /matlab-eval/Evaluate.m: -------------------------------------------------------------------------------- 1 | function [ res, prec, rec, acc ] = Evaluate( gtLabelID, predLabelID, fps ) 2 | %EVALUATE 3 | % A function to evaluate the performance of the phase recognition method 4 | % providing jaccard index, precision, and recall for each phase 5 | % and accuracy over the surgery. All metrics are computed in a relaxed 6 | % boundary mode. 7 | % OUTPUT: 8 | % res: the jaccard index per phase (relaxed) - NaN for non existing phase in GT 9 | % prec: precision per phase (relaxed) - NaN for non existing phase in GT 10 | % rec: recall per phase (relaxed) - NaN for non existing phase in GT 11 | % acc: the accuracy over the video (relaxed) 12 | 13 | oriT = 10 * fps; % 10 seconds relaxed boundary 14 | 15 | res = []; prec = []; rec = []; 16 | diff = predLabelID - gtLabelID; 17 | updatedDiff = []; 18 | 19 | % obtain the true positive with relaxed boundary 20 | for iPhase = 1:7 % nPhases 21 | gtConn = bwconncomp(gtLabelID == iPhase); 22 | 23 | for iConn = 1:gtConn.NumObjects 24 | startIdx = min(gtConn.PixelIdxList{iConn}); 25 | endIdx = max(gtConn.PixelIdxList{iConn}); 26 | 27 | curDiff = diff(startIdx:endIdx); 28 | 29 | % in the case where the phase is shorter than the relaxed boundary 30 | t = oriT; 31 | if(t > length(curDiff)) 32 | t = length(curDiff); 33 | disp(['Very short phase ' num2str(iPhase)]); 34 | end 35 | 36 | % relaxed boundary 37 | % revised for cholec80 dataset !!!!!!!!!!! 38 | if(iPhase == 4 || iPhase == 5) % Gallbladder dissection and packaging might jump between two phases 39 | curDiff(curDiff(1:t)==-1) = 0; % late transition 40 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition % 5 can be predicted as 6/7 at the end > 5 followed by 6/7 41 | elseif(iPhase == 6 || iPhase == 7) % Gallbladder dissection might jump between two phases 42 | curDiff(curDiff(1:t)==-1 | curDiff(1:t)==-2) = 0; % late transition 43 | curDiff(curDiff(end-t+1:end)==1 | curDiff(end-t+1:end)==2) = 0; % early transition 44 | else 45 | % general situation 46 | curDiff(curDiff(1:t)==-1) = 0; % late transition 47 | curDiff(curDiff(end-t+1:end)==1) = 0; % early transition 48 | end 49 | 50 | updatedDiff(startIdx:endIdx) = curDiff; 51 | end 52 | end 53 | 54 | % compute jaccard index, prec, and rec per phase 55 | for iPhase = 1:7 56 | gtConn = bwconncomp(gtLabelID == iPhase); 57 | predConn = bwconncomp(predLabelID == iPhase); 58 | 59 | if(gtConn.NumObjects == 0) 60 | % no iPhase in current ground truth, assigned NaN values 61 | % SHOULD be excluded in the computation of mean (use nanmean) 62 | res(end+1) = NaN; 63 | prec(end+1) = NaN; 64 | rec(end+1) = NaN; 65 | continue; 66 | end 67 | 68 | iPUnion = union(vertcat(predConn.PixelIdxList{:}), vertcat(gtConn.PixelIdxList{:})); 69 | tp = sum(updatedDiff(iPUnion) == 0); 70 | jaccard = tp/length(iPUnion); 71 | jaccard = jaccard * 100; 72 | 73 | % res(end+1, :) = [iPhase jaccard]; 74 | res(end+1) = jaccard; 75 | 76 | % Compute prec and rec 77 | indx = (gtLabelID == iPhase); 78 | 79 | sumTP = tp; % sum(predLabelID(indx) == iPhase); 80 | sumPred = sum(predLabelID == iPhase); 81 | sumGT = sum(indx); 82 | 83 | prec(end+1) = sumTP * 100 / sumPred; 84 | rec(end+1) = sumTP * 100 / sumGT; 85 | end 86 | 87 | % compute accuracy 88 | acc = sum(updatedDiff==0) / length(gtLabelID); 89 | acc = acc * 100; 90 | 91 | end 92 | 93 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.function_base import append 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision import transforms 5 | import os 6 | import numpy as np 7 | import torch 8 | phase2label_dicts = { 9 | 'cholec80':{ 10 | 'Preparation':0, 11 | 'CalotTriangleDissection':1, 12 | 'ClippingCutting':2, 13 | 'GallbladderDissection':3, 14 | 'GallbladderPackaging':4, 15 | 'CleaningCoagulation':5, 16 | 'GallbladderRetraction':6}, 17 | 18 | 'm2cai16':{ 19 | 'TrocarPlacement':0, 20 | 'Preparation':1, 21 | 'CalotTriangleDissection':2, 22 | 'ClippingCutting':3, 23 | 'GallbladderDissection':4, 24 | 'GallbladderPackaging':5, 25 | 'CleaningCoagulation':6, 26 | 'GallbladderRetraction':7} 27 | } 28 | 29 | 30 | 31 | 32 | def phase2label(phases, phase2label_dict): 33 | labels = [phase2label_dict[phase] if phase in phase2label_dict.keys() else len(phase2label_dict) for phase in phases] 34 | return labels 35 | 36 | def label2phase(labels, phase2label_dict): 37 | label2phase_dict = {phase2label_dict[k]:k for k in phase2label_dict.keys()} 38 | phases = [label2phase_dict[label] if label in label2phase_dict.keys() else 'HardFrame' for label in labels] 39 | return phases 40 | 41 | 42 | 43 | class TestVideoDataset(Dataset): 44 | def __init__(self, dataset, root, sample_rate, video_feature_folder): 45 | self.dataset = dataset 46 | self.sample_rate = sample_rate 47 | self.videos = [] 48 | self.labels = [] 49 | ### 50 | self.video_names = [] 51 | if dataset =='cholec80': 52 | self.hard_frame_index = 7 53 | if dataset == 'm2cai16': 54 | self.hard_frame_index = 8 55 | 56 | video_feature_folder = os.path.join(root, video_feature_folder) 57 | label_folder = os.path.join(root, 'annotation_folder') 58 | 59 | 60 | num_len = 0 61 | 62 | ans = 0 63 | for v_f in os.listdir(video_feature_folder): 64 | 65 | 66 | v_f_abs_path = os.path.join(video_feature_folder, v_f) 67 | 68 | v_label_file_abs_path = os.path.join(label_folder, v_f.split('.')[0] + '.txt') 69 | 70 | 71 | labels = self.read_labels(v_label_file_abs_path) 72 | # 73 | labels = labels[::sample_rate] 74 | 75 | videos = np.load(v_f_abs_path)[::sample_rate,] 76 | 77 | num_len += videos.shape[0] 78 | 79 | 80 | self.videos.append(videos) 81 | 82 | self.labels.append(labels) 83 | phase = 1 84 | for i in range(len(labels)-1): 85 | if labels[i] == labels[i+1]: 86 | continue 87 | else: 88 | phase += 1 89 | 90 | ans += 1 91 | self.video_names.append(v_f) 92 | 93 | print('VideoDataset: Load dataset {} with {} videos.'.format(self.dataset, self.__len__())) 94 | 95 | def __len__(self): 96 | return len(self.videos) 97 | 98 | 99 | 100 | def __getitem__(self, item): 101 | video, label, video_name = self.videos[item], self.labels[item], self.video_names[item] 102 | return video, label, video_name 103 | 104 | 105 | def read_labels(self, label_file): 106 | with open(label_file,'r') as f: 107 | phases = [line.strip().split('\t')[1] for line in f.readlines()] 108 | labels = phase2label(phases, phase2label_dicts[self.dataset]) 109 | return labels 110 | -------------------------------------------------------------------------------- /matlab-eval/Main_m2cai.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = '/datasets/m2cai16/annotation_folder'; 5 | for k = 1:14 6 | %num = num2str(k); 7 | if k<10 8 | to_add = ['workflow_video_' num2str(k,'%02d');]; 9 | end 10 | 11 | if k>9 12 | to_add = ['workflow_video_' num2str(k);]; 13 | end 14 | video_name = [gt_root_folder to_add '.txt']; 15 | disp(video_name) 16 | 17 | phaseGroundTruths = [phaseGroundTruths video_name]; 18 | end 19 | 20 | % phaseGroundTruths = {'video41-phase.txt', ... 21 | % 'video42-phase.txt'}; 22 | % phaseGroundTruths 23 | 24 | phases = {'TrocarPlacement', 'Preparation', 'CalotTriangleDissection', ... 25 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ... 26 | 'GallbladderRetraction'}; 27 | 28 | fps = 25; 29 | 30 | for i = 1:length(phaseGroundTruths) 31 | predroot = ''; %% your prediction folder 32 | %predroot = '../../Results/multi/phase'; 33 | %predroot = '../../Results/multi_kl_best_890_882/phase_post'; 34 | phaseGroundTruth = phaseGroundTruths{i}; 35 | predFile = [predroot phaseGroundTruth(end-5:end-4) '_pred.txt']; 36 | disp(predFile) 37 | [gt] = ReadPhaseLabel(phaseGroundTruth); 38 | [pred] = ReadPhaseLabel(predFile); 39 | 40 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 41 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 42 | end 43 | 44 | if(~isempty(find(gt{1} ~= pred{1}))) 45 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 46 | end 47 | 48 | % reassigning the phase labels to numbers 49 | gtLabelID = []; 50 | predLabelID = []; 51 | for j = 1:8 52 | gtLabelID(find(strcmp(phases(j), gt{2}))) = j; 53 | predLabelID(find(strcmp(phases(j), pred{2}))) = j; 54 | end 55 | 56 | % compute jaccard index, precision, recall, and the accuracy 57 | [jaccard(:,i), prec(:,i), rec(:,i), acc(i)] = Evaluate_m2cai(gtLabelID, predLabelID, fps); 58 | 59 | end 60 | 61 | % Compute means and stds 62 | index = find(jaccard>100); 63 | jaccard(index)=100; 64 | meanJaccPerPhase = nanmean(jaccard, 2); 65 | meanJacc = mean(meanJaccPerPhase); 66 | stdJacc = std(meanJaccPerPhase); 67 | for h = 1:8 68 | jaccphase = jaccard(h,:); 69 | meanjaccphase(h) = nanmean(jaccphase); 70 | stdjaccphase(h) = nanstd(jaccphase); 71 | end 72 | 73 | index = find(prec>100); 74 | prec(index)=100; 75 | meanPrecPerPhase = nanmean(prec, 2); 76 | meanPrec = nanmean(meanPrecPerPhase); 77 | stdPrec = nanstd(meanPrecPerPhase); 78 | for h = 1:8 79 | precphase = prec(h,:); 80 | meanprecphase(h) = nanmean(precphase); 81 | stdprecphase(h) = nanstd(precphase); 82 | end 83 | 84 | index = find(rec>100); 85 | rec(index)=100; 86 | meanRecPerPhase = nanmean(rec, 2); 87 | meanRec = mean(meanRecPerPhase); 88 | stdRec = std(meanRecPerPhase); 89 | for h = 1:8 90 | recphase = rec(h,:); 91 | meanrecphase(h) = nanmean(recphase); 92 | stdrecphase(h) = nanstd(recphase); 93 | end 94 | 95 | 96 | meanAcc = mean(acc); 97 | stdAcc = std(acc); 98 | 99 | % Display results 100 | % fprintf('model is :%s\n', model_rootfolder); 101 | disp('================================================'); 102 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 103 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 104 | disp('================================================'); 105 | for iPhase = 1:length(phases) 106 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 107 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 108 | disp('---------------------------------------------'); 109 | end 110 | disp('================================================'); 111 | 112 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) ' +- ' sprintf('%5.2f', stdJacc)]); 113 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) ' +- ' sprintf('%5.2f', stdAcc)]); 114 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) ' +- ' sprintf('%5.2f', stdPrec)]); 115 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) ' +- ' sprintf('%5.2f', stdRec)]); 116 | -------------------------------------------------------------------------------- /matlab-eval/Main.m: -------------------------------------------------------------------------------- 1 | close all; clear all; 2 | 3 | phaseGroundTruths = {}; 4 | gt_root_folder = '/datasets/cholec80/annotation_folder'; %annotation_folder 5 | for k = 41:80 6 | num = num2str(k); 7 | to_add = ['video' num]; 8 | video_name = [gt_root_folder to_add '.txt']; 9 | phaseGroundTruths = [phaseGroundTruths video_name]; 10 | end 11 | % phaseGroundTruths = {'video41-phase.txt', ... 12 | % 'video42-phase.txt'}; 13 | % phaseGroundTruths 14 | 15 | phases = {'Preparation', 'CalotTriangleDissection', ... 16 | 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation', ... 17 | 'GallbladderRetraction'}; 18 | 19 | fps = 25; 20 | 21 | for i = 1:length(phaseGroundTruths) 22 | predroot = ''; % your prediction folder 23 | 24 | %predroot = '../../Results/multi/phase'; 25 | %predroot = '../../Results/multi_kl_best_890_882/phase_post'; 26 | phaseGroundTruth = phaseGroundTruths{i}; 27 | num = cell2mat(regexp(phaseGroundTruth,'\d', 'match')) 28 | predFile = [predroot phaseGroundTruth(end-10:end-4) '_pred.txt']; 29 | 30 | [gt] = ReadPhaseLabel(phaseGroundTruth); 31 | disp(predFile) 32 | [pred] = ReadPhaseLabel(predFile); 33 | 34 | if(size(gt{1}, 1) ~= size(pred{1},1) || size(gt{2}, 1) ~= size(pred{2},1)) 35 | error(['ERROR:' ground_truth_file '\nGround truth and prediction have different sizes']); 36 | end 37 | 38 | if(~isempty(find(gt{1} ~= pred{1}))) 39 | error(['ERROR: ' ground_truth_file '\nThe frame index in ground truth and prediction is not equal']); 40 | end 41 | 42 | % reassigning the phase labels to numbers 43 | gtLabelID = []; 44 | predLabelID = []; 45 | for j = 1:7 46 | % disp(phases(j)) 47 | %sss 48 | %num2str(j-1) 49 | %sss 50 | gtLabelID(find(strcmp(phases(j), gt{2}))) = j; 51 | predLabelID(find(strcmp(phases(j), pred{2}))) = j; 52 | %gtLabelID(find(strcmp(num2str(j-1), gt{2}))) = j; 53 | %predLabelID(find(strcmp(num2str(j-1), pred{2}))) = j; 54 | end 55 | %disp(gtLabelID) 56 | % compute jaccard index, precision, recall, and the accuracy 57 | [jaccard(:,i), prec(:,i), rec(:,i), acc(i)] = Evaluate(gtLabelID, predLabelID, fps); 58 | 59 | end 60 | 61 | accPerVideo= acc; 62 | 63 | % Compute means and stds 64 | index = find(jaccard>100); 65 | jaccard(index)=100; 66 | meanJaccPerPhase = nanmean(jaccard, 2); 67 | meanJaccPerVideo = nanmean(jaccard, 1); 68 | meanJacc = mean(meanJaccPerPhase); 69 | stdJacc = std(meanJaccPerPhase); 70 | for h = 1:7 71 | jaccphase = jaccard(h,:); 72 | meanjaccphase(h) = nanmean(jaccphase); 73 | stdjaccphase(h) = nanstd(jaccphase); 74 | end 75 | 76 | index = find(prec>100); 77 | prec(index)=100; 78 | meanPrecPerPhase = nanmean(prec, 2); 79 | meanPrecPerVideo = nanmean(prec, 1); 80 | meanPrec = nanmean(meanPrecPerPhase); 81 | stdPrec = nanstd(meanPrecPerPhase); 82 | for h = 1:7 83 | precphase = prec(h,:); 84 | meanprecphase(h) = nanmean(precphase); 85 | stdprecphase(h) = nanstd(precphase); 86 | end 87 | 88 | index = find(rec>100); 89 | rec(index)=100; 90 | meanRecPerPhase = nanmean(rec, 2); 91 | meanRecPerVideo = nanmean(rec, 1); 92 | meanRec = mean(meanRecPerPhase); 93 | stdRec = std(meanRecPerPhase); 94 | for h = 1:7 95 | recphase = rec(h,:); 96 | meanrecphase(h) = nanmean(recphase); 97 | stdrecphase(h) = nanstd(recphase); 98 | end 99 | 100 | 101 | meanAcc = mean(acc); 102 | stdAcc = std(acc); 103 | 104 | % Display results 105 | % fprintf('model is :%s\n', model_rootfolder); 106 | disp('================================================'); 107 | disp([sprintf('%25s', 'Phase') '|' sprintf('%6s', 'Jacc') '|'... 108 | sprintf('%6s', 'Prec') '|' sprintf('%6s', 'Rec') '|']); 109 | disp('================================================'); 110 | for iPhase = 1:length(phases) 111 | disp([sprintf('%25s', phases{iPhase}) '|' sprintf('%6.2f', meanJaccPerPhase(iPhase)) '|' ... 112 | sprintf('%6.2f', meanPrecPerPhase(iPhase)) '|' sprintf('%6.2f', meanRecPerPhase(iPhase)) '|']); 113 | disp('---------------------------------------------'); 114 | end 115 | disp('================================================'); 116 | 117 | disp(['Mean jaccard: ' sprintf('%5.2f', meanJacc) ' +- ' sprintf('%5.2f', stdJacc)]); 118 | disp(['Mean accuracy: ' sprintf('%5.2f', meanAcc) ' +- ' sprintf('%5.2f', stdAcc)]); 119 | disp(['Mean precision: ' sprintf('%5.2f', meanPrec) ' +- ' sprintf('%5.2f', stdPrec)]); 120 | disp(['Mean recall: ' sprintf('%5.2f', meanRec) ' +- ' sprintf('%5.2f', stdRec)]); 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Flag 3 | import torch 4 | import torchvision.models as models 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | import pickle 9 | import os 10 | import argparse 11 | import numpy as np 12 | import random 13 | from tqdm import tqdm 14 | from sklearn.model_selection import KFold 15 | 16 | from dataset import * 17 | from prototype import hierarch_train,base_predict 18 | from utils import * 19 | from hierarch_tcn2 import Hierarch_TCN2 20 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | seed = 19980125 24 | # print(device) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 28 | np.random.seed(seed) # Numpy module. 29 | random.seed(seed) # Python random module. 30 | torch.manual_seed(seed) 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--action', default='hierarch_train') 36 | parser.add_argument('--dataset', default="m2cai16") 37 | parser.add_argument('--dataset_path', default="./datasets/{}/") 38 | 39 | # parser.add_argument('--dataset', default="cholec80") 40 | # parser.add_argument('--dataset_path', default="./datasets/cholec80") 41 | parser.add_argument('--sample_rate', default=5, type=int) 42 | parser.add_argument('--test_sample_rate', default=5, type=int) 43 | parser.add_argument('--refine_model', default='gru') 44 | parser.add_argument('--num_classes', default=7) 45 | parser.add_argument('--model', default="Hierarch_TCN2") 46 | parser.add_argument('--learning_rate', default=5e-4, type=float) 47 | parser.add_argument('--epochs', default=100) 48 | parser.add_argument('--gpu', default="3", type=str) 49 | parser.add_argument('--combine_loss', default=False, type=bool) 50 | parser.add_argument('--ms_loss', default=True, type=bool) 51 | 52 | parser.add_argument('--fpn', default=True, type=bool) 53 | parser.add_argument('--output', default=False, type=bool) 54 | parser.add_argument('--feature', default=False, type=bool) 55 | parser.add_argument('--trans', default=False, type=bool) 56 | parser.add_argument('--prototype', default=False, type=bool) 57 | parser.add_argument('--last', default=False, type=bool) 58 | parser.add_argument('--first', default=False, type=bool) 59 | parser.add_argument('--hier', default=False, type=bool) 60 | ####ms-tcn2 61 | parser.add_argument('--num_layers_PG', default="11", type=int) 62 | parser.add_argument('--num_layers_R', default="10", type=int) 63 | parser.add_argument('--num_R', default="3", type=int) 64 | 65 | ##Transformer 66 | parser.add_argument('--head_num', default=8) 67 | parser.add_argument('--embed_num', default=512) 68 | parser.add_argument('--block_num', default=1) 69 | parser.add_argument('--positional_encoding_type', default="learned", type=str, help="fixed or learned") 70 | args = parser.parse_args() 71 | 72 | learning_rate = 5e-5 73 | epochs = 100 74 | refine_epochs = 40 75 | 76 | f_path = os.path.abspath('..') 77 | root_path = f_path.split('surgical_code')[0] 78 | 79 | if args.dataset == 'm2cai16': 80 | refine_epochs = 15 # early stopping 81 | args.num_classes = 8 82 | 83 | 84 | loss_layer = nn.CrossEntropyLoss() 85 | mse_layer = nn.MSELoss(reduction='none') 86 | 87 | 88 | num_stages = 3 # refinement stages 89 | if args.dataset == 'm2cai16': 90 | num_stages = 2 # for over-fitting 91 | num_layers = 12 # layers of prediction tcn e 92 | num_f_maps = 64 93 | dim = 2048 94 | sample_rate = args.sample_rate 95 | test_sample_rate = args.test_sample_rate 96 | num_classes = len(phase2label_dicts[args.dataset]) 97 | args.num_classes = num_classes 98 | # print(args.num_classes) 99 | num_layers_PG = args.num_layers_PG 100 | num_layers_R = args.num_layers_R 101 | num_R = args.num_R 102 | 103 | 104 | 105 | print(args) 106 | 107 | base_model=Hierarch_TCN2(args,num_layers_PG, num_layers_R, num_R, num_f_maps, dim, num_classes) 108 | 109 | 110 | 111 | if args.action == 'hierarch_train': 112 | 113 | video_traindataset = TestVideoDataset(args.dataset, args.dataset_path.format(args.dataset) + '/train_dataset', sample_rate, 'video_feature') 114 | video_train_dataloader = DataLoader(video_traindataset, batch_size=1, shuffle=False, drop_last=False) 115 | video_testdataset = TestVideoDataset(args.dataset, args.dataset_path.format(args.dataset) + '/test_dataset', test_sample_rate, 'video_feature') 116 | video_test_dataloader = DataLoader(video_testdataset, batch_size=1, shuffle=False, drop_last=False) 117 | model_save_dir = 'models/{}/'.format(args.dataset) 118 | hierarch_train(args, base_model, video_train_dataloader, video_test_dataloader, device, save_dir=model_save_dir, debug=True) 119 | 120 | elif args.action == 'hierarch_predict': 121 | 122 | # print('ssss') 123 | 124 | model_path = '' # use your model 125 | 126 | base_model.load_state_dict(torch.load(model_path)) 127 | video_testdataset =TestVideoDataset(args.dataset, root_path+ args.dataset_path.format(args.dataset) + '/test_dataset', test_sample_rate, 'video_feature') 128 | video_test_dataloader = DataLoader(video_testdataset, batch_size=1, shuffle=False, drop_last=False) 129 | base_predict(base_model,args, device, video_test_dataloader) 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from matplotlib import * 3 | import os 4 | import sys 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | # from MulticoreTSNE import MulticoreTSNE as TSNE 10 | 11 | import seaborn as sns 12 | phase2label_dicts = { 13 | 'cholec80':{ 14 | 'Preparation':0, 15 | 'CalotTriangleDissection':1, 16 | 'ClippingCutting':2, 17 | 'GallbladderDissection':3, 18 | 'GallbladderPackaging':4, 19 | 'CleaningCoagulation':5, 20 | 'GallbladderRetraction':6}, 21 | 22 | 'm2cai16':{ 23 | 'TrocarPlacement':0, 24 | 'Preparation':1, 25 | 'CalotTriangleDissection':2, 26 | 'ClippingCutting':3, 27 | 'GallbladderDissection':4, 28 | 'GallbladderPackaging':5, 29 | 'CleaningCoagulation':6, 30 | 'GallbladderRetraction':7} 31 | } 32 | def label2phase(labels, phase2label_dict): 33 | label2phase_dict = {phase2label_dict[k]:k for k in phase2label_dict.keys()} 34 | phases = [label2phase_dict[label] for label in labels] 35 | return phases 36 | max_pool = nn.MaxPool1d(kernel_size=13,stride=5,dilation=3) 37 | 38 | 39 | path_p= "/home/xmli/phwang/ntfs/xinpeng/code/casual_tcn/results/m2cai16/eva/resize/" 40 | def fusion(predicted_list,labels,args): 41 | 42 | all_out_list = [] 43 | resize_out_list = [] 44 | labels_list = [] 45 | all_out = 0 46 | len_layer = len(predicted_list) 47 | weight_list = [1.0/len_layer for i in range (0, len_layer)] 48 | # print(weight_list) 49 | num=0 50 | for out, w in zip(predicted_list, weight_list): 51 | resize_out =F.interpolate(out,size=labels.size(0),mode='nearest') 52 | resize_out_list.append(resize_out) 53 | # align_corners=True 54 | # print(out.size()) 55 | resize_label = F.interpolate(labels.float().unsqueeze(0).unsqueeze(0),size=out.size(2),mode='linear',align_corners=False) 56 | if out.size(2)==labels.size(0): 57 | resize_label = labels 58 | labels_list.append(resize_label.squeeze().long()) 59 | else: 60 | # resize_label = max_pool(labels_list[-1].float().unsqueeze(0).unsqueeze(0)) 61 | resize_label = F.interpolate(labels.float().unsqueeze(0).unsqueeze(0),size=out.size(2),mode='nearest') 62 | # resize_label2 = F.interpolate(resize_label,size=labels.size(0),mode='nearest') 63 | # ,align_corners=True 64 | # print(resize_label.size(), resize_label2.size()) 65 | # print((resize_label2 == labels).sum()/labels.size(0)) 66 | # with open(path_p+'{}.txt'.format(num),"w") as f: 67 | # for labl1, lab2 in zip(resize_label2.squeeze(), labels.squeeze()): 68 | # f.writelines(str(labl1)+'\t'+str(lab2)+'\n') 69 | # num+=1 70 | labels_list.append(resize_label.squeeze().long()) 71 | # labels_list.append(labels.squeeze().long()) 72 | # print(resize_label.size(), out.size()) 73 | # labels_list.append(labels.squeeze().long()) 74 | # assert resize_out.size(2) == resize_label.size(0) 75 | # assert resize_label.size(2) == out.size(2) 76 | # print(out.size()) 77 | # print(resize_label.size()) 78 | # print(resize_out.size()) 79 | # all_out_list.append(out) 80 | # all_out_list.append(resize_out) 81 | 82 | all_out_list.append(out) 83 | # resize_out=out 84 | # all_out = all_out + w*resize_out 85 | 86 | # sss 87 | return all_out, all_out_list, labels_list 88 | 89 | def cosine_distance(a, b): 90 | if a.shape != b.shape: 91 | raise RuntimeError("array {} shape not match {}".format(a.shape, b.shape)) 92 | if a.ndim==1: 93 | a_norm = np.linalg.norm(a) 94 | b_norm = np.linalg.norm(b) 95 | elif a.ndim==2: 96 | a_norm = np.linalg.norm(a, axis=1, keepdims=True) 97 | b_norm = np.linalg.norm(b, axis=1, keepdims=True) 98 | else: 99 | raise RuntimeError("array dimensions {} not right".format(a.ndim)) 100 | similiarity = np.dot(a, b.T)/(a_norm * b_norm) 101 | # dist = 1. - similiarity 102 | return similiarity 103 | 104 | 105 | 106 | 107 | def segment_bars(save_path, *labels): 108 | num_pics = len(labels) 109 | color_map = plt.cm.tab10 110 | fig = plt.figure(figsize=(15, num_pics * 1.5)) 111 | 112 | barprops = dict(aspect='auto', cmap=color_map, 113 | interpolation='nearest', vmin=0, vmax=10) 114 | 115 | for i, label in enumerate(labels): 116 | plt.subplot(num_pics, 1, i+1) 117 | plt.xticks([]) 118 | plt.yticks([]) 119 | plt.imshow([label], **barprops) 120 | 121 | if save_path is not None: 122 | plt.savefig(save_path) 123 | else: 124 | plt.show() 125 | 126 | plt.close() 127 | 128 | 129 | def segment_bars_with_confidence_score(save_path, confidence_score, labels=[]): 130 | num_pics = len(labels) 131 | color_map = plt.cm.tab10 132 | 133 | # axprops = dict(xticks=[], yticks=[0,0.5,1], frameon=False) 134 | barprops = dict(aspect='auto', cmap=color_map, 135 | interpolation='nearest', vmin=0, vmax=15) 136 | fig = plt.figure(figsize=(15, (num_pics+1) * 1.5)) 137 | 138 | interval = 1 / (num_pics+2) 139 | axes = [] 140 | for i, label in enumerate(labels): 141 | i = i + 1 142 | axes.append(fig.add_axes([0.1, 1-i*interval, 0.8, interval - interval/num_pics])) 143 | # ax1.imshow([label], **barprops) 144 | titles = ['Ground Truth','Causal-TCN', 'Causal-TCN + PKI', 'Causal-TCN + MS-GRU'] 145 | for i, label in enumerate(labels): 146 | axes[i].set_xticks([]) 147 | axes[i].set_yticks([]) 148 | axes[i].imshow([label], **barprops) 149 | # axes[i].set_title(titles[i]) 150 | 151 | ax99 = fig.add_axes([0.1, 0.05, 0.8, interval - interval/num_pics]) 152 | # ax99.set_xlim(-len(confidence_score)/15, len(confidence_score) + len(confidence_score)/15) 153 | ax99.set_xlim(0, len(confidence_score)) 154 | ax99.set_ylim(-0.2, 1.2) 155 | ax99.set_yticks([0,0.5,1]) 156 | ax99.set_xticks([]) 157 | 158 | 159 | ax99.plot(range(len(confidence_score)), confidence_score) 160 | 161 | if save_path is not None: 162 | print(save_path) 163 | plt.savefig(save_path) 164 | else: 165 | plt.show() 166 | 167 | plt.close() 168 | 169 | def PKI(confidence_seq, prediction_seq, transition_prior_matrix, alpha, beta, gamma): # fix the predictions that do not meet priors 170 | initital_phase = 0 171 | previous_phase = 0 172 | alpha_count = 0 173 | assert len(confidence_seq) == len(prediction_seq) 174 | refined_seq = [] 175 | for i, prediction in enumerate(prediction_seq): 176 | if prediction == initital_phase: 177 | alpha_count = 0 178 | refined_seq.append(initital_phase) 179 | else: 180 | if prediction != previous_phase or confidence_seq[i] <= beta: 181 | alpha_count = 0 182 | 183 | if confidence_seq[i] >= beta: 184 | alpha_count += 1 185 | 186 | if transition_prior_matrix[initital_phase][prediction] == 1: 187 | refined_seq.append(prediction) 188 | else: 189 | refined_seq.append(initital_phase) 190 | 191 | if alpha_count >= alpha and transition_prior_matrix[initital_phase][prediction] == 1: 192 | initital_phase = prediction 193 | alpha_count = 0 194 | 195 | if alpha_count >= gamma: 196 | initital_phase = prediction 197 | alpha_count = 0 198 | previous_phase = prediction 199 | 200 | 201 | assert len(refined_seq) == len(prediction_seq) 202 | return refined_seq -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from typing import Optional, List 6 | import torch.nn.init as init 7 | import copy 8 | 9 | # class SelfAttention(nn.Module): 10 | # def __init__( 11 | # self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 12 | # ): 13 | # super().__init__() 14 | # self.num_heads = heads 15 | # head_dim = dim // heads 16 | # self.scale = qk_scale or head_dim ** -0.5 17 | 18 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 19 | # self.attn_drop = nn.Dropout(dropout_rate) 20 | # self.proj = nn.Linear(dim, dim) 21 | # self.proj_drop = nn.Dropout(dropout_rate) 22 | 23 | # def forward(self, x): 24 | # B, N, C = x.shape 25 | # qkv = ( 26 | # self.qkv(x) 27 | # .reshape(B, N, 3, self.num_heads, C // self.num_heads) 28 | # .permute(2, 0, 3, 1, 4) 29 | # ) 30 | # q, k, v = ( 31 | # qkv[0], 32 | # qkv[1], 33 | # qkv[2], 34 | # ) # make torchscript happy (cannot use tensor as tuple) 35 | 36 | # attn = (q @ k.transpose(-2, -1)) * self.scale 37 | # attn = attn.softmax(dim=-1) 38 | # attn = self.attn_drop(attn) 39 | 40 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C) 41 | # x = self.proj(x) 42 | # x = self.proj_drop(x) 43 | # return x 44 | 45 | class DecoderLayer(nn.Module): 46 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 47 | dropout=0.1, activation="relu"): 48 | super(DecoderLayer, self).__init__() 49 | d_ff = d_ff or 4*d_model 50 | self.self_attention = self_attention 51 | self.cross_attention = cross_attention 52 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 53 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 54 | self.norm1 = nn.LayerNorm(d_model) 55 | self.norm2 = nn.LayerNorm(d_model) 56 | self.norm3 = nn.LayerNorm(d_model) 57 | self.dropout = nn.Dropout(dropout) 58 | self.activation = F.relu if activation == "relu" else F.gelu 59 | 60 | def forward(self, x, cross, x_mask=None, cross_mask=None): 61 | x = x + self.dropout(self.self_attention( 62 | x, x, x, 63 | attn_mask=x_mask 64 | )) 65 | x = self.norm1(x) 66 | 67 | x = x + self.dropout(self.cross_attention( 68 | x, cross, cross, 69 | attn_mask=cross_mask 70 | )) 71 | 72 | y = x = self.norm2(x) 73 | y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) 74 | y = self.dropout(self.conv2(y).transpose(-1,1)) 75 | 76 | return self.norm3(x+y) 77 | 78 | 79 | class Decoder(nn.Module): 80 | def __init__(self, layers, norm_layer=None): 81 | super(Decoder, self).__init__() 82 | self.layers = nn.ModuleList(layers) 83 | self.norm = norm_layer 84 | 85 | def forward(self, x, cross, x_mask=None, cross_mask=None): 86 | for layer in self.layers: 87 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 88 | 89 | if self.norm is not None: 90 | x = self.norm(x) 91 | 92 | return x 93 | 94 | class TransformerDecoder(nn.Module): 95 | 96 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 97 | super().__init__() 98 | self.layers = _get_clones(decoder_layer, num_layers) 99 | self.num_layers = num_layers 100 | self.norm = norm 101 | self.return_intermediate = return_intermediate 102 | 103 | def forward(self, tgt, memory, 104 | tgt_mask: Optional[Tensor] = None, 105 | memory_mask: Optional[Tensor] = None, 106 | tgt_key_padding_mask: Optional[Tensor] = None, 107 | memory_key_padding_mask: Optional[Tensor] = None, 108 | pos: Optional[Tensor] = None, 109 | query_pos: Optional[Tensor] = None): 110 | output = tgt 111 | T,B,C = memory.shape 112 | intermediate = [] 113 | 114 | for n,layer in enumerate(self.layers): 115 | 116 | residual=True 117 | output,ws = layer(output, memory, tgt_mask=tgt_mask, 118 | memory_mask=memory_mask, 119 | tgt_key_padding_mask=tgt_key_padding_mask, 120 | memory_key_padding_mask=memory_key_padding_mask, 121 | pos=pos, query_pos=query_pos,residual=residual) 122 | 123 | if self.return_intermediate: 124 | intermediate.append(self.norm(output)) 125 | if self.norm is not None: 126 | output = self.norm(output) 127 | if self.return_intermediate: 128 | intermediate.pop() 129 | intermediate.append(output) 130 | 131 | if self.return_intermediate: 132 | return torch.stack(intermediate) 133 | return output 134 | 135 | 136 | 137 | class TransformerDecoderLayer(nn.Module): 138 | 139 | def __init__(self, d_model, nhead, dim_feedforward=64, dropout=0.1, 140 | activation="relu", normalize_before=False): 141 | super().__init__() 142 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 143 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 144 | # Implementation of Feedforward model 145 | self.linear1 = nn.Linear(d_model, dim_feedforward) 146 | self.dropout = nn.Dropout(dropout) 147 | self.linear2 = nn.Linear(dim_feedforward, d_model) 148 | 149 | self.norm1 = nn.LayerNorm(d_model) 150 | self.norm2 = nn.LayerNorm(d_model) 151 | self.norm3 = nn.LayerNorm(d_model) 152 | self.dropout1 = nn.Dropout(dropout) 153 | self.dropout2 = nn.Dropout(dropout) 154 | self.dropout3 = nn.Dropout(dropout) 155 | 156 | self.activation = _get_activation_fn(activation) 157 | self.normalize_before = normalize_before 158 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 159 | return tensor if pos is None else tensor + pos 160 | 161 | def forward_post(self, tgt, memory, 162 | tgt_mask: Optional[Tensor] = None, 163 | memory_mask: Optional[Tensor] = None, 164 | tgt_key_padding_mask: Optional[Tensor] = None, 165 | memory_key_padding_mask: Optional[Tensor] = None, 166 | pos: Optional[Tensor] = None, 167 | query_pos: Optional[Tensor] = None, 168 | residual=True): 169 | q = k = self.with_pos_embed(tgt, query_pos) 170 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 171 | key_padding_mask=tgt_key_padding_mask) 172 | tgt = self.norm1(tgt) 173 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 174 | key=self.with_pos_embed(memory, pos), 175 | value=memory, attn_mask=memory_mask, 176 | key_padding_mask=memory_key_padding_mask) 177 | 178 | 179 | # attn_weights [B,NUM_Q,T] 180 | tgt = tgt + self.dropout2(tgt2) 181 | tgt = self.norm2(tgt) 182 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 183 | tgt = tgt + self.dropout3(tgt2) 184 | tgt = self.norm3(tgt) 185 | return tgt,ws 186 | 187 | def forward_pre(self, tgt, memory, 188 | tgt_mask: Optional[Tensor] = None, 189 | memory_mask: Optional[Tensor] = None, 190 | tgt_key_padding_mask: Optional[Tensor] = None, 191 | memory_key_padding_mask: Optional[Tensor] = None, 192 | pos: Optional[Tensor] = None, 193 | query_pos: Optional[Tensor] = None): 194 | 195 | 196 | # q = k = self.with_pos_embed(tgt2, query_pos) 197 | # # # print(q.size(), k.size(), tgt2.size()) 198 | # tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 199 | # key_padding_mask=tgt_key_padding_mask) 200 | # tgt = tgt + self.dropout1(tgt2) 201 | # print('1', tgt.size(), memory.size()) 202 | # sssss 203 | # tgt2 = self.norm2(tgt) 204 | # print(self.with_pos_embed(tgt2, query_pos).size(), self.with_pos_embed(memory, pos).size()) 205 | memory = memory.permute(2,0,1).contiguous() 206 | # print(memory.size()) 207 | # memory_mask = self._generate_square_subsequent_mask(memory.size(0),tgt2.size(0)) 208 | # memory_mask = memory_mask.cuda() 209 | # print(memory_mask.size()) 210 | # print(tgt2.size(),memory.size()) 211 | # attn_output_weights = torch.bmm(tgt2,memory.transpose(1, 2)) 212 | # print(attn_output_weights.size()) 213 | # sss 214 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 215 | key=self.with_pos_embed(memory, pos), 216 | value=memory, attn_mask=memory_mask, 217 | key_padding_mask=memory_key_padding_mask) 218 | tgt2 = self.norm1(tgt2) 219 | # # print(tgt2.size(), memory.size()) 220 | # tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 221 | # key=self.with_pos_embed(memory, pos), 222 | # value=memory, attn_mask=memory_mask, 223 | # key_padding_mask=memory_key_padding_mask) 224 | # # print(tgt2.size()) 225 | # # sss 226 | tgt2 = tgt + self.dropout2(tgt2) 227 | # # # print('2', tgt.size()) 228 | # tgt2 = self.norm3(tgt) 229 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 230 | # # print(tgt2.size()) 231 | # # tgt = tgt + self.dropout3(tgt2) 232 | # # print() 233 | # print(attn_weights.size()) 234 | # ssss 235 | return tgt2, attn_weights 236 | 237 | def forward(self, tgt, memory, 238 | tgt_mask: Optional[Tensor] = None, 239 | memory_mask: Optional[Tensor] = None, 240 | tgt_key_padding_mask: Optional[Tensor] = None, 241 | memory_key_padding_mask: Optional[Tensor] = None, 242 | pos: Optional[Tensor] = None, 243 | query_pos: Optional[Tensor] = None, 244 | residual=True): 245 | if self.normalize_before: 246 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 247 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 248 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 249 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) 250 | 251 | def _generate_square_subsequent_mask(self, ls, sz): 252 | mask = (torch.triu(torch.ones(ls, sz)) == 1).transpose(0, 1) 253 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 254 | return mask 255 | 256 | 257 | def _get_clones(module, N): 258 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 259 | 260 | 261 | 262 | def _get_activation_fn(activation): 263 | """Return an activation function given a string""" 264 | if activation == "relu": 265 | return F.relu 266 | if activation == "gelu": 267 | return F.gelu 268 | if activation == "glu": 269 | return F.glu 270 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /prototype.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | from tracemalloc import start 5 | import torch 6 | import torchvision.models as models 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import pickle 11 | import os 12 | import re 13 | import argparse 14 | import numpy as np 15 | import random 16 | from thop import profile 17 | from utils import fusion,segment_bars_with_confidence_score 18 | 19 | from tqdm import tqdm 20 | f_path = os.path.abspath('..') 21 | root_path = f_path.split('surgical_code')[0] 22 | 23 | 24 | loss_layer = nn.CrossEntropyLoss() 25 | mse_layer = nn.MSELoss(reduction='none') 26 | 27 | 28 | def hierarch_train(args, model, train_loader, validation_loader, device, save_dir = 'models', debug = False): 29 | 30 | model.to(device) 31 | num_classes = args.num_classes 32 | if not os.path.exists(save_dir): 33 | os.makedirs(save_dir) 34 | 35 | best_epoch = 0 36 | best_acc = 0 37 | model.train() 38 | save_name = 'hier{}_msloss{}_trans{}'.format(args.hier,args.ms_loss,args.trans) 39 | save_dir = os.path.join(save_dir, args.model,save_name) 40 | for epoch in range(1, args.epochs + 1): 41 | if epoch % 30 == 0: 42 | args.learning_rate = args.learning_rate * 0.5 43 | 44 | 45 | correct = 0 46 | total = 0 47 | loss_item = 0 48 | ce_item = 0 49 | ms_item = 0 50 | lc_item = 0 51 | gl_item = 0 52 | optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=1e-5) 53 | max_seq = 0 54 | mean_len = 0 55 | ans = 0 56 | max_phase = 0 57 | for (video, labels, video_name) in (train_loader): 58 | 59 | 60 | 61 | labels = torch.Tensor(labels).long() 62 | 63 | 64 | video, labels = video.to(device), labels.to(device) 65 | 66 | 67 | predicted_list, feature_list, prototype = model(video) 68 | 69 | mean_len += predicted_list[0].size(-1) 70 | ans += 1 71 | all_out, resize_list, labels_list = fusion(predicted_list,labels, args) 72 | 73 | max_seq = max(max_seq, video.size(1)) 74 | 75 | 76 | 77 | loss = 0 78 | 79 | if args.ms_loss: 80 | ms_loss = 0 81 | 82 | for p,l in zip(resize_list,labels_list): 83 | ms_loss += loss_layer(p.transpose(2, 1).contiguous().view(-1, args.num_classes), l.view(-1)) 84 | ms_loss += torch.mean(torch.clamp(mse_layer(F.log_softmax(p[:, :, 1:], dim=1), F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0, max=16)) 85 | loss = loss + ms_loss 86 | ms_item += ms_loss.item() 87 | 88 | optimizer.zero_grad() 89 | loss_item += loss.item() 90 | 91 | 92 | if args.last: 93 | all_out = resize_list[-1] 94 | if args.first: 95 | all_out = resize_list[0] 96 | 97 | # print(all_out.size()) 98 | loss.backward() 99 | 100 | optimizer.step() 101 | 102 | 103 | _, predicted = torch.max(all_out.data, 1) 104 | 105 | # labels = labels_list[-1] 106 | correct += ((predicted == labels).sum()).item() 107 | total += labels.shape[0] 108 | # total +=1 109 | 110 | print('Train Epoch {}: Acc {}, Loss {}, ms {}'.format(epoch, correct / total, loss_item /total, ms_item/total)) 111 | if debug: 112 | # save_dir 113 | test_acc, predicted, out_pro, test_video_name=hierarch_test(args, model, validation_loader, device) 114 | if test_acc > best_acc: 115 | best_acc = test_acc 116 | best_epoch = epoch 117 | if not os.path.exists(save_dir): 118 | os.makedirs(save_dir) 119 | torch.save(model.state_dict(), save_dir + '/best_{}.model'.format(epoch)) 120 | print('Best Test: Acc {}, Epoch {}'.format(best_acc, best_epoch)) 121 | 122 | def hierarch_test(args, model, test_loader, device, random_mask=False): 123 | 124 | model.to(device) 125 | num_classes = args.num_classes 126 | 127 | model.eval() 128 | 129 | with torch.no_grad(): 130 | 131 | 132 | correct = 0 133 | total = 0 134 | loss_item = 0 135 | all_preds = [] 136 | center = torch.ones((1, 64, num_classes), requires_grad=False) 137 | center = center.to(device) 138 | label_correct={} 139 | label_total= {} 140 | probabilty_list = [] 141 | video_name_list=[] 142 | precision=0 143 | recall = 0 144 | ce_item = 0 145 | ms_item = 0 146 | lc_item = 0 147 | gl_item = 0 148 | max_seq = 0 149 | for n_iter,(video, labels, video_name ) in enumerate(test_loader): 150 | 151 | 152 | labels = torch.Tensor(labels).long() 153 | 154 | 155 | video, labels = video.to(device), labels.to(device) 156 | max_seq = max(max_seq, video.size(1)) 157 | 158 | predicted_list, feature_list, _ = model(video) 159 | 160 | all_out, resize_list,labels_list = fusion(predicted_list,labels, args) 161 | 162 | loss = 0 163 | 164 | 165 | 166 | if args.ms_loss: 167 | ms_loss = 0 168 | for p,l in zip(resize_list,labels_list): 169 | # print(p.size()) 170 | ms_loss += loss_layer(p.transpose(2, 1).contiguous().view(-1, args.num_classes), l.view(-1)) 171 | ms_loss += torch.mean(torch.clamp(mse_layer(F.log_softmax(p[:, :, 1:], dim=1), F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0, max=16)) 172 | loss = loss + ms_loss 173 | ms_item += ms_loss.item() 174 | 175 | 176 | loss_item += loss.item() 177 | 178 | if args.last: 179 | all_out = resize_list[-1] 180 | if args.first: 181 | all_out = resize_list[0] 182 | 183 | _, predicted = torch.max(all_out.data, 1) 184 | 185 | 186 | predicted = predicted.squeeze() 187 | 188 | # labels = labels_list[-1] 189 | correct += ((predicted == labels).sum()).item() 190 | total += labels.shape[0] 191 | 192 | 193 | video_name_list.append(video_name) 194 | 195 | all_preds.append(predicted) 196 | 197 | all_out = F.softmax(all_out,dim=1) 198 | 199 | probabilty_list.append(all_out.transpose(1,2)) 200 | # print(max_seq) 201 | print('Test Acc {}, Loss {}, ms {}'.format( correct / total, loss_item /total, ms_item/total)) 202 | # print('BMG precision {}, BMG recall {}'.format(precision/(n_iter+1), recall/(n_iter+1) )) 203 | # print(len(label_total)) 204 | for (kc, vc), (kall, vall) in zip(label_correct.items(),label_total.items()): 205 | print("{} acc: {}".format(kc, vc/vall)) 206 | return correct / total, all_preds, probabilty_list, video_name_list 207 | 208 | 209 | def base_predict(model, args, device,test_loader, pki = False,split='test'): 210 | 211 | phase2label_dicts = { 212 | 'cholec80':{ 213 | 'Preparation':0, 214 | 'CalotTriangleDissection':1, 215 | 'ClippingCutting':2, 216 | 'GallbladderDissection':3, 217 | 'GallbladderPackaging':4, 218 | 'CleaningCoagulation':5, 219 | 'GallbladderRetraction':6}, 220 | 221 | 'm2cai16':{ 222 | 'TrocarPlacement':0, 223 | 'Preparation':1, 224 | 'CalotTriangleDissection':2, 225 | 'ClippingCutting':3, 226 | 'GallbladderDissection':4, 227 | 'GallbladderPackaging':5, 228 | 'CleaningCoagulation':6, 229 | 'GallbladderRetraction':7} 230 | } 231 | model.to(device) 232 | model.eval() 233 | save_name = '{}_hier{}_trans{}'.format(args.sample_rate,args.hier,args.trans) 234 | 235 | 236 | pic_save_dir = 'results/{}/{}/vis/'.format(args.dataset,save_name) 237 | results_dir = 'results/{}/{}/prediction_{}/'.format(args.dataset,save_name,args.sample_rate) 238 | 239 | gt_dir = root_path+'/datasets/surgical/workflow/{}/phase_annotations/'.format(args.dataset) 240 | 241 | if not os.path.exists(pic_save_dir): 242 | os.makedirs(pic_save_dir) 243 | if not os.path.exists(results_dir): 244 | os.makedirs(results_dir) 245 | 246 | with torch.no_grad(): 247 | correct =0 248 | total =0 249 | for (video, labels, mask, video_name) in tqdm(test_loader): 250 | labels = torch.Tensor(labels).long() 251 | mask = torch.Tensor(mask).float() 252 | print(video.size(),video_name,labels.size()) 253 | video = video.to(device) 254 | labels = labels.to(device) 255 | mask = mask.to(device) 256 | # re = model(video) 257 | predicted_list, feature_list, _ = model(video) 258 | 259 | all_out, resize_list,labels_list = fusion(predicted_list,labels, args) 260 | if args.last: 261 | all_out = resize_list[-1] 262 | if args.first: 263 | all_out = resize_list[0] 264 | confidence, predicted = torch.max(F.softmax(all_out.data,1), 1) 265 | 266 | 267 | 268 | 269 | correct += ((predicted == labels).sum()).item() 270 | total += labels.shape[0] 271 | 272 | 273 | 274 | predicted = predicted.squeeze(0).tolist() 275 | confidence = confidence.squeeze(0).tolist() 276 | 277 | labels = [label.item() for label in labels] 278 | 279 | pic_file = video_name[0].split('.')[0] + '-vis.png' 280 | pic_path = os.path.join(pic_save_dir, pic_file) 281 | segment_bars_with_confidence_score(pic_path, confidence_score=confidence, labels=[labels, predicted]) 282 | 283 | 284 | predicted_phases_expand = [] 285 | 286 | for i in predicted: 287 | predicted_phases_expand = np.concatenate((predicted_phases_expand, [i]*5 )) # we downsample the framerate from 25fps to 5fps 288 | 289 | 290 | print(video_name) 291 | 292 | v_n = video_name[0] 293 | 294 | 295 | v_n = re.findall(r"\d+\.?\d*",v_n) 296 | 297 | v_n = float(v_n[0]) 298 | target_video_file = "%02d_pred.txt"%(v_n) 299 | print(target_video_file) 300 | 301 | if args.dataset == 'm2cai16': 302 | 303 | gt_file = 'test_workflow_video_%02d.txt'%(v_n) 304 | else: 305 | 306 | gt_file = 'video%02d-phase.txt'%(v_n) 307 | 308 | g_ptr = open(os.path.join(gt_dir, gt_file), "r") 309 | f_ptr = open(os.path.join(results_dir, target_video_file), 'w') 310 | 311 | 312 | gt = g_ptr.readlines()[1:] ## 313 | 314 | gt = gt[::5] 315 | print(len(gt), len(predicted_phases_expand)) 316 | 317 | if len(gt) > len(predicted_phases_expand): 318 | lst = predicted_phases_expand[-1] 319 | print(len(gt) - len(predicted_phases_expand)) 320 | for i in range(0,len(gt) - len(predicted_phases_expand)): 321 | predicted_phases_expand=np.append(predicted_phases_expand,lst) 322 | else: 323 | predicted_phases_expand = predicted_phases_expand[0:len(gt)] 324 | print(len(gt), len(predicted_phases_expand)) 325 | assert len(predicted_phases_expand) == len(gt) 326 | 327 | # f_ptr.write("Frame\tPhase\n") 328 | for index, line in enumerate(predicted_phases_expand): 329 | # print(int(line),args.dataset) 330 | phase_dict = phase2label_dicts[args.dataset] 331 | p_phase = '' 332 | for k,v in phase_dict.items(): 333 | if v==int(line): 334 | p_phase = k 335 | break 336 | 337 | # line = phase2label_dicts[args.dataset][int(line)] 338 | # f_ptr.write('{}\t{}\n'.format(index, int(line))) 339 | f_ptr.write('{}\t{}\n'.format(index, p_phase)) 340 | f_ptr.close() 341 | 342 | # g_phase_ptr.write("Frame\tPhase\n") 343 | # for index, line in enumerate(gt): 344 | # line = line.strip('\n') 345 | # _, pp = line.split('\t') 346 | # # print(index,pp) 347 | # # pp = phase2label_dicts[args.dataset][pp] 348 | # g_phase_ptr.write('{}\t{}\n'.format(index, pp)) 349 | # g_phase_ptr.close() 350 | print(correct/total) 351 | 352 | -------------------------------------------------------------------------------- /hierarch_tcn2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import namedtuple 5 | import torchvision 6 | from decoder import TransformerDecoder, TransformerDecoderLayer 7 | from PositionalEncoding import FixedPositionalEncoding, LearnedPositionalEncoding 8 | import copy 9 | import numpy as np 10 | class FPN(nn.Module): 11 | def __init__(self,num_f_maps): 12 | super(FPN, self).__init__() 13 | self.latlayer1 = nn.Conv1d(num_f_maps, num_f_maps, kernel_size=1, stride=1, padding=0) 14 | self.latlayer2 = nn.Conv1d( num_f_maps, num_f_maps, kernel_size=1, stride=1, padding=0) 15 | 16 | self.latlayer3 = nn.Conv1d( num_f_maps, num_f_maps, kernel_size=1, stride=1, padding=0) 17 | def _upsample_add(self, x, y): 18 | '''Upsample and add two feature maps. 19 | Args: 20 | x: (Variable) top feature map to be upsampled. 21 | y: (Variable) lateral feature map. 22 | Returns: 23 | (Variable) added feature map. 24 | Note in PyTorch, when input size is odd, the upsampled feature map 25 | with `F.upsample(..., scale_factor=2, mode='nearest')` 26 | maybe not equal to the lateral feature map size. 27 | e.g. 28 | original input size: [N,_,15,15] -> 29 | conv2d feature map size: [N,_,8,8] -> 30 | upsampled feature map size: [N,_,16,16] 31 | So we choose bilinear upsample which supports arbitrary output sizes. 32 | ''' 33 | _,_,W = y.size() 34 | return F.upsample(x, size=W, mode='linear') + y 35 | 36 | def forward(self,out_list): 37 | p4 = out_list[3] 38 | c3 = out_list[2] 39 | c2 = out_list[1] 40 | c1 = out_list[0] 41 | p3 = self._upsample_add(p4, self.latlayer1(c3)) 42 | p2 = self._upsample_add(p3, self.latlayer1(c2)) 43 | p1 = self._upsample_add(p2, self.latlayer1(c1)) 44 | return [p1,p2,p3,p4] 45 | 46 | class Hierarch_TCN2(nn.Module): 47 | 48 | def __init__(self, args, num_layers_PG, num_layers_R, num_R, num_f_maps, dim, num_classes): 49 | super(Hierarch_TCN2, self).__init__() 50 | # self.PG = Prediction_Generation(args, num_layers_PG, num_f_maps, dim, num_classes) 51 | self.PG = BaseCausalTCN(num_layers_PG, num_f_maps, dim, num_classes) 52 | 53 | self.conv_out_list = [nn.Conv1d(num_f_maps, num_classes, 1) for s in range(num_R)] 54 | self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 55 | # self.first_linear = nn.Linear(num_f_maps*4, num_f_maps, 1) 56 | self.conv_out1 = nn.Conv1d(num_f_maps*3, num_classes, 1) 57 | self.Rs = nn.ModuleList([copy.deepcopy(Refinement(args, num_layers_R, num_f_maps, num_classes, num_classes,self.conv_out)) for s in range(num_R)]) 58 | self.use_fpn = args.fpn 59 | self.use_output = args.output 60 | self.use_feature = args.feature 61 | self.use_trans = args.trans 62 | # self.prototpye=[] 63 | if args.fpn: 64 | self.fpn = FPN(num_f_maps) 65 | if args.trans: 66 | self.query = nn.Embedding(num_classes, num_f_maps) 67 | 68 | 69 | if args.positional_encoding_type == "learned": 70 | self.position_encoding = LearnedPositionalEncoding( 71 | 19971, num_f_maps 72 | ) 73 | elif args.positional_encoding_type == "fixed": 74 | self.position_encoding = FixedPositionalEncoding( 75 | num_f_maps, 76 | ) 77 | else: 78 | self.position_encoding=None 79 | print('position encoding :', args.positional_encoding_type) 80 | decoder_layer = TransformerDecoderLayer(num_f_maps, args.head_num, args.embed_num, 81 | 0.1, 'relu',normalize_before=True) 82 | decoder_norm = nn.LayerNorm(num_f_maps) 83 | self.decoder = TransformerDecoder(decoder_layer, args.block_num, decoder_norm, 84 | return_intermediate=False) 85 | self.prototpye = torch.nn.Parameter(torch.zeros(1, 64, num_classes), requires_grad=True) 86 | 87 | def forward(self, x): 88 | out_list = [] 89 | f_list = [] 90 | x = x.permute(0,2,1) 91 | 92 | f, out1 = self.PG(x) 93 | 94 | 95 | f_list.append(f) 96 | if not self.use_fpn: 97 | out_list.append(out1) 98 | 99 | # print(out.size()) 100 | 101 | 102 | for R in self.Rs: 103 | # F.softmax(out, dim=1) 104 | if self.use_output: 105 | f, out1 = R(out1) 106 | out_list.append(out1) 107 | # print(out1.size()) 108 | else: 109 | f, out1 = R(f) 110 | # print(f.size()) 111 | # print(out.size()) 112 | 113 | f_list.append(f) 114 | if not self.use_fpn: 115 | out_list.append(out1) 116 | # outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 117 | # print(len(out_list)) 118 | if self.use_fpn: 119 | f_list = self.fpn(f_list) 120 | for f in f_list: 121 | # print(f.size()) 122 | out_list.append(self.conv_out(f)) 123 | # sss 124 | if self.use_feature: 125 | last_feature = f_list[-1] 126 | refine_out = torch.matmul(self.prototpye.transpose(1,2),last_feature) 127 | out_list[-1] = 0.5*out_list[-1] + 0.5*refine_out 128 | 129 | # print(len(f_list)) 130 | 131 | 132 | if self.use_trans: 133 | 134 | 135 | for i in range(len(f_list)): 136 | if self.position_encoding == None: 137 | f_list[i] = f_list[i] 138 | else: 139 | # print(f_list[i].size()) 140 | f_list[i] = self.position_encoding(f_list[i]) 141 | # query_embed = self.query.weight.unsqueeze(1).repeat( 1, batch_size, 1) 142 | 143 | # first_feature = f_list[0] 144 | first_feature_list= [] 145 | first_feature_list.append(f_list[0]) 146 | first_feature = f_list[0].permute(2,0,1) 147 | # print(len(f_list)) 148 | # sss 149 | for i in range(1, len(f_list)): 150 | middle_feature = f_list[i] 151 | 152 | first_feature = self.decoder(first_feature, middle_feature, 153 | memory_key_padding_mask=None, pos=None, query_pos=None) 154 | # print(first_feature.size(),middle_feature.size()) 155 | 156 | # attention_w = torch.matmul(first_feature.transpose(1,2), middle_feature) 157 | # attention_w = F.softmax(attention_w,dim=2) 158 | # new_first_feature = torch.matmul(attention_w, middle_feature.transpose(1,2)) 159 | # print(new_first_feature.transpose().size()) 160 | # ssss 161 | # first_feature_list.append(new_first_feature.transpose(1,2)) 162 | # first_feature_list.append(new_first_feature.permute(1,2,0)) 163 | # last_feature = f_list[-1] 164 | # middle_feature = f_list[-2] 165 | # # print(pos_embd.size()) 166 | 167 | # # x = self.conv_out(out) # (bs, c, l) 168 | # # out = last_feature.permute(2,0,1) 169 | # first_feature = f_list[0].permute(2,0,1) 170 | # # print(first_feature.size(), last_feature.size()) 171 | # first_feature = self.decoder(first_feature, last_feature, 172 | # memory_key_padding_mask=None, pos=None, query_pos=None) 173 | # f_list[0] = first_feature.permute(1,2,0) 174 | 175 | # f_list[0] = torch.cat(first_feature_list,dim=1) 176 | # f_list[0] = torch.stack(first_feature_list,dim=1).sum(dim=1) 177 | 178 | # print(f_list[0].size()) 179 | # print(f_list[1].size()) 180 | # reduced_first_feature = self.first_linear(f_list[0].transpose(1,2)).transpose(1,2) 181 | # reduced_first_feature=f_list[0] 182 | reduced_first_feature=first_feature.permute(1,2,0) 183 | out_list[0] = self.conv_out(reduced_first_feature) 184 | # for idx, f in enumerate(f_list): 185 | # if idx == 0: 186 | # out_list.append(self.conv_out1(f)) 187 | # else: 188 | # out_list.append(self.conv_out(f)) 189 | 190 | 191 | # out_list[-1] = pro 192 | return out_list, f_list, self.prototpye 193 | 194 | class BaseCausalTCN(nn.Module): 195 | def __init__(self, num_layers, num_f_maps, dim, num_classes): 196 | print(num_layers) 197 | super(BaseCausalTCN, self).__init__() 198 | self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 199 | self.layers = nn.ModuleList( 200 | [copy.deepcopy(DilatedResidualCausalLayer(2 ** i, num_f_maps, num_f_maps)) for i in range(num_layers)]) 201 | self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 202 | self.channel_dropout = nn.Dropout2d() 203 | # self.downsample = nn.Linear(num_f_maps,num_f_maps, kernel_size=3, stride=2,dilation=3) 204 | # self.center = torch.nn.Parameter(torch.zeros(1, 64, num_classes), requires_grad=False) 205 | self.num_classes = num_classes 206 | 207 | 208 | def forward(self, x, labels=None, mask=None,test=False): 209 | # x = x.permute(0,2,1) # (bs,l,c) -> (bs, c, l) 210 | 211 | if mask is not None: 212 | # print(x.size(),mask.size()) 213 | x = x * mask 214 | 215 | x= x.unsqueeze(3) # of shape (bs, c, l, 1) 216 | x = self.channel_dropout(x) 217 | x = x.squeeze(3) 218 | 219 | out = self.conv_1x1(x) 220 | for layer in self.layers: 221 | out = layer(out) 222 | 223 | 224 | x = self.conv_out(out) # (bs, c, l) 225 | 226 | 227 | return out, x 228 | 229 | 230 | class Prediction_Generation(nn.Module): 231 | def __init__(self, args, num_layers, num_f_maps, dim, num_classes): 232 | super(Prediction_Generation, self).__init__() 233 | 234 | self.num_layers = num_layers 235 | 236 | self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 237 | 238 | self.conv_dilated_1 = nn.ModuleList(( 239 | [copy.deepcopy(DilatedResidualCausalLayer(2**(num_layers-1-i), num_f_maps, num_f_maps)) 240 | for i in range(num_layers)] 241 | )) 242 | 243 | # self.conv_dilated_1 = nn.ModuleList(( 244 | # nn.Conv1d(num_f_maps, num_f_maps, 3, padding=2**(num_layers-1-i), dilation=2**(num_layers-1-i)) 245 | # for i in range(num_layers) 246 | # )) 247 | self.conv_dilated_2 = nn.ModuleList(( 248 | [copy.deepcopy(DilatedResidualCausalLayer(2**i, num_f_maps, num_f_maps)) 249 | for i in range(num_layers)] 250 | )) 251 | # self.conv_dilated_2 = nn.ModuleList(( 252 | # nn.Conv1d(num_f_maps, num_f_maps, 3, padding=2**i, dilation=2**i) 253 | # for i in range(num_layers) 254 | # )) 255 | 256 | self.conv_fusion = nn.ModuleList(( 257 | nn.Conv1d(2*num_f_maps, num_f_maps, 1) 258 | for i in range(num_layers) 259 | 260 | )) 261 | 262 | 263 | self.dropout = nn.Dropout() 264 | 265 | self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 266 | 267 | def forward(self, x): 268 | 269 | f = self.conv_1x1_in(x) 270 | 271 | for i in range(self.num_layers): 272 | f_in = f 273 | f = self.conv_fusion[i](torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1)) 274 | f = F.relu(f) 275 | f = self.dropout(f) 276 | f = f + f_in 277 | 278 | out = self.conv_out(f) 279 | 280 | return f, out 281 | 282 | class Refinement(nn.Module): 283 | def __init__(self, args, num_layers, num_f_maps, dim, num_classes, conv_out): 284 | super(Refinement, self).__init__() 285 | self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 286 | self.layers = nn.ModuleList([copy.deepcopy(DilatedResidualCausalLayer(2**i, num_f_maps, num_f_maps)) for i in range(num_layers)]) 287 | self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 288 | # self.conv_out = conv_out 289 | self.max_pool_1x1 = nn.AvgPool1d(kernel_size=7,stride=3) 290 | self.use_output = args.output 291 | self.hier = args.hier 292 | 293 | def forward(self, x): 294 | if self.use_output: 295 | out = self.conv_1x1(x) 296 | else: 297 | out = x 298 | for layer in self.layers: 299 | out = layer(out) 300 | if self.hier: 301 | f = self.max_pool_1x1(out) 302 | else: 303 | f = out 304 | out = self.conv_out(f) 305 | 306 | return f, out 307 | 308 | class DilatedResidualLayer(nn.Module): 309 | def __init__(self, dilation, in_channels, out_channels): 310 | super(DilatedResidualLayer, self).__init__() 311 | self.conv_dilated = nn.Conv1d(in_channels, out_channels, 3, padding=dilation, dilation=dilation) 312 | self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 313 | self.dropout = nn.Dropout() 314 | 315 | def forward(self, x): 316 | out = F.relu(self.conv_dilated(x)) 317 | out = self.conv_1x1(out) 318 | out = self.dropout(out) 319 | 320 | return x + out 321 | 322 | class DilatedResidualCausalLayer(nn.Module): 323 | def __init__(self, dilation, in_channels, out_channels, padding=None): 324 | super(DilatedResidualCausalLayer, self).__init__() 325 | if padding == None: 326 | 327 | self.padding = 2 * dilation 328 | else: 329 | self.padding=padding 330 | # causal: add padding to the front of the input 331 | self.conv_dilated = nn.Conv1d(in_channels, out_channels, 3, padding=0, dilation=dilation) # 332 | # self.conv_dilated = nn.Conv1d(in_channels, out_channels, 3, padding=dilation, dilation=dilation) 333 | self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 334 | self.dropout = nn.Dropout() 335 | 336 | def forward(self, x): 337 | out = F.pad(x, [self.padding, 0], 'constant', 0) # add padding to the front of input 338 | out = F.relu(self.conv_dilated(out)) 339 | out = self.conv_1x1(out) 340 | out = self.dropout(out) 341 | return (x + out) --------------------------------------------------------------------------------