├── images └── BabyNet.png ├── README.md ├── .gitignore └── src ├── models └── babynet.py ├── train_video.py └── video_data_loader.py /images/BabyNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SanoScience/BabyNet/HEAD/images/BabyNet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BabyNet: Residual Transformer Module for Birth Weight Prediction on Fetal Ultrasound Video 2 | 3 | This is the official released code for "BabyNet: Residual Transformer Module for Birth Weight Prediction on Fetal 4 | Ultrasound Video" early accepted for the 25th International Conference on Medical Image Computing and 5 | Computer Assisted Intervention (MICCAI) 2022 in Singapore. 6 | 7 | ![BabyNet](./images/BabyNet.png) 8 | 9 | ### Abstract 10 | 11 | Predicting fetal weight at birth is an important aspect of perinatal care, particularly in the context of antenatal 12 | management, which includes the planned timing and mode of delivery. Accurate prediction of weight using prenatal 13 | ultrasound is challenging as it requires images of specific fetal body parts during advanced pregnancy - this, however, 14 | is complicated by the poor quality of images caused by the lack of amniotic fluid. It follows that predictions which 15 | rely standard methods often suffer from significant errors. In this paper we propose the Residual Transformer Module, 16 | that extends a 3D ResNet-based network for analysis of 2D+t spatio-temporal ultrasound video scans. Our end-to-end 17 | method, called BabyNet, fully automatically predicts fetal birth weight based on fetal ultrasound video scans. We 18 | evaluate BabyNet using a dedicated clinical set comprising 225 2D fetal ultrasound videos of pregnancies from 75 19 | patients performed one day prior to delivery. Experimental results show that BabyNet outperforms several 20 | state-of-the-art methods and estimate the weight at birth with accuracy comparable to human experts. Furthermore, 21 | combining estimates provided by human experts with those computed by BabyNet yields the best results, outperforming 22 | either method by a significant margin. 23 | 24 | ### Usage 25 | 26 | #### Train the model 27 | 28 | > python3 train_video.py 29 | > 30 | If you are using our codes, please cite our work: 31 | ``` 32 | @inproceedings{plotka2022babynet, 33 | title={BabyNet: Residual Transformer Module for Birth Weight Prediction on Fetal Ultrasound Video}, 34 | author={P{\l}otka, Szymon and Grzeszczyk, Michal K and Brawura-Biskupski-Samaha, Robert and Gutaj, Pawe{\l} and Lipa, Micha{\l} and Trzci{\'n}ski, Tomasz and Sitek, Arkadiusz}, 35 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 36 | pages={350--359}, 37 | year={2022}, 38 | organization={Springer} 39 | } 40 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /src/models/babynet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class MHSA3D(nn.Module): 6 | def __init__(self, n_dims, n_frames=2, width=14, height=14, heads=4): 7 | super(MHSA3D, self).__init__() 8 | self.scale = (n_dims // heads) ** -0.5 9 | self.heads = heads 10 | 11 | self.query = nn.Conv3d(in_channels=n_dims, out_channels=n_dims, kernel_size=(1, 1, 1), stride=1, padding=0, 12 | bias=False) 13 | self.key = nn.Conv3d(in_channels=n_dims, out_channels=n_dims, kernel_size=(1, 1, 1), stride=1, padding=0, 14 | bias=False) 15 | self.value = nn.Conv3d(in_channels=n_dims, out_channels=n_dims, kernel_size=(1, 1, 1), stride=1, padding=0, 16 | bias=False) 17 | 18 | self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, 1, height]), requires_grad=True) 19 | self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, width, 1]), requires_grad=True) 20 | self.rel_t = nn.Parameter(torch.randn([1, heads, n_dims // heads, n_frames, 1, 1]), requires_grad=True) 21 | 22 | self.softmax = nn.Softmax(dim=-1) 23 | 24 | def forward(self, x): 25 | n_batch, C, F, height, width = x.size() 26 | q = self.query(x).view(n_batch, self.heads, F, C // self.heads, -1) 27 | k = self.key(x).view(n_batch, self.heads, F, C // self.heads, -1) 28 | v = self.value(x).view(n_batch, self.heads, F, C // self.heads, -1) 29 | 30 | content_content = torch.matmul(q.permute(0, 1, 2, 4, 3), k) 31 | content_rel_pos = self.rel_h + self.rel_w + self.rel_t 32 | content_position = content_rel_pos.view(1, self.heads, C // self.heads, F, -1).permute(0, 1, 3, 4, 2) 33 | content_position2 = torch.matmul(content_position, q) 34 | 35 | energy = (content_content + content_position2) * self.scale 36 | attention = self.softmax(energy) 37 | 38 | out = torch.matmul(v, attention.permute(0, 1, 2, 4, 3)) 39 | out = out.view(n_batch, C, F, height, width) 40 | 41 | return out 42 | 43 | 44 | class Conv3DSimple(nn.Conv3d): 45 | def __init__(self, 46 | in_planes, 47 | out_planes, 48 | midplanes=None, 49 | stride=1, 50 | padding=1): 51 | super(Conv3DSimple, self).__init__( 52 | in_channels=in_planes, 53 | out_channels=out_planes, 54 | kernel_size=(3, 3, 3), 55 | stride=stride, 56 | padding=padding, 57 | bias=False) 58 | 59 | @staticmethod 60 | def get_downsample_stride(stride): 61 | return stride, stride, stride 62 | 63 | 64 | class BasicBlock(nn.Module): 65 | expansion = 1 66 | 67 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None, mhsa=False, n_frames_last_layer=2, 68 | input_last_layer=(7, 7)): 69 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 70 | 71 | super(BasicBlock, self).__init__() 72 | self.conv1 = nn.Sequential( 73 | conv_builder(inplanes, planes, midplanes, stride), 74 | nn.BatchNorm3d(planes), 75 | nn.ReLU(inplace=True) 76 | ) 77 | if not mhsa: 78 | self.conv2 = nn.Sequential( 79 | conv_builder(planes, planes, midplanes), 80 | nn.BatchNorm3d(planes) 81 | ) 82 | else: 83 | self.conv2 = nn.Sequential(MHSA3D(planes, n_frames=n_frames_last_layer, width=input_last_layer[1], 84 | height=input_last_layer[0]), nn.BatchNorm3d(planes)) 85 | 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | out = self.conv1(x) 94 | out = self.conv2(out) 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class BasicStem(nn.Sequential): 105 | def __init__(self): 106 | super(BasicStem, self).__init__( 107 | nn.Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), 108 | padding=(1, 3, 3), bias=False), 109 | nn.BatchNorm3d(64), 110 | nn.ReLU(inplace=True)) 111 | 112 | 113 | class VideoResNet(nn.Module): 114 | 115 | def __init__(self, block, conv_makers, layers, 116 | stem, num_classes=1, msha=False, n_frames_last_layer=2, input_last_layer=(7, 7)): 117 | """Generic resnet video generator. 118 | 119 | Args: 120 | block (nn.Module): resnet building block 121 | conv_makers (list(functions)): generator function for each layer 122 | layers (List[int]): number of blocks per layer 123 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 124 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 125 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 126 | """ 127 | super(VideoResNet, self).__init__() 128 | self.inplanes = 64 129 | 130 | self.stem = stem() 131 | 132 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 133 | 134 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 135 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 136 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2, mhsa=msha, 137 | n_frames_last_layer=n_frames_last_layer, input_last_layer=input_last_layer) 138 | 139 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 140 | self.fc = nn.Sequential( 141 | nn.Linear(512 * block.expansion, num_classes) 142 | ) 143 | 144 | # init weights 145 | self._initialize_weights() 146 | 147 | def forward(self, x): 148 | x = self.stem(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | 155 | x = self.avgpool(x) 156 | # Flatten the layer to fc 157 | x = x.flatten(1) 158 | x = self.fc(x) 159 | x = x.flatten() 160 | return x 161 | 162 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1, mhsa=False, n_frames_last_layer=2, 163 | input_last_layer=(7, 7)): 164 | downsample = None 165 | 166 | if stride != 1 or self.inplanes != planes * block.expansion: 167 | ds_stride = conv_builder.get_downsample_stride(stride) 168 | downsample = nn.Sequential( 169 | nn.Conv3d(self.inplanes, planes * block.expansion, 170 | kernel_size=1, stride=ds_stride, bias=False), 171 | nn.BatchNorm3d(planes * block.expansion) 172 | ) 173 | layers = [] 174 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample, mhsa=mhsa, 175 | n_frames_last_layer=n_frames_last_layer, 176 | input_last_layer=input_last_layer)) 177 | 178 | self.inplanes = planes * block.expansion 179 | for i in range(1, blocks): 180 | layers.append( 181 | block(self.inplanes, planes, conv_builder, mhsa=mhsa, n_frames_last_layer=n_frames_last_layer, 182 | input_last_layer=input_last_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def _initialize_weights(self): 187 | for m in self.modules(): 188 | if isinstance(m, nn.Conv3d): 189 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 190 | nonlinearity='relu') 191 | if m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, nn.BatchNorm3d): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | elif isinstance(m, nn.Linear): 197 | nn.init.normal_(m.weight, 0, 0.01) 198 | nn.init.constant_(m.bias, 0) 199 | 200 | 201 | def _video_resnet(**kwargs): 202 | model = VideoResNet(**kwargs) 203 | 204 | return model 205 | 206 | 207 | # VideoResNet based on: https://github.com/pytorch/vision/blob/5a315453da5089d66de94604ea49334a66552524/torchvision/models/video/resnet.py#L285 208 | # MHSA based on: https://github.com/leaderj1001/BottleneckTransformers/blob/main/model.py 209 | def BabyNet(msha=False, n_frames=16, input_size=(224, 224), **kwargs): 210 | n_frames_last_layer = max(n_frames // 8, 1) 211 | input_last_layer = (input_size[0] // 16, input_size[1] // 16) 212 | return _video_resnet(block=BasicBlock, 213 | conv_makers=[Conv3DSimple] * 4, 214 | layers=[2, 2, 2, 2], 215 | stem=BasicStem, msha=msha, n_frames_last_layer=n_frames_last_layer, 216 | input_last_layer=input_last_layer, **kwargs) 217 | -------------------------------------------------------------------------------- /src/train_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dir_path = os.path.dirname(os.path.realpath(__file__)) 5 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir)) 6 | sys.path.insert(0, parent_dir_path) 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | import torch.optim as optim 12 | from torch.optim.lr_scheduler import StepLR 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data import Sampler, SubsetRandomSampler 15 | from sklearn.model_selection import KFold 16 | from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error 17 | import time 18 | 19 | from models.babynet import BabyNet 20 | from video_data_loader import FetalWeightVideo 21 | import argparse 22 | 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 25 | 26 | parser = argparse.ArgumentParser(description="BabyNet for Fetal Birth Weight prediction") 27 | parser.add_argument("--data", 28 | type=str, 29 | default="../data/", 30 | help="Path to the data directory.") 31 | parser.add_argument("--x_img_size", 32 | type=int, 33 | default=64, 34 | help="Input X image size.") 35 | parser.add_argument("--y_img_size", 36 | type=int, 37 | default=64, 38 | help="Input Y image size") 39 | parser.add_argument("--batch_size", 40 | type=int, 41 | default=2, 42 | help="Number of batch size.") 43 | parser.add_argument("--epochs", 44 | type=int, 45 | default=20, 46 | help="Number of epochs.") 47 | parser.add_argument("--lr", 48 | type=float, 49 | default=0.0001, 50 | help="Number of learning rate.") 51 | parser.add_argument("--step_lr", 52 | type=int, 53 | default=16, 54 | help="Step of learning rate") 55 | parser.add_argument("--w_decay", 56 | type=float, 57 | default=0.0001, 58 | help="Number of weight decay.") 59 | parser.add_argument("--GPU", 60 | type=bool, 61 | default=True, 62 | help="Use GPU.") 63 | parser.add_argument("--display_steps", 64 | type=int, 65 | default=20, 66 | help="Number of display steps.") 67 | parser.add_argument("--model_name", 68 | type=str, 69 | default="BabyNet", 70 | help="Name of trained model.") 71 | parser.add_argument("--frames_num", 72 | type=int, 73 | default=16, 74 | help="Number of frames in chunk") 75 | parser.add_argument("--skip_frames", 76 | type=int, 77 | default=0, 78 | help="Number of frames to skip") 79 | parser.add_argument("--pixels_crop", 80 | type=int, 81 | default=0, 82 | help="Number of frames in chunk") 83 | parser.add_argument("--msha3D", 84 | type=bool, 85 | default=True, 86 | help='Add MSHA to ResNet3D') 87 | args = parser.parse_args() 88 | 89 | dataset = FetalWeightVideo(input_path=args.data, 90 | x_image_size=args.x_img_size, 91 | y_image_size=args.y_img_size, 92 | pixels_crop=args.pixels_crop, 93 | skip_frames=args.skip_frames, 94 | n_frames=args.frames_num) 95 | 96 | if args.GPU and torch.cuda.is_available(): 97 | device = torch.device("cuda") 98 | else: 99 | device = torch.device("cpu") 100 | 101 | 102 | class CustomSequentialSampler(Sampler[int]): 103 | 104 | def __init__(self, data_source) -> None: 105 | self.data_source = data_source 106 | 107 | def __iter__(self): 108 | for i in range(len(self.data_source)): 109 | yield self.data_source[i] 110 | 111 | def __len__(self) -> int: 112 | return len(self.data_source) 113 | 114 | 115 | overlapping_ids_in_test = [] 116 | 117 | 118 | def ensure_no_patient_split(train_ids, valid_ids): 119 | patient_ids = dataset.patient_id_by_chunk 120 | patient_ids_train = set([patient_ids[i] for i in train_ids]) 121 | patient_ids_valid = set([patient_ids[i] for i in valid_ids]) 122 | overlapping_ids = patient_ids_train.intersection(patient_ids_valid) 123 | if len(overlapping_ids) == 0: 124 | return train_ids, valid_ids 125 | 126 | for overlapping_id in overlapping_ids: 127 | if overlapping_id not in overlapping_ids_in_test: 128 | indices_to_move = [ind for ind in train_ids if patient_ids[ind] == overlapping_id] 129 | valid_ids = np.append(valid_ids, indices_to_move).flatten() 130 | train_ids = np.delete(train_ids, np.searchsorted(train_ids, indices_to_move)) 131 | overlapping_ids_in_test.append(overlapping_id) 132 | else: 133 | indices_to_move = [ind for ind in valid_ids if patient_ids[ind] == overlapping_id] 134 | train_ids = np.append(train_ids, indices_to_move).flatten() 135 | valid_ids = np.delete(valid_ids, np.searchsorted(valid_ids, indices_to_move)) 136 | return list(sorted(train_ids)), list(sorted(valid_ids)) 137 | 138 | 139 | def calculate_metrics(y_true, y_pred): 140 | mse = mean_squared_error(y_true, y_pred, squared=True) 141 | rmse = mean_squared_error(y_true, y_pred, squared=False) 142 | mae = mean_absolute_error(y_true, y_pred) 143 | mape = mean_absolute_percentage_error(y_true, y_pred) 144 | print(f"RMSE: {rmse:.2f}") 145 | print(f"MSE: {mse:.2f}") 146 | print(f"MAE: {mae:.2f}") 147 | print(f"MAPE: {mape:.4f}") 148 | 149 | 150 | kfold = KFold(n_splits=5, shuffle=False) 151 | criterion_reg = nn.MSELoss() 152 | loss_min = np.inf 153 | train_dataset = FetalWeightVideo(input_path=args.data, 154 | x_image_size=args.x_img_size, 155 | y_image_size=args.y_img_size, 156 | pixels_crop=args.pixels_crop, 157 | skip_frames=args.skip_frames, 158 | n_frames=args.frames_num, 159 | mode="train") 160 | val_dataset = FetalWeightVideo(input_path=args.data, 161 | x_image_size=args.x_img_size, 162 | y_image_size=args.y_img_size, 163 | pixels_crop=args.pixels_crop, 164 | skip_frames=args.skip_frames, 165 | n_frames=args.frames_num, 166 | mode="val") 167 | 168 | print("---------------") 169 | 170 | # Start time of learning 171 | total_start_training = time.time() 172 | 173 | for fold, (train_ids, valid_ids) in enumerate(kfold.split(dataset)): 174 | print(f"FOLD {fold}") 175 | print("----------------") 176 | train_ids, valid_ids = ensure_no_patient_split(train_ids, valid_ids) 177 | train_subsampler = SubsetRandomSampler(train_ids) 178 | valid_subsampler = CustomSequentialSampler(valid_ids) 179 | 180 | train_loader = DataLoader(dataset=train_dataset, 181 | batch_size=args.batch_size, 182 | sampler=train_subsampler) 183 | valid_loader = DataLoader(dataset=val_dataset, 184 | batch_size=args.batch_size, 185 | sampler=valid_subsampler) 186 | 187 | model = BabyNet(msha=args.msha3D, n_frames=args.frames_num, 188 | input_size=(args.y_img_size - args.pixels_crop, args.x_img_size - args.pixels_crop)) 189 | model.to(device) 190 | 191 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.w_decay) 192 | 193 | scheduler = StepLR(optimizer=optimizer, step_size=args.step_lr, gamma=0.1, verbose=True) 194 | 195 | best_val_score = np.inf 196 | best_val_preds = None 197 | for epoch in range(args.epochs): 198 | start_time_epoch = time.time() 199 | print(f"Starting epoch {epoch + 1}") 200 | model.train() 201 | running_loss = 0.0 202 | 203 | y_true = [] 204 | y_pred = [] 205 | patient_running_loss = [] 206 | for batch_idx, (videos, weights, patient_id, body_part, first_frame) in enumerate(train_loader): 207 | optimizer.zero_grad() 208 | videos = torch.permute(videos, (0, 4, 1, 2, 3)) 209 | videos = videos.to(device=device).float() 210 | y_true.extend(weights.flatten().tolist()) 211 | weights = weights.to(device=device).float() 212 | 213 | reg_out = model(videos) 214 | y_pred.extend(reg_out.flatten().cpu().tolist()) 215 | loss_reg = criterion_reg(reg_out, weights) 216 | loss = loss_reg 217 | loss.backward() 218 | optimizer.step() 219 | 220 | running_loss += loss.item() 221 | 222 | if batch_idx % args.display_steps == 0: 223 | print(' ', end='') 224 | print(f"Batch: {batch_idx + 1}/{len(train_loader)} " 225 | f"Loss: {loss.item():.4f} " 226 | f"Learning time: {(time.time() - start_time_epoch):.2f}s " 227 | f"First frame: {first_frame[0]}") 228 | 229 | # evalute 230 | calculate_metrics(y_true, y_pred) 231 | print(f"Finished epoch {epoch + 1}, starting evaluation.") 232 | 233 | model.eval() 234 | val_running_loss = 0.0 235 | y_true = [] 236 | y_pred = [] 237 | for batch_idx, (videos, weights, patient_id, body_part, first_frame) in enumerate(valid_loader): 238 | videos = torch.permute(videos, (0, 4, 1, 2, 3)) 239 | videos = videos.to(device=device).float() 240 | y_true.extend(weights.flatten().tolist()) 241 | weights = weights.to(device=device).float() 242 | 243 | reg_out = model(videos) 244 | y_pred.extend(reg_out.flatten().cpu().tolist()) 245 | loss_reg = criterion_reg(reg_out, weights) 246 | loss = loss_reg 247 | 248 | val_running_loss += loss.item() 249 | 250 | calculate_metrics(y_true, y_pred) 251 | 252 | train_loss = running_loss / len(train_loader) 253 | val_loss = val_running_loss / len(valid_loader) 254 | 255 | if best_val_score > val_loss: 256 | save_path = f"{args.model_name}-fold-{fold}.pt" 257 | torch.save(model.state_dict(), save_path) 258 | best_val_score = val_loss 259 | print(f"Current best val score {best_val_score}. Model saved!") 260 | 261 | scheduler.step() 262 | 263 | print(' ', end='') 264 | print(f"Train Loss: {train_loss:.3f} " 265 | f"Val Loss: {val_loss:.3f}") 266 | 267 | print('Training finished, took {:.2f}s'.format(time.time() - total_start_training)) 268 | -------------------------------------------------------------------------------- /src/video_data_loader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | 5 | import cv2 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensorV2 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torchio as tio 12 | from torch.utils.data import Dataset 13 | from itertools import groupby 14 | from collections import defaultdict 15 | 16 | 17 | class FetalWeightVideo(Dataset): 18 | """ 19 | FetalWeightVideo class. 20 | """ 21 | 22 | def __init__(self, 23 | input_path: str = None, 24 | resample: bool = True, 25 | n_frames: int = 48, 26 | x_image_size: int = 512, 27 | y_image_size: int = 384, 28 | pixels_crop: int = 0, 29 | max_width: int = 1680, 30 | max_height: int = 1260, 31 | skip_frames: int = 0, 32 | multiple_planes: bool = False, 33 | background_frames_filename: str = "background_frames.csv", 34 | resize_only: bool = False, 35 | padding: bool = False, 36 | mode: str = None) -> None: 37 | """ 38 | 39 | Args: 40 | input_path: 41 | resample: 42 | n_frames: if multiple_planes, n_frames must be divisible by 3 43 | x_image_size: the assumption is that x>y and x/y= ~0.75 44 | y_image_size: 45 | pixels_crop: 46 | max_width: 47 | max_height: 48 | skip_frames: 49 | multiple_planes: 50 | background_frames_filename: 51 | resize_only: 52 | padding: 53 | mode: 54 | """ 55 | self.input_path = os.path.join(input_path, "BabyWeightVideo_frames") 56 | self.n_frames = n_frames 57 | self.x_image_size = x_image_size 58 | self.y_image_size = y_image_size 59 | self.max_width = max_width 60 | self.max_height = max_height 61 | self.pixels_crop = pixels_crop 62 | self.multiple_planes = multiple_planes 63 | self.weight_labels = pd.read_csv("../data/baby_weight.csv", sep=";", usecols=["ID", "weight"]) 64 | self.resize_only = resize_only 65 | 66 | assert (self.resize_only and pixels_crop == 0) or (not self.resize_only) 67 | assert mode in ["train", "val", None] 68 | self.mode = mode 69 | self.patient_id = [] 70 | for file in os.listdir(self.input_path): 71 | if not file.startswith("."): 72 | pid = file.split("_")[1] 73 | self.patient_id.append(pid) 74 | 75 | self.resample = resample 76 | 77 | self.chunk_frames = [] 78 | self.patient_id_by_chunk = [] 79 | self.frames_to_omit = [] 80 | if background_frames_filename: 81 | background_frames = pd.read_csv(os.path.join(input_path, background_frames_filename)) 82 | background_frames = background_frames["background_frames"].tolist() 83 | background_frames = [f"{f}.png" for f in background_frames] 84 | self.frames_to_omit = background_frames 85 | 86 | for i in self.patient_id: 87 | frames = glob.glob(f"{self.input_path}/video_{i}/*.png") 88 | frames = sorted(frames, key=numericalSort) 89 | frames = [f for f in frames if self._get_filename(f) not in self.frames_to_omit] 90 | if multiple_planes: 91 | chunk_frames = self.chunk_frames_multiple_planes(frames, n_frames) 92 | else: 93 | chunk_frames = chunks(frames, n_frames, skip_frames) 94 | 95 | for chunk in chunk_frames: 96 | if len(chunk) != n_frames and not padding: 97 | continue 98 | chunk.extend(["padding" for _ in range(n_frames - len(chunk))]) 99 | self.chunk_frames.append([chunk, i]) 100 | self.patient_id_by_chunk.append(i) 101 | 102 | def __getitem__(self, x): 103 | """ 104 | 105 | Args: 106 | x: 107 | 108 | Returns: 109 | 110 | """ 111 | chunk_frames, patient_id = self.chunk_frames[x][0], self.chunk_frames[x][1] 112 | patient_dir = f"{self.input_path}/video_{patient_id}" 113 | video_stack = self.load_video(patient_dir, chunk_frames, resample=True) 114 | weight = self.weight_labels.loc[self.weight_labels["ID"] == int(patient_id), "weight"] 115 | weight = weight.item() 116 | 117 | video_stack = torch.tensor(video_stack) 118 | 119 | body_part = self.get_body_part(chunk_frames) 120 | first_frame = self._first_frame(chunk_frames) 121 | return video_stack, weight, self.patient_id_by_chunk[x], body_part, first_frame 122 | 123 | def __len__(self) -> int: 124 | return len(self.chunk_frames) 125 | 126 | def _get_filename(self, full_path): 127 | return os.path.basename(full_path) 128 | 129 | def get_body_part(self, chunk_frames): 130 | body_parts = [self._body_part_from_frame_path(frame) for frame in chunk_frames if "padding" not in frame] 131 | return max(body_parts, key=body_parts.count) 132 | 133 | def _body_part_from_frame_path(self, path): 134 | body_part = path.split('video_')[-1].split("_")[-2] 135 | if "abdomen" in body_part: 136 | return "abdomen" 137 | if "head" in body_part: 138 | return "head" 139 | if "femur" in body_part: 140 | return "femur" 141 | 142 | raise ValueError(f"Unrecognized body part: {body_part}") 143 | 144 | def load_video(self, patient_dir: str = None, 145 | chunk_frames: list = None, 146 | resample: bool = True): 147 | """ 148 | 149 | Args: 150 | resample: 151 | chunk_frames: 152 | patient_dir: 153 | 154 | Returns: 155 | 156 | """ 157 | if not os.path.exists(patient_dir): 158 | raise FileNotFoundError(patient_dir) 159 | 160 | video = np.zeros( 161 | (len(chunk_frames), self.y_image_size - self.pixels_crop, self.x_image_size - self.pixels_crop, 1), 162 | np.float32) 163 | video_frames = {} 164 | for count, chunk in enumerate(chunk_frames): 165 | if resample: 166 | if "padding" in chunk: 167 | continue 168 | frame = cv2.imread(chunk, 0) 169 | 170 | if self.resize_only: 171 | transformations = A.Compose([ 172 | A.Resize(self.y_image_size, self.x_image_size), 173 | ]) 174 | else: 175 | height, width = frame.shape[:2] 176 | transformed_height = int((height / self.max_width) * self.x_image_size) 177 | transformed_width = int((width / self.max_width) * self.x_image_size) 178 | 179 | assert transformed_width <= self.x_image_size 180 | assert transformed_height <= self.y_image_size 181 | transformations = A.Compose([ 182 | A.Resize(transformed_height, transformed_width), 183 | A.PadIfNeeded(self.y_image_size, self.x_image_size, border_mode=cv2.BORDER_CONSTANT, value=0), 184 | A.CenterCrop(self.y_image_size - self.pixels_crop, self.x_image_size - self.pixels_crop), 185 | ]) 186 | 187 | transformed_frame = transformations(image=frame) 188 | transformed_frame = transformed_frame["image"] 189 | transformed_frame = np.expand_dims(transformed_frame, axis=2) 190 | frame_name = f"image{count - 1}" if count != 0 else "image" 191 | video_frames[frame_name] = transformed_frame 192 | 193 | if self.mode == "val": 194 | augmentations = A.Compose([]) 195 | else: 196 | augmentations = A.Compose([A.Rotate(limit=(-25, 25)), 197 | A.HorizontalFlip(p=0.5), 198 | A.RandomBrightnessContrast(), 199 | A.ImageCompression(p=0.1), 200 | A.OneOf([ 201 | A.MotionBlur(p=0.5), 202 | A.MedianBlur(blur_limit=3, p=0.5), 203 | A.Blur(blur_limit=3, p=0.5), 204 | A.GaussianBlur(p=0.5)], 205 | p=0.5), 206 | ], 207 | additional_targets={f"image{i}": "image" for i in range(len(video_frames) - 1)}) 208 | 209 | transformed_video = augmentations(**video_frames) 210 | for i, (frame_name, frame) in enumerate(video_frames.items()): 211 | augmented_frame = transformed_video[frame_name] 212 | video[i] = augmented_frame 213 | 214 | m = np.max(video) 215 | video = (video - 0.5 * m) / (0.5 * m) # img = (img - mean * max_pixel_value) / (std * max_pixel_value) 216 | return video 217 | 218 | def chunk_frames_multiple_planes(self, frames, n_frames): 219 | by_body_parts = groupby(frames, self._body_part_from_frame_path) 220 | body_parts_dict = defaultdict(list) 221 | max_length = 0 222 | for b_p, fr in by_body_parts: 223 | b_p_list = list(el for el in fr) 224 | if len(b_p_list) > max_length: 225 | max_length = len(b_p_list) 226 | body_parts_dict[b_p] = b_p_list 227 | max_length_with_padding = max_length + (n_frames // 3 - (max_length % (n_frames // 3))) 228 | abdomen_frames = self._pad(body_parts_dict["abdomen"], max_length_with_padding) 229 | femur_frames = self._pad(body_parts_dict["femur"], max_length_with_padding) 230 | head_frames = self._pad(body_parts_dict["head"], max_length_with_padding) 231 | chunk_frames = [] 232 | body_part_frames_num = n_frames // 3 233 | for i in range(max_length_with_padding // (n_frames // 3)): 234 | new_chunk = [] 235 | new_chunk.extend(abdomen_frames[i * body_part_frames_num: i * body_part_frames_num + body_part_frames_num]) 236 | new_chunk.extend(femur_frames[i * body_part_frames_num: i * body_part_frames_num + body_part_frames_num]) 237 | new_chunk.extend(head_frames[i * body_part_frames_num: i * body_part_frames_num + body_part_frames_num]) 238 | all_padding = True 239 | for f in new_chunk: 240 | if "padding" not in f: 241 | all_padding = False 242 | if all_padding: 243 | continue 244 | chunk_frames.append(new_chunk) 245 | 246 | return chunk_frames 247 | 248 | def _pad(self, l: list, pad_to: int): 249 | l.extend(["padding" for _ in range(pad_to - len(l))]) 250 | return l 251 | 252 | def _first_frame(self, frames): 253 | for f in frames: 254 | if "padding" not in f: 255 | return "video" + f.split("video")[-1] 256 | 257 | 258 | def chunks(L, n, skip): 259 | """ 260 | 261 | Args: 262 | L: 263 | n: 264 | skip: 265 | 266 | Returns: 267 | 268 | """ 269 | if skip > 0: 270 | return [L[x: x + (skip+1) * (n-1)+1: skip + 1] for x in range(0, len(L), (skip+1) * (n-1)+1)] 271 | else: 272 | return [L[x: x + n: skip + 1] for x in range(0, len(L), n)] 273 | 274 | 275 | def numericalSort(value): 276 | numbers = re.compile(r'(\d+)') 277 | parts = numbers.split(value) 278 | parts[1::2] = map(int, parts[1::2]) 279 | return parts 280 | --------------------------------------------------------------------------------