├── README.md ├── non_local.py ├── model.py ├── utils.py ├── Dataset.py ├── inference.py ├── dataset.py └── trainRNN.py /README.md: -------------------------------------------------------------------------------- 1 | # KFGNet 2 | This is source code for video classification stage in KFGNet. 3 | 4 | 5 | Code is still being uploading. 6 | Rest parts of code will be available soon. 7 | 8 | Baidu Netdisk link of ultrasonic data : https://pan.baidu.com/s/1_9PRDvjaJbpZfcLL1XA5OA 9 | Code : wuj2 10 | 11 | -------------------------------------------------------------------------------- /non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NonLocalBlock(nn.Module): 6 | def __init__(self, channel): 7 | super(NonLocalBlock, self).__init__() 8 | self.inter_channel = channel // 2 9 | self.fc_phi = nn.Linear(in_features=channel, out_features=self.inter_channel) 10 | self.fc_theta = nn.Linear(in_features=channel, out_features=self.inter_channel) 11 | self.fc_g = nn.Linear(in_features=channel, out_features=self.inter_channel) 12 | self.softmax = nn.Softmax(dim=1) 13 | self.fc_mask = nn.Linear(in_features=self.inter_channel, out_features=channel) 14 | 15 | def forward(self, x): 16 | b, c, d = x.size() 17 | 18 | x_phi = self.fc_phi(x) 19 | x_theta = self.fc_theta(x).permute(0, 2, 1).contiguous() 20 | x_g = self.fc_g(x).permute(0, 2, 1).contiguous() 21 | 22 | mul_theta_phi = torch.matmul(x_theta, x_phi) 23 | mul_theta_phi = self.softmax(mul_theta_phi) 24 | mul_theta_phi_g = torch.matmul(mul_theta_phi, x_g).permute(0, 2, 1).contiguous() 25 | 26 | mask = self.fc_mask(mul_theta_phi_g) 27 | out = mask + x 28 | 29 | return out 30 | 31 | 32 | if __name__=='__main__': 33 | model = NonLocalBlock(channel=512) 34 | 35 | input = torch.randn(32, 30, 512) 36 | 37 | out = model(input) 38 | print(out.shape) 39 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from non_local import NonLocalBlock 5 | 6 | 7 | class scoreNet(nn.Module): 8 | def __init__(self): 9 | super(scoreNet, self).__init__() 10 | self.fc_embed1 = nn.Linear(in_features=6, out_features=256, bias=True) 11 | self.ln1 = nn.LayerNorm(256) 12 | self.fc_embed2 = nn.Linear(in_features=4096, out_features=512, bias=True) 13 | self.ln2 = nn.LayerNorm(512) 14 | 15 | self.nonlocalblock = NonLocalBlock(channel=768) 16 | 17 | # conv1d 18 | # self.conv1 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=1) 19 | # self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, groups=1) 20 | # self.conv3 = nn.Conv1d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) 21 | 22 | # lstm 23 | self.lstm = nn.LSTM(input_size=768, hidden_size=512, num_layers=2, batch_first=True, bidirectional=True) 24 | self.fc1 = nn.Linear(in_features=1024, out_features=128, bias=True) 25 | self.fc2 = nn.Linear(in_features=128, out_features=1, bias=True) 26 | 27 | def forward(self, x): 28 | x1 = self.fc_embed1(x[:, :, :6]) 29 | x1 = self.ln1(x1) 30 | x1 = F.relu(x1) 31 | x2 = self.fc_embed2(x[:, :, 6:]) 32 | x2 = self.ln2(x2) 33 | x2 = F.relu(x2) 34 | 35 | x = torch.cat((x1, x2), 2) 36 | 37 | # conv1d 38 | # x = x.permute(0, 2, 1) 39 | # x = F.relu(self.conv1(x)) 40 | # x = F.relu(self.conv2(x)) 41 | # x = torch.sigmoid(self.conv3(x)) 42 | # x = x.permute(0, 2, 1) 43 | 44 | 45 | # non-local 46 | x = self.nonlocalblock(x) 47 | # lstm 48 | x, _ = self.lstm(x) 49 | x = F.relu(self.fc1(x)) 50 | x = torch.sigmoid(self.fc2(x)) 51 | 52 | x = x.squeeze(2) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | 6 | def check_dir(path): 7 | folder = os.path.exists(path) 8 | 9 | if not folder: 10 | os.makedirs(path) 11 | 12 | def tic(): 13 | return time.time() 14 | 15 | def toc(start): 16 | stop = time.time() 17 | print('\nUsed {:.2f} s\n'.format(stop - start)) 18 | return stop - start 19 | 20 | def compute_iou(rec1, rec2): 21 | # xmin, ymin, xmax, ymax = (coor[0], coor[1], coor[2], coor[3]) 22 | # computing area of each rectangles 23 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 24 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 25 | 26 | # computing the sum_area 27 | sum_area = S_rec1 + S_rec2 28 | 29 | # find the each edge of intersect rectangle 30 | left_line = max(rec1[0], rec2[0]) 31 | right_line = min(rec1[2], rec2[2]) 32 | top_line = min(rec1[3], rec2[3]) 33 | bottom_line = max(rec1[1], rec2[1]) 34 | 35 | # judge if there is an intersect 36 | if left_line >= right_line or top_line <= bottom_line: 37 | return 0 38 | else: 39 | intersect = (right_line - left_line) * (top_line - bottom_line) 40 | return (intersect / (sum_area - intersect)) * 1.0 41 | 42 | 43 | def similarity(v_data, crucial_frame_num, sfeature_weight=1, siou_weight=1, sframe_weight=1): 44 | crucial_frame = v_data[crucial_frame_num,:] 45 | 46 | distance = np.sqrt(np.sum((crucial_frame[6:] - v_data[:,6:]) ** 2, axis=1)) 47 | max_distance = np.max(distance) 48 | sfeature = np.exp(- distance / (max_distance + 1e-8)) 49 | 50 | iou_list = [] 51 | for each_rec in v_data[:,1:5]: 52 | iou = compute_iou(crucial_frame[1:5], each_rec) 53 | iou_list.append(iou) 54 | siou = np.array(iou_list) 55 | 56 | frame_distance = abs(v_data[:,0] - crucial_frame[0]) 57 | max_frame_distance = np.max(frame_distance) 58 | sframe = 1 - frame_distance / (max_frame_distance + 1e-8) 59 | 60 | sim = (sfeature * sfeature_weight + siou * siou_weight + sframe * sframe_weight) / 3 61 | 62 | return sim -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import glob 3 | import json 4 | import numpy as np 5 | from sklearn.preprocessing import MinMaxScaler 6 | from utils import * 7 | import torch 8 | 9 | 10 | TOTAL_FRAME_NUM = 300 11 | 12 | 13 | class VideoDataset(Dataset): 14 | def __init__(self, data_dir): 15 | self.videos = glob.glob(data_dir + '/*.json') 16 | 17 | def __len__(self): 18 | return len(self.videos) 19 | 20 | def __getitem__(self, index): 21 | v_data = [] 22 | 23 | video_data = self.videos[index] 24 | with open(video_data, encoding="utf-8") as f: 25 | json_data = json.load(f) 26 | crucial_frame = json_data.get('frame_list')[1] 27 | det_result = json_data.get('det_result') 28 | 29 | best_diff = 1000 30 | crucial_frame_index = 0 31 | for idx, frame in enumerate(det_result): 32 | frame_data = [] 33 | frame_num = frame.get('frame_num') 34 | 35 | diff = abs(crucial_frame - frame_num) 36 | 37 | if diff < best_diff: 38 | best_diff = diff 39 | crucial_frame_index = idx 40 | 41 | frame_data.append(frame_num) 42 | bbox = frame.get('boxes')[0] 43 | frame_data += bbox 44 | score = frame.get('scores')[0] 45 | frame_data.append(score) 46 | feature = frame.get('features')[0] 47 | frame_data += feature 48 | 49 | v_data.append(np.array(frame_data)) 50 | 51 | v_data = np.array(v_data).astype(np.float32) 52 | 53 | 54 | if len(v_data) >= TOTAL_FRAME_NUM: 55 | v_len = len(v_data) 56 | idxs = np.linspace(0, v_len-1, TOTAL_FRAME_NUM-1, dtype=int).tolist() 57 | idxs.append(crucial_frame_index) 58 | idxs.sort() 59 | new_idx = idxs.index(crucial_frame_index) 60 | 61 | v_data = v_data[idxs, :] 62 | sim = similarity(v_data, new_idx) 63 | else: 64 | sim = similarity(v_data, crucial_frame_index) 65 | label = sim.astype(np.float32) 66 | 67 | scaler = MinMaxScaler() 68 | v_data[:, :6] = scaler.fit_transform(v_data[:, :6]) 69 | 70 | return torch.Tensor(v_data), torch.Tensor(label) 71 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import json 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from sklearn.preprocessing import MinMaxScaler 7 | from utils import * 8 | from model import scoreNet 9 | 10 | 11 | def read_video(video_data): 12 | v_data = [] 13 | frame_num_list = [] 14 | 15 | with open(video_data, encoding="utf-8") as f: 16 | json_data = json.load(f) 17 | video_name = json_data.get('video_name') 18 | crucial_frame = json_data.get('frame_list')[1] 19 | det_result = json_data.get('det_result') 20 | 21 | for idx, frame in enumerate(det_result): 22 | frame_data = [] 23 | frame_num = frame.get('frame_num') 24 | frame_num_list.append(frame_num) 25 | 26 | frame_data.append(frame_num) 27 | bbox = frame.get('boxes')[0] 28 | frame_data += bbox 29 | score = frame.get('scores')[0] 30 | frame_data.append(score) 31 | feature = frame.get('features')[0] 32 | frame_data += feature 33 | 34 | v_data.append(np.array(frame_data)) 35 | 36 | v_data = np.array(v_data).astype(np.float32) 37 | 38 | scaler = MinMaxScaler() 39 | v_data[:, :6] = scaler.fit_transform(v_data[:, :6]) 40 | 41 | return torch.Tensor(v_data).unsqueeze(0), frame_num_list, video_name, crucial_frame 42 | 43 | 44 | 45 | device = torch.device('cuda:0') 46 | 47 | model = scoreNet() 48 | state_dict = torch.load('./OutTrain/RNN-fold0-best.ckpt') 49 | # create new OrderedDict that does not contain `module.` 50 | # new_state_dict = OrderedDict() 51 | # for k, v in state_dict.items(): 52 | # name = k[7:] # remove `module.` 53 | # new_state_dict[name] = v 54 | # load params 55 | model.load_state_dict(state_dict) 56 | 57 | model = model.to(device) 58 | 59 | model.eval() 60 | 61 | fwriter_result = open('./result.csv', 'w') 62 | 63 | 64 | with torch.no_grad(): 65 | for i, videos in enumerate(glob.glob('../data_for_miccai/test/*.json')): 66 | 67 | video, frame_num_list, video_name, crucial_frame = read_video(videos) 68 | print(video_name) 69 | print('crucial_frame:' + str(crucial_frame)) 70 | 71 | video = video.to(device) 72 | 73 | output = model(video) 74 | 75 | pred_idx = F.softmax(output, dim=1).data.max(1)[1] 76 | 77 | pred_frame = frame_num_list[pred_idx] 78 | 79 | print('pred_frame:' + str(pred_frame)) 80 | 81 | fwriter_result.write('{},{},{}\n'.format(video_name, crucial_frame, pred_frame)) 82 | fwriter_result.flush() 83 | 84 | fwriter_result.close() 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import glob 4 | import cv2 5 | import numpy as np 6 | import random 7 | from scipy.ndimage.interpolation import rotate 8 | from utils import generate_sim 9 | 10 | 11 | class VideoDataset(Dataset): 12 | def __init__(self, file_names, transform=True, resize=(112, 112)): 13 | super(VideoDataset, self).__init__() 14 | self.data = file_names 15 | self.transform = transform 16 | self.resize = resize 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, index): 22 | volume = [] 23 | 24 | volume_path = self.data[index] 25 | 26 | for img_path in glob.glob(volume_path + '/*.jpg'): 27 | img = cv2.imread(img_path) 28 | img = cv2.resize(img, self.resize, interpolation=cv2.INTER_CUBIC) 29 | volume.append(img) 30 | 31 | volume = np.array(volume).astype(np.float32) 32 | 33 | sim = generate_sim(volume) 34 | 35 | volume = volume / 255 36 | 37 | if self.transform: 38 | if random.choice([True, False]): 39 | volume = random_intensity_shift(volume, 0.1, 0.1) 40 | 41 | if random.choice([True, False]): 42 | volume = random_flip_3d(volume) 43 | 44 | 45 | if volume_path.split('_')[-1] == 'm': 46 | label = 1 47 | else: 48 | label = 0 49 | 50 | volume = torch.Tensor(volume).permute(3, 0, 1, 2) 51 | # print(volume.shape) 52 | # print(volume) 53 | # print(label) 54 | return volume, label, sim 55 | 56 | 57 | def random_flip_3d(volume): 58 | if random.choice([True, False]): 59 | volume = volume[::-1, :, :].copy() # here must use copy(), otherwise error occurs 60 | if random.choice([True, False]): 61 | volume = volume[:, ::-1, :].copy() 62 | if random.choice([True, False]): 63 | volume = volume[:, :, ::-1].copy() 64 | 65 | return volume 66 | 67 | 68 | def random_rotation_3d(volume, max_angles): 69 | volume1 = volume 70 | # rotate along x-axis 71 | angle = random.uniform(-max_angles[2], max_angles[2]) 72 | volume2 = rotate(volume1, angle, order=2, mode='nearest', axes=(0, 1), reshape=False) 73 | 74 | # rotate along y-axis 75 | angle = random.uniform(-max_angles[1], max_angles[1]) 76 | volume_rot = rotate(volume2, angle, order=2, mode='nearest', axes=(0, 2), reshape=False) 77 | 78 | return volume_rot 79 | 80 | 81 | def random_intensity_shift(volume, max_offset, max_scale_delta): 82 | 83 | offset = random.uniform(-max_offset, max_offset) 84 | scale = random.uniform(1 - max_scale_delta, 1 + max_scale_delta) 85 | 86 | volume = volume.copy() 87 | volume += offset 88 | volume *= scale 89 | 90 | return volume 91 | 92 | -------------------------------------------------------------------------------- /trainRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.nn.utils.rnn import pad_sequence 6 | from model import scoreNet 7 | import numpy as np 8 | import argparse 9 | from Dataset import VideoDataset 10 | from utils import * 11 | 12 | 13 | parser = argparse.ArgumentParser(description='RNN') 14 | parser.add_argument('--fold', type=int, default=0) 15 | parser.add_argument('--num-workers', type=int, default=12) 16 | parser.add_argument('--use-gpus', default='1', type=str) 17 | parser.add_argument('--batch-size', type=int, default=64) 18 | parser.add_argument('--data-dir', type=str, default='../data3') 19 | parser.add_argument('--output-dir', type=str, default='OutTrain') 20 | parser.add_argument('--lr', type=float, default=1e-3) 21 | parser.add_argument('--weight-decay', type=float, default=1e-8) 22 | parser.add_argument('--epochs', type=int, default=20) 23 | parser.add_argument('--val-interval', type=int, default=2) 24 | parser.add_argument('--save-model-interval', type=int, default=4) 25 | 26 | args = parser.parse_args() 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = args.use_gpus 29 | fold_k = args.fold 30 | val_interval = args.val_interval 31 | save_model_interval = args.save_model_interval 32 | learning_rate = args.lr 33 | weight_decay = args.weight_decay 34 | batch_size = args.batch_size 35 | epochs = args.epochs 36 | num_workers = args.num_workers 37 | data_dir = args.data_dir 38 | output_dir = args.output_dir 39 | 40 | # print('\n\n----------------------RNN-Fold-{}------------------------'.format(fold_k)) 41 | # print('Batch_Size: \t\t{}'.format(batch_size)) 42 | # print('Learning_Rate: \t\t{}'.format(learning_rate)) 43 | # print('Weigh_Decay: \t\t{}'.format(weight_decay)) 44 | # print('Epochs: \t\t{}'.format(epochs)) 45 | # print('Val_Interval: \t\t{}'.format(val_interval)) 46 | # print('Data_DIR: \t\t{}'.format(data_dir)) 47 | # print('Out_DIR: \t\t{}'.format(output_dir)) 48 | # print('Using \t\t{} GPU(s).'.format(torch.cuda.device_count())) 49 | # print('----------------------RNN-Fold-{}------------------------\n'.format(fold_k)) 50 | 51 | torch.backends.cudnn.benchmark = True 52 | torch.backends.cudnn.deterministic = False 53 | 54 | device = torch.device('cuda') 55 | 56 | 57 | def train(epoch, model, train_loader, criterion, optimizer, scheduler, fwriter): 58 | epoch_loss = 0 59 | epoch_corrects = 0 60 | acc = 0 61 | acc_result = np.zeros((1, 33)) 62 | model.train() 63 | 64 | print('\n--------------------Training-epoch-{:04}---------------------'.format(epoch+1)) 65 | 66 | for batch_idx, (videos, labels) in enumerate(train_loader): 67 | 68 | videos, labels = videos.to(device), labels.to(device) 69 | 70 | output = model(videos) 71 | 72 | preds = output.data.max(1)[1] 73 | 74 | corrects = torch.sum( abs(labels.data.max(1)[1] - preds) < 3).item() 75 | acc += torch.sum(labels.data.max(1)[1] == preds).item() 76 | acc_list = [torch.sum( abs(labels.data.max(1)[1] - preds) <= i).item() for i in range(33)] 77 | acc_array = np.array(acc_list) 78 | acc_result += acc_array 79 | 80 | loss = criterion(output, labels) 81 | 82 | epoch_loss += loss 83 | 84 | epoch_corrects += corrects 85 | 86 | optimizer.zero_grad() 87 | 88 | loss.backward() 89 | 90 | nn.utils.clip_grad_norm_(model.parameters(), 0.5) 91 | 92 | optimizer.step() 93 | 94 | # print(preds) 95 | # print(labels.data.max(1)[1]) 96 | 97 | partial_epoch = epoch + (batch_idx + 1) / len(train_loader) 98 | print('Train Epoch: {:5.2f} \tLoss: {:.4f} \tAcc: {:.4f}'.format(partial_epoch, loss.item() / labels.size(0), corrects / labels.size(0))) 99 | 100 | fwriter.write('{:5.2f},{:.4f}\n'.format(partial_epoch, loss.item())) 101 | fwriter.flush() 102 | 103 | scheduler.step() 104 | 105 | epoch_acc = epoch_corrects / len(train_loader.dataset) 106 | acc /= len(train_loader.dataset) 107 | acc_result /= len(train_loader.dataset) 108 | print('Epoch Acc: {:.2f}'.format(epoch_acc)) 109 | print('Train set: Acc: {:.4f}'.format(acc)) 110 | print(acc_result) 111 | 112 | 113 | def validate(epoch, model, val_loader, criterion, fwriter): 114 | loss = 0 115 | acc = 0 116 | acc_result = np.zeros((1, 33)) 117 | 118 | model.eval() 119 | 120 | print('-------------------Validating-epoch-{:04}--------------------'.format(epoch+1)) 121 | 122 | with torch.no_grad(): 123 | for i, (videos, labels) in enumerate(val_loader): 124 | 125 | videos, labels = videos.to(device), labels.to(device) 126 | 127 | output = model(videos) 128 | 129 | preds = F.softmax(output, dim=1).data.max(1)[1] 130 | 131 | acc += torch.sum(labels.data.max(1)[1] == preds).item() 132 | acc_list = [torch.sum( abs(labels.data.max(1)[1] - preds) <= i).item() for i in range(33)] 133 | acc_array = np.array(acc_list) 134 | acc_result += acc_array 135 | 136 | loss += criterion(output, labels) 137 | 138 | acc /= len(val_loader.dataset) 139 | acc_result /= len(val_loader.dataset) 140 | loss /= len(val_loader.dataset) 141 | 142 | print('Val set: Average Loss: {:.4f}'.format(loss)) 143 | print('Val set: Acc: {:.4f}'.format(acc)) 144 | print(acc_result) 145 | 146 | 147 | fwriter.write('{},{:.4f},{:.4f}\n'.format(epoch + 1, loss, acc)) 148 | fwriter.flush() 149 | 150 | return acc 151 | 152 | def test(epoch, model, test_loader, criterion, fwriter): 153 | loss = 0 154 | acc = 0 155 | acc_result = np.zeros((1, 33)) 156 | 157 | model.eval() 158 | 159 | print('-------------------Testing-epoch-{:04}--------------------'.format(epoch+1)) 160 | 161 | with torch.no_grad(): 162 | for i, (videos, labels) in enumerate(test_loader): 163 | 164 | videos, labels = videos.to(device), labels.to(device) 165 | 166 | output = model(videos) 167 | 168 | preds = F.softmax(output, dim=1).data.max(1)[1] 169 | 170 | acc += torch.sum(labels.data.max(1)[1] == preds).item() 171 | acc_list = [torch.sum( abs(labels.data.max(1)[1] - preds) <= i).item() for i in range(33)] 172 | acc_array = np.array(acc_list) 173 | acc_result += acc_array 174 | 175 | loss += criterion(output, labels) 176 | 177 | acc /= len(test_loader.dataset) 178 | acc_result /= len(test_loader.dataset) 179 | loss /= len(test_loader.dataset) 180 | 181 | print('Test set: Average Loss: {:.4f}'.format(loss)) 182 | print('Test set: Acc: {:.4f}'.format(acc)) 183 | print(acc_result) 184 | 185 | 186 | fwriter.write('{},{:.4f},{:.4f}\n'.format(epoch + 1, loss, acc)) 187 | fwriter.flush() 188 | 189 | return acc 190 | 191 | 192 | def weights_init(model): 193 | for m in model.modules(): 194 | if isinstance(m, (nn.Conv1d, nn.Linear)): 195 | nn.init.kaiming_normal_(m.weight) 196 | 197 | 198 | def collate_fn(data_and_label): 199 | v_data = [i[0] for i in data_and_label] 200 | label = [i[1] for i in data_and_label] 201 | v_data = pad_sequence(v_data, batch_first=True) 202 | label = pad_sequence(label, batch_first=True) 203 | 204 | return v_data, label 205 | 206 | 207 | def main(): 208 | model = scoreNet() 209 | 210 | model = model.to(device) 211 | 212 | check_dir(output_dir) 213 | 214 | fwriter_train = open(output_dir + '/train.csv', 'w') 215 | fwriter_val = open(output_dir + '/val.csv', 'w') 216 | fwriter_test = open(output_dir + '/test.csv', 'w') 217 | 218 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 219 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1, last_epoch=-1) 220 | 221 | weights_init(model) 222 | 223 | train_dataset = VideoDataset(data_dir='../data_for_miccai/train') 224 | train_loader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) 225 | 226 | val_dataset = VideoDataset(data_dir='../data_for_miccai/val') 227 | val_loader = DataLoader(val_dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) 228 | 229 | test_dataset = VideoDataset(data_dir='../data_for_miccai/test') 230 | test_loader = DataLoader(test_dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) 231 | 232 | 233 | 234 | train_time, val_time, test_time = 0, 0, 0 235 | 236 | criterion = nn.MSELoss() 237 | 238 | max_acc = 0 239 | 240 | for epoch in range(epochs): 241 | start = tic() 242 | 243 | train(epoch, model, train_loader, criterion, optimizer, scheduler, fwriter_train) 244 | train_time += toc(start) 245 | 246 | if (epoch + 1) % val_interval == 0: 247 | start = tic() 248 | acc = validate(epoch, model, val_loader, criterion, fwriter_val) 249 | if acc > max_acc: 250 | max_acc = acc 251 | torch.save(model.state_dict(), output_dir + '/RNN-fold{}-best.ckpt'.format(fold_k)) 252 | val_time += toc(start) 253 | 254 | if (epoch + 1) % val_interval == 0: 255 | start = tic() 256 | acc = test(epoch, model, test_loader, criterion, fwriter_test) 257 | test_time += toc(start) 258 | 259 | if (epoch + 1) % save_model_interval == 0: 260 | torch.save(model.state_dict(), output_dir + '/RNN-fold{}-{}.ckpt'.format(fold_k, epoch + 1)) 261 | 262 | print('\nTrain time accumulated: {:.2f}s, Val time accumulated: {:.2f}s.\n'.format(train_time, val_time)) 263 | 264 | 265 | if __name__ == '__main__': 266 | main() 267 | --------------------------------------------------------------------------------