├── .gitignore ├── .travis.yml ├── README.md ├── echonet ├── __init__.py ├── __main__.py ├── __version__.py ├── config.py ├── datasets │ ├── __init__.py │ └── echo.py ├── models │ ├── __init__.py │ └── rnet2dp1.py ├── segmentation │ ├── __init__.py │ ├── _utils.py │ ├── deeplabv3.py │ └── segmentation.py └── utils │ ├── __init__.py │ ├── seg_cycle.py │ ├── video_segin.py │ └── vidsegin_teachstd_kd.py ├── flow_a_tmi_revise_v2.PNG ├── flow_b_tmi_revise.PNG ├── flow_graph.PNG ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.swp 4 | echonet.cfg 5 | .echonet.cfg 6 | *.pyc 7 | echonet.egg-info/ 8 | *.pth 9 | *.pt 10 | *.npy 11 | output/zdbg* 12 | output/*/size 13 | output/*/videos 14 | *.avi 15 | *.zip 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: minimal 2 | 3 | os: 4 | - linux 5 | 6 | env: 7 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2 (torchvision 0.2 does not have VisionDataset) 8 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3 (torchvision 0.3 has a cuda issue) 9 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4 10 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5 11 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2 12 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3 13 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4 14 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5 15 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2 16 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3 17 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4 18 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5 19 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2 20 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3 21 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4 22 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5 23 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2 24 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3 25 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4 26 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5 27 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2 28 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3 29 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4 30 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5 31 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2 32 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3 33 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4 34 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5 35 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2 36 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3 37 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4 38 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5 39 | 40 | install: 41 | - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; 42 | then 43 | MINICONDA_OS=Linux; 44 | sudo apt-get update; 45 | else 46 | MINICONDA_OS=MacOSX; 47 | brew update; 48 | fi 49 | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-${MINICONDA_OS}-x86_64.sh -O miniconda.sh 50 | - bash miniconda.sh -b -p $HOME/miniconda 51 | - source "$HOME/miniconda/etc/profile.d/conda.sh" 52 | - hash -r 53 | - conda config --set always_yes yes --set changeps1 no 54 | - conda update -q conda 55 | # Useful for debugging any issues with conda 56 | - conda info -a 57 | - conda search pytorch || true 58 | 59 | - conda create -q -n test-environment python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} 60 | - conda activate test-environment 61 | - pip install -q torchvision==${TORCHVISION_VERSION} "pillow<7.0.0" 62 | - pip install -q . 63 | - pip install -q flake8 pylint 64 | 65 | script: 66 | - flake8 --ignore=E501 67 | - pylint --disable=C0103,C0301,R0401,R0801,R0902,R0912,R0913,R0914,R0915 --extension-pkg-whitelist=cv2,torch --generated-members=torch.* echonet/ scripts/*.py setup.py 68 | - python -c "import echonet" 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Cyclical Self-Supervision for Semi-Supervised Ejection Fraction Prediction from Echocardiogram Videos 3 | 4 | 5 | 6 | This is the implementation of CSS for Semi-Supervised Ejection Fraction Prediction for the paper ["Cyclical Self-Supervision for Semi-Supervised Ejection Fraction Prediction from Echocardiogram Videos"](). 7 | 8 | ![CSS_flow](flow_a_tmi_revise_v2.PNG)![CSS_flow](flow_b_tmi_revise.PNG) 9 | 10 |
11 |
12 | 13 | ## Data 14 | 15 | Researchers can request the EchoNet-Dynamic dataset at https://echonet.github.io/dynamic/ and set the directory path in the configuration file, `echonet.cfg`. 16 | 17 |
18 |
19 | 20 | ## Environment 21 | 22 | It is recommended to use PyTorch `conda` environments for running the program. A requirements file has been included. 23 | 24 |
25 |
26 | 27 | ## Training and testing 28 | 29 | The code must first be installed by running 30 | 31 | pip3 install --user . 32 | 33 | under the repository directory `CSS-SemiVideo`. Training consists of three components: 34 | 35 |
36 | 37 | ### 1) To train the CSS semi-supervised segmentation model, run: 38 | 39 | ``` 40 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --reduced_set 41 | ``` 42 | 43 | The LV segmentation prediction masks of all frames must be inferred for the second stage. To do so, run: 44 | 45 | ``` 46 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=train 47 | 48 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=val 49 | 50 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=test 51 | ``` 52 | 53 | The segmentation prediction outputs will be located under the output folder `output/css_seg`. To reduce installation time for EchoNet-Dynamic, these are moved to a separatate directory parallel to `CSS-SemiVideo`, i.e. `CSS-SemiVideo/../infer_buffers/css_seg`. Segmentation masks are also sourced from this location for Step 2 of the framework. 54 | 55 | To do this, run: 56 | 57 | ``` 58 | mkdir ../infer_buffers/css_seg 59 | mv output/css_seg/*_infer_cmpct ../infer_buffers/css_seg/ 60 | ``` 61 |
62 | 63 | ### 2) To train the multi-modal LVEF prediction model, run: 64 | 65 | ``` 66 | echonet video_segin --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --output=output/teacher_model --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --segsource=css_seg 67 | ``` 68 |
69 | 70 | ### 3) To train teacher-student distillation, run: 71 | 72 | ``` 73 | echonet vidsegin_teachstd_kd --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --output=output/end2end_model --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --reduced_set --max_block=20 --segsource=css_seg --w_unlb=5 --batch_size_unlb=10 --weights_0=output/teacher_model/best.pt 74 | ``` 75 | 76 | 77 | 78 |
79 |
80 | 81 | 82 | 83 | ## Pretrained models 84 | 85 | Trained checkpoints and models can be downloaded from: 86 | 87 | 1) CSS for semi-supervised segmentation: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/EqiP-N0MDRZGlwqr5PeZUrYBtLki8QWtBlMqRK1FNkjbcw?e=DIpkIm 88 | 89 | 2) Multi-modal LVEF regression: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/ErxaHepi4ndAnMcvSOwTH5wBDI6rHypqdcBiXF8B0XYvmg?e=Rud7Pf 90 | 91 | 3) Teacher-student distillation: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/Ev7mQ1ReI05LtiDIqQu1IpYBC6xN4R47PsYnhDUQr4n3fw?e=US4caq 92 | 93 | 94 | To run with the pretrained model weights, replace the `.pts` files in the target output directory with the downloaded files. 95 | 96 |
97 | 98 | | Experiments | MAE | RMSE | R2 | 99 | | ---------- | :-----------: | :-----------: | :-----------: | 100 | | Multi-Modal | 5.13 ± 0.05 | 6.90 ± 0.07 | 67.6% ± 0.5 | 101 | | Teacher-student Distillation | 4.90 ± 0.04 | 6.57 ± 0.06 | 71.1% ± 0.4 | 102 | 103 |
104 |
105 | 106 | ## Notes 107 | * Contact: DAI Weihang (wdai03@gmail.com) 108 |
109 |
110 | 111 | ## Citation 112 | If this code is useful for your research, please consider citing: 113 | 114 | (to be released) 115 | -------------------------------------------------------------------------------- /echonet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The echonet package contains code for loading echocardiogram videos, and 3 | functions for training and testing segmentation and ejection fraction 4 | prediction models. 5 | """ 6 | 7 | import click 8 | 9 | from echonet.__version__ import __version__ 10 | from echonet.config import CONFIG as config 11 | import echonet.datasets as datasets 12 | import echonet.utils as utils 13 | import echonet.models as models 14 | import echonet.segmentation as segmentation 15 | 16 | @click.group() 17 | def main(): 18 | """Entry point for command line interface.""" 19 | 20 | 21 | del click 22 | 23 | main.add_command(utils.seg_cycle.run) 24 | main.add_command(utils.vidsegin_teachstd_kd.run) 25 | main.add_command(utils.video_segin.run) 26 | 27 | 28 | __all__ = ["__version__", "config", "datasets", "main", "utils", "models", "segmentation"] 29 | -------------------------------------------------------------------------------- /echonet/__main__.py: -------------------------------------------------------------------------------- 1 | """Entry point for command line.""" 2 | 3 | import echonet 4 | 5 | 6 | if __name__ == '__main__': 7 | echonet.main() 8 | -------------------------------------------------------------------------------- /echonet/__version__.py: -------------------------------------------------------------------------------- 1 | """Version number for Echonet package.""" 2 | 3 | __version__ = "1.0.0" 4 | -------------------------------------------------------------------------------- /echonet/config.py: -------------------------------------------------------------------------------- 1 | """Sets paths based on configuration files.""" 2 | 3 | import configparser 4 | import os 5 | import types 6 | 7 | _FILENAME = None 8 | _PARAM = {} 9 | for filename in ["echonet.cfg", 10 | ".echonet.cfg", 11 | os.path.expanduser("~/echonet.cfg"), 12 | os.path.expanduser("~/.echonet.cfg"), 13 | ]: 14 | if os.path.isfile(filename): 15 | _FILENAME = filename 16 | config = configparser.ConfigParser() 17 | with open(filename, "r") as f: 18 | config.read_string("[config]\n" + f.read()) 19 | _PARAM = config["config"] 20 | break 21 | 22 | CONFIG = types.SimpleNamespace( 23 | FILENAME=_FILENAME, 24 | DATA_DIR=_PARAM.get("data_dir", "../EchoNet/Heart-videos/")) 25 | -------------------------------------------------------------------------------- /echonet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The echonet.datasets submodule defines a Pytorch dataset for loading 3 | echocardiogram videos. 4 | """ 5 | 6 | from .echo import Echo, Echo_tskd, Echo_CSS 7 | 8 | __all__ = ["Echo", "Echo_tskd", "Echo_CSS"] 9 | -------------------------------------------------------------------------------- /echonet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rnet2dp1 import r2plus1d_18_kd, r2plus1d_18 2 | 3 | __all__ = ["r2plus1d_18_kd", "r2plus1d_18"] -------------------------------------------------------------------------------- /echonet/models/rnet2dp1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | from torchvision.models.utils import load_state_dict_from_url 6 | import numpy as np 7 | # from ..utils import load_state_dict_from_url 8 | 9 | 10 | __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18', 'r2plus1d_18_kd'] 11 | 12 | model_urls = { 13 | 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', 14 | 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', 15 | 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', 16 | } 17 | 18 | 19 | class Conv3DSimple(nn.Conv3d): 20 | def __init__(self, 21 | in_planes, 22 | out_planes, 23 | midplanes=None, 24 | stride=1, 25 | padding=1): 26 | 27 | super(Conv3DSimple, self).__init__( 28 | in_channels=in_planes, 29 | out_channels=out_planes, 30 | kernel_size=(3, 3, 3), 31 | stride=stride, 32 | padding=padding, 33 | bias=False) 34 | 35 | @staticmethod 36 | def get_downsample_stride(stride): 37 | return stride, stride, stride 38 | 39 | 40 | class Conv2Plus1D(nn.Sequential): 41 | 42 | def __init__(self, 43 | in_planes, 44 | out_planes, 45 | midplanes, 46 | stride=1, 47 | padding=1): 48 | super(Conv2Plus1D, self).__init__( 49 | nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), 50 | stride=(1, stride, stride), padding=(0, padding, padding), 51 | bias=False), 52 | nn.BatchNorm3d(midplanes), 53 | nn.ReLU(inplace=True), 54 | nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), 55 | stride=(stride, 1, 1), padding=(padding, 0, 0), 56 | bias=False)) 57 | 58 | @staticmethod 59 | def get_downsample_stride(stride): 60 | return stride, stride, stride 61 | 62 | 63 | class Conv3DNoTemporal(nn.Conv3d): 64 | 65 | def __init__(self, 66 | in_planes, 67 | out_planes, 68 | midplanes=None, 69 | stride=1, 70 | padding=1): 71 | 72 | super(Conv3DNoTemporal, self).__init__( 73 | in_channels=in_planes, 74 | out_channels=out_planes, 75 | kernel_size=(1, 3, 3), 76 | stride=(1, stride, stride), 77 | padding=(0, padding, padding), 78 | bias=False) 79 | 80 | @staticmethod 81 | def get_downsample_stride(stride): 82 | return 1, stride, stride 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | 87 | expansion = 1 88 | 89 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 90 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 91 | 92 | super(BasicBlock, self).__init__() 93 | self.conv1 = nn.Sequential( 94 | conv_builder(inplanes, planes, midplanes, stride), 95 | nn.BatchNorm3d(planes), 96 | nn.ReLU(inplace=True) 97 | ) 98 | self.conv2 = nn.Sequential( 99 | conv_builder(planes, planes, midplanes), 100 | nn.BatchNorm3d(planes) 101 | ) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.conv2(out) 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out += residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class Bottleneck(nn.Module): 121 | expansion = 4 122 | 123 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 124 | 125 | super(Bottleneck, self).__init__() 126 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 127 | 128 | # 1x1x1 129 | self.conv1 = nn.Sequential( 130 | nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), 131 | nn.BatchNorm3d(planes), 132 | nn.ReLU(inplace=True) 133 | ) 134 | # Second kernel 135 | self.conv2 = nn.Sequential( 136 | conv_builder(planes, planes, midplanes, stride), 137 | nn.BatchNorm3d(planes), 138 | nn.ReLU(inplace=True) 139 | ) 140 | 141 | # 1x1x1 142 | self.conv3 = nn.Sequential( 143 | nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), 144 | nn.BatchNorm3d(planes * self.expansion) 145 | ) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.downsample = downsample 148 | self.stride = stride 149 | 150 | def forward(self, x): 151 | residual = x 152 | 153 | out = self.conv1(x) 154 | out = self.conv2(out) 155 | out = self.conv3(out) 156 | 157 | if self.downsample is not None: 158 | residual = self.downsample(x) 159 | 160 | out += residual 161 | out = self.relu(out) 162 | 163 | return out 164 | 165 | 166 | class BasicStem(nn.Sequential): 167 | """The default conv-batchnorm-relu stem 168 | """ 169 | def __init__(self): 170 | super(BasicStem, self).__init__( 171 | nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), 172 | padding=(1, 3, 3), bias=False), 173 | nn.BatchNorm3d(64), 174 | nn.ReLU(inplace=True)) 175 | 176 | 177 | class R2Plus1dStem(nn.Sequential): 178 | """R(2+1)D stem is different than the default one as it uses separated 3D convolution 179 | """ 180 | def __init__(self): 181 | super(R2Plus1dStem, self).__init__( 182 | nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 183 | stride=(1, 2, 2), padding=(0, 3, 3), 184 | bias=False), 185 | nn.BatchNorm3d(45), 186 | nn.ReLU(inplace=True), 187 | nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 188 | stride=(1, 1, 1), padding=(1, 0, 0), 189 | bias=False), 190 | nn.BatchNorm3d(64), 191 | nn.ReLU(inplace=True)) 192 | 193 | 194 | class VideoResNet(nn.Module): 195 | 196 | def __init__(self, block, conv_makers, layers, 197 | stem, num_classes=400, 198 | zero_init_residual=False): 199 | """Generic resnet video generator. 200 | 201 | Args: 202 | block (nn.Module): resnet building block 203 | conv_makers (list(functions)): generator function for each layer 204 | layers (List[int]): number of blocks per layer 205 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 206 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 207 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 208 | """ 209 | super(VideoResNet, self).__init__() 210 | self.inplanes = 64 211 | 212 | self.stem = stem() 213 | 214 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 215 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 216 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 217 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 218 | 219 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 220 | self.fc = nn.Linear(512 * block.expansion, num_classes) 221 | 222 | # init weights 223 | self._initialize_weights() 224 | 225 | if zero_init_residual: 226 | for m in self.modules(): 227 | if isinstance(m, Bottleneck): 228 | nn.init.constant_(m.bn3.weight, 0) 229 | 230 | def forward(self, x): 231 | x = self.stem(x) 232 | 233 | x = self.layer1(x) 234 | x = self.layer2(x) 235 | x = self.layer3(x) 236 | x = self.layer4(x) 237 | 238 | x = self.avgpool(x) 239 | # Flatten the layer to fc 240 | x = x.flatten(1) 241 | x = self.fc(x) 242 | 243 | return x 244 | 245 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 246 | downsample = None 247 | 248 | if stride != 1 or self.inplanes != planes * block.expansion: 249 | ds_stride = conv_builder.get_downsample_stride(stride) 250 | downsample = nn.Sequential( 251 | nn.Conv3d(self.inplanes, planes * block.expansion, 252 | kernel_size=1, stride=ds_stride, bias=False), 253 | nn.BatchNorm3d(planes * block.expansion) 254 | ) 255 | layers = [] 256 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 257 | 258 | self.inplanes = planes * block.expansion 259 | for i in range(1, blocks): 260 | layers.append(block(self.inplanes, planes, conv_builder)) 261 | 262 | return nn.Sequential(*layers) 263 | 264 | def _initialize_weights(self): 265 | for m in self.modules(): 266 | if isinstance(m, nn.Conv3d): 267 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 268 | nonlinearity='relu') 269 | if m.bias is not None: 270 | nn.init.constant_(m.bias, 0) 271 | elif isinstance(m, nn.BatchNorm3d): 272 | nn.init.constant_(m.weight, 1) 273 | nn.init.constant_(m.bias, 0) 274 | elif isinstance(m, nn.Linear): 275 | nn.init.normal_(m.weight, 0, 0.01) 276 | nn.init.constant_(m.bias, 0) 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | class VideoResNet_kd(nn.Module): 286 | 287 | def __init__(self, block, conv_makers, layers, 288 | stem, num_classes=400, 289 | zero_init_residual=False): 290 | """Generic resnet video generator. 291 | 292 | Args: 293 | block (nn.Module): resnet building block 294 | conv_makers (list(functions)): generator function for each layer 295 | layers (List[int]): number of blocks per layer 296 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 297 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 298 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 299 | """ 300 | super(VideoResNet_kd, self).__init__() 301 | self.inplanes = 64 302 | 303 | self.stem = stem() 304 | 305 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 306 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 307 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 308 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 309 | 310 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 311 | self.fc = nn.Linear(512 * block.expansion, num_classes) 312 | 313 | # init weights 314 | self._initialize_weights() 315 | 316 | if zero_init_residual: 317 | for m in self.modules(): 318 | if isinstance(m, Bottleneck): 319 | nn.init.constant_(m.bn3.weight, 0) 320 | 321 | def forward(self, x): 322 | x = self.stem(x) 323 | x_l0 = x 324 | x = self.layer1(x) 325 | x_l1 = x 326 | x = self.layer2(x) 327 | x_l2 = x 328 | x = self.layer3(x) 329 | x_l3 = x 330 | x = self.layer4(x) 331 | x_l4 = x 332 | 333 | x = self.avgpool(x) 334 | 335 | x = x.flatten(1) 336 | x_reg_feat = x 337 | x = self.fc(x) 338 | 339 | return x, x_l0, x_l1, x_l2, x_reg_feat, x_l4, x_l3 340 | 341 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 342 | downsample = None 343 | 344 | if stride != 1 or self.inplanes != planes * block.expansion: 345 | ds_stride = conv_builder.get_downsample_stride(stride) 346 | downsample = nn.Sequential( 347 | nn.Conv3d(self.inplanes, planes * block.expansion, 348 | kernel_size=1, stride=ds_stride, bias=False), 349 | nn.BatchNorm3d(planes * block.expansion) 350 | ) 351 | layers = [] 352 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 353 | 354 | self.inplanes = planes * block.expansion 355 | for i in range(1, blocks): 356 | layers.append(block(self.inplanes, planes, conv_builder)) 357 | 358 | return nn.Sequential(*layers) 359 | 360 | def _initialize_weights(self): 361 | for m in self.modules(): 362 | if isinstance(m, nn.Conv3d): 363 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 364 | nonlinearity='relu') 365 | if m.bias is not None: 366 | nn.init.constant_(m.bias, 0) 367 | elif isinstance(m, nn.BatchNorm3d): 368 | nn.init.constant_(m.weight, 1) 369 | nn.init.constant_(m.bias, 0) 370 | elif isinstance(m, nn.Linear): 371 | nn.init.normal_(m.weight, 0, 0.01) 372 | nn.init.constant_(m.bias, 0) 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | def _video_resnet(arch, pretrained=False, progress=True, **kwargs): 388 | model = VideoResNet(**kwargs) 389 | 390 | if pretrained: 391 | state_dict = load_state_dict_from_url(model_urls[arch], 392 | progress=progress) 393 | model.load_state_dict(state_dict) 394 | return model 395 | 396 | 397 | 398 | def _video_resnet_kd(arch, pretrained=False, progress=True, **kwargs): 399 | model = VideoResNet_kd(**kwargs) 400 | 401 | if pretrained: 402 | state_dict = load_state_dict_from_url(model_urls[arch], 403 | progress=progress) 404 | model.load_state_dict(state_dict) 405 | return model 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | def r2plus1d_18(pretrained=False, progress=True, **kwargs): 415 | """Constructor for the 18 layer deep R(2+1)D network as in 416 | https://arxiv.org/abs/1711.11248 417 | 418 | Args: 419 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 420 | progress (bool): If True, displays a progress bar of the download to stderr 421 | 422 | Returns: 423 | nn.Module: R(2+1)D-18 network 424 | """ 425 | return _video_resnet('r2plus1d_18', 426 | pretrained, progress, 427 | block=BasicBlock, 428 | conv_makers=[Conv2Plus1D] * 4, 429 | layers=[2, 2, 2, 2], 430 | stem=R2Plus1dStem, **kwargs) 431 | 432 | 433 | 434 | 435 | 436 | def r2plus1d_18_kd(pretrained=False, progress=True, **kwargs): 437 | """Constructor for the 18 layer deep R(2+1)D network as in 438 | https://arxiv.org/abs/1711.11248 439 | 440 | Args: 441 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 442 | progress (bool): If True, displays a progress bar of the download to stderr 443 | 444 | Returns: 445 | nn.Module: R(2+1)D-18 network 446 | """ 447 | return _video_resnet_kd('r2plus1d_18', 448 | pretrained, progress, 449 | block=BasicBlock, 450 | conv_makers=[Conv2Plus1D] * 4, 451 | layers=[2, 2, 2, 2], 452 | stem=R2Plus1dStem, **kwargs) 453 | 454 | 455 | 456 | 457 | -------------------------------------------------------------------------------- /echonet/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation import * 2 | # from .fcn import * 3 | from .deeplabv3 import * 4 | # from .lraspp import * -------------------------------------------------------------------------------- /echonet/segmentation/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Optional, Dict 3 | 4 | from torch import nn, Tensor 5 | from torch.nn import functional as F 6 | import torch 7 | 8 | 9 | class _SimpleSegmentationModel(nn.Module): 10 | __constants__ = ['aux_classifier'] 11 | 12 | def __init__( 13 | self, 14 | backbone: nn.Module, 15 | classifier: nn.Module, 16 | aux_classifier: Optional[nn.Module] = None 17 | ) -> None: 18 | super(_SimpleSegmentationModel, self).__init__() 19 | self.backbone = backbone 20 | self.classifier = classifier 21 | # self.aux_classifier = aux_classifier 22 | 23 | self.ctr_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 24 | self.ctr_fc = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), nn.Linear(2048, 128)) 25 | 26 | 27 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 28 | input_shape = x.shape[-2:] 29 | # contract: features is a dict of tensors 30 | 31 | # print("self.backbone", self.backbone.conv1) 32 | features = self.backbone(x) 33 | 34 | result = OrderedDict() 35 | x = features["out"] 36 | x = self.classifier(x) 37 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 38 | result["out"] = x 39 | 40 | x_ctr = features["out"] 41 | x_ctr = self.ctr_avgpool(x_ctr) 42 | x_ctr = x_ctr.flatten(1) 43 | x_ctr = self.ctr_fc(x_ctr) 44 | # print("x_ctr.shape, in _utils segmentation", x_ctr.shape) 45 | result['ctr_feat'] = F.normalize(x_ctr, dim = 1) 46 | result['feat_mid'] = features["out"] 47 | return result 48 | 49 | 50 | 51 | 52 | 53 | class _SimpleSegmentationModel_CSS(nn.Module): 54 | __constants__ = ['aux_classifier'] 55 | 56 | def __init__( 57 | self, 58 | backbone: nn.Module, 59 | classifier: nn.Module, 60 | aux_classifier: Optional[nn.Module] = None 61 | ) -> None: 62 | super(_SimpleSegmentationModel_CSS, self).__init__() 63 | self.backbone = backbone 64 | self.classifier = classifier 65 | # self.aux_classifier = aux_classifier 66 | 67 | 68 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 69 | input_shape = x.shape[-2:] 70 | # contract: features is a dict of tensors 71 | 72 | # print("self.backbone", self.backbone.conv1) 73 | 74 | xtest = self.backbone.conv1(x) 75 | xtest = self.backbone.bn1(xtest) 76 | xtest = self.backbone.relu(xtest) 77 | xtest_layerbs = xtest 78 | xtest = self.backbone.maxpool(xtest) 79 | xtest_layer0 = xtest 80 | xtest = self.backbone.layer1(xtest) 81 | xtest_layer1 = xtest 82 | xtest = self.backbone.layer2(xtest) ### can just output here. 83 | xtest_layer2 = xtest 84 | xtest = self.backbone.layer3(xtest) 85 | xtest = self.backbone.layer4(xtest) 86 | # print("xtest_layerbs.shape", xtest_layerbs.shape)# xtest_layerbs.shape torch.Size([2, 64, 56, 56]) 87 | # print("xtest_layer0.shape", xtest_layer0.shape) #xtest_layer0.shape torch.Size([2, 64, 28, 28]) 88 | # print("xtest_layer1.shape", xtest_layer1.shape) #xtest_layer1.shape torch.Size([2, 256, 28, 28]) 89 | # print("xtest_layer2.shape", xtest_layer2.shape) #torch.Size([2, 512, 14, 14]) 90 | 91 | result = OrderedDict() 92 | x = xtest 93 | 94 | x = self.classifier(x) 95 | x_maskpre = x 96 | x_maskpre = F.interpolate(x_maskpre, size=[56,56], mode='bilinear', align_corners=False) 97 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 98 | result["out"] = x 99 | result['x_layerbs'] = xtest_layerbs 100 | result['x_layer1'] = xtest_layer1 101 | result['x_layer4'] = xtest 102 | result['maskfeat'] = x_maskpre 103 | return result 104 | 105 | 106 | -------------------------------------------------------------------------------- /echonet/segmentation/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from typing import List 5 | 6 | from ._utils import _SimpleSegmentationModel, _SimpleSegmentationModel_CSS 7 | 8 | 9 | __all__ = ["DeepLabV3", "DeepLabV3_CSS"] 10 | 11 | 12 | class DeepLabV3(_SimpleSegmentationModel): 13 | """ 14 | Implements DeepLabV3 model from 15 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 16 | `_. 17 | 18 | Args: 19 | backbone (nn.Module): the network used to compute the features for the model. 20 | The backbone should return an OrderedDict[Tensor], with the key being 21 | "out" for the last feature map used, and "aux" if an auxiliary classifier 22 | is used. 23 | classifier (nn.Module): module that takes the "out" element returned from 24 | the backbone and returns a dense prediction. 25 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 26 | """ 27 | pass 28 | 29 | 30 | 31 | class DeepLabV3_CSS(_SimpleSegmentationModel_CSS): 32 | """ 33 | Implements DeepLabV3 model from 34 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 35 | `_. 36 | 37 | Args: 38 | backbone (nn.Module): the network used to compute the features for the model. 39 | The backbone should return an OrderedDict[Tensor], with the key being 40 | "out" for the last feature map used, and "aux" if an auxiliary classifier 41 | is used. 42 | classifier (nn.Module): module that takes the "out" element returned from 43 | the backbone and returns a dense prediction. 44 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 45 | """ 46 | pass 47 | 48 | 49 | 50 | class DeepLabHead(nn.Sequential): 51 | def __init__(self, in_channels: int, num_classes: int) -> None: 52 | super(DeepLabHead, self).__init__( 53 | ASPP(in_channels, [12, 24, 36]), 54 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU(), 57 | nn.Conv2d(256, num_classes, 1) 58 | ) 59 | 60 | 61 | class ASPPConv(nn.Sequential): 62 | def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: 63 | modules = [ 64 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 65 | nn.BatchNorm2d(out_channels), 66 | nn.ReLU() 67 | ] 68 | super(ASPPConv, self).__init__(*modules) 69 | 70 | 71 | class ASPPPooling(nn.Sequential): 72 | def __init__(self, in_channels: int, out_channels: int) -> None: 73 | super(ASPPPooling, self).__init__( 74 | nn.AdaptiveAvgPool2d(1), 75 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 76 | nn.BatchNorm2d(out_channels), 77 | nn.ReLU()) 78 | 79 | def forward(self, x: torch.Tensor) -> torch.Tensor: 80 | size = x.shape[-2:] 81 | for mod in self: 82 | x = mod(x) 83 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 84 | 85 | 86 | class ASPP(nn.Module): 87 | def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: 88 | super(ASPP, self).__init__() 89 | modules = [] 90 | modules.append(nn.Sequential( 91 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 92 | nn.BatchNorm2d(out_channels), 93 | nn.ReLU())) 94 | 95 | rates = tuple(atrous_rates) 96 | for rate in rates: 97 | modules.append(ASPPConv(in_channels, out_channels, rate)) 98 | 99 | modules.append(ASPPPooling(in_channels, out_channels)) 100 | 101 | self.convs = nn.ModuleList(modules) 102 | 103 | self.project = nn.Sequential( 104 | nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), 105 | nn.BatchNorm2d(out_channels), 106 | nn.ReLU(), 107 | nn.Dropout(0.5)) 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | _res = [] 111 | for conv in self.convs: 112 | _res.append(conv(x)) 113 | res = torch.cat(_res, dim=1) 114 | return self.project(res) -------------------------------------------------------------------------------- /echonet/segmentation/segmentation.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from torch import nn 4 | from typing import Any, Optional, Dict 5 | # from .._utils import IntermediateLayerGetter 6 | # from ..._internally_replaced_utils import load_state_dict_from_url 7 | from torch.hub import load_state_dict_from_url 8 | from torchvision.models import resnet 9 | from .deeplabv3 import DeepLabHead, DeepLabV3, DeepLabV3_CSS 10 | # from .fcn import FCN, FCNHead 11 | 12 | __all__ = ['deeplabv3_resnet50', 'deeplabv3_resnet50_CSS'] 13 | 14 | 15 | model_urls = { 16 | 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth', 17 | 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', 18 | 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', 19 | 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', 20 | 'deeplabv3_mobilenet_v3_large_coco': 21 | 'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth', 22 | 'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth', 23 | } 24 | 25 | 26 | 27 | 28 | 29 | 30 | class IntermediateLayerGetter(nn.ModuleDict): 31 | """ 32 | Module wrapper that returns intermediate layers from a model 33 | It has a strong assumption that the modules have been registered 34 | into the model in the same order as they are used. 35 | This means that one should **not** reuse the same nn.Module 36 | twice in the forward if you want this to work. 37 | Additionally, it is only able to query submodules that are directly 38 | assigned to the model. So if `model` is passed, `model.feature1` can 39 | be returned, but not `model.feature1.layer2`. 40 | Args: 41 | model (nn.Module): model on which we will extract the features 42 | return_layers (Dict[name, new_name]): a dict containing the names 43 | of the modules for which the activations will be returned as 44 | the key of the dict, and the value of the dict is the name 45 | of the returned activation (which the user can specify). 46 | Examples:: 47 | >>> m = torchvision.models.resnet18(pretrained=True) 48 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 49 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 50 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 51 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 52 | >>> print([(k, v.shape) for k, v in out.items()]) 53 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 54 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 55 | """ 56 | _version = 2 57 | __annotations__ = { 58 | "return_layers": Dict[str, str], 59 | } 60 | 61 | def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: 62 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 63 | raise ValueError("return_layers are not present in model") 64 | orig_return_layers = return_layers 65 | return_layers = {str(k): str(v) for k, v in return_layers.items()} 66 | layers = OrderedDict() 67 | for name, module in model.named_children(): 68 | layers[name] = module 69 | if name in return_layers: 70 | del return_layers[name] 71 | if not return_layers: 72 | break 73 | 74 | super(IntermediateLayerGetter, self).__init__(layers) 75 | self.return_layers = orig_return_layers 76 | 77 | def forward(self, x): 78 | out = OrderedDict() 79 | for name, module in self.items(): 80 | x = module(x) 81 | if name in self.return_layers: 82 | out_name = self.return_layers[name] 83 | out[out_name] = x 84 | return out 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | def _segm_model( 93 | name: str, 94 | backbone_name: str, 95 | num_classes: int, 96 | aux: Optional[bool], 97 | pretrained_backbone: bool = True 98 | ) -> nn.Module: 99 | if 'resnet' in backbone_name: 100 | backbone = resnet.__dict__[backbone_name]( 101 | pretrained=pretrained_backbone, 102 | replace_stride_with_dilation=[False, True, True]) 103 | out_layer = 'layer4' 104 | out_inplanes = 2048 105 | aux_layer = 'layer3' 106 | aux_inplanes = 1024 107 | elif 'mobilenet_v3' in backbone_name: 108 | assert 1==2, "not using mobilenet" 109 | 110 | # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. 111 | # The first and last blocks are always included because they are the C0 (conv1) and Cn. 112 | stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] 113 | out_pos = stage_indices[-1] # use C5 which has output_stride = 16 114 | out_layer = str(out_pos) 115 | out_inplanes = backbone[out_pos].out_channels 116 | aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 117 | aux_layer = str(aux_pos) 118 | aux_inplanes = backbone[aux_pos].out_channels 119 | else: 120 | raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) 121 | 122 | return_layers = {out_layer: 'out'} 123 | if aux: 124 | return_layers[aux_layer] = 'aux' 125 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 126 | 127 | aux_classifier = None 128 | # if aux: 129 | # aux_classifier = FCNHead(aux_inplanes, num_classes) 130 | 131 | model_map = { 132 | 'deeplabv3': (DeepLabHead, DeepLabV3) #, 133 | # 'fcn': (FCNHead, FCN), 134 | } 135 | classifier = model_map[name][0](out_inplanes, num_classes) 136 | base_model = model_map[name][1] 137 | 138 | model = base_model(backbone, classifier, aux_classifier) 139 | return model 140 | 141 | 142 | 143 | 144 | 145 | def _segm_model_CSS( 146 | name: str, 147 | backbone_name: str, 148 | num_classes: int, 149 | aux: Optional[bool], 150 | pretrained_backbone: bool = True 151 | ) -> nn.Module: 152 | if 'resnet' in backbone_name: 153 | backbone = resnet.__dict__[backbone_name]( 154 | pretrained=pretrained_backbone, 155 | replace_stride_with_dilation=[False, True, True]) 156 | out_layer = 'layer4' 157 | out_inplanes = 2048 158 | aux_layer = 'layer3' 159 | aux_inplanes = 1024 160 | elif 'mobilenet_v3' in backbone_name: 161 | assert 1==2, "not using mobilenet" 162 | 163 | # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. 164 | # The first and last blocks are always included because they are the C0 (conv1) and Cn. 165 | stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] 166 | out_pos = stage_indices[-1] # use C5 which has output_stride = 16 167 | out_layer = str(out_pos) 168 | out_inplanes = backbone[out_pos].out_channels 169 | aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 170 | aux_layer = str(aux_pos) 171 | aux_inplanes = backbone[aux_pos].out_channels 172 | else: 173 | raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) 174 | 175 | return_layers = {out_layer: 'out'} 176 | if aux: 177 | return_layers[aux_layer] = 'aux' 178 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 179 | 180 | aux_classifier = None 181 | # if aux: 182 | # aux_classifier = FCNHead(aux_inplanes, num_classes) 183 | 184 | model_map = { 185 | 'deeplabv3': (DeepLabHead, DeepLabV3_CSS) #, 186 | # 'fcn': (FCNHead, FCN), 187 | } 188 | classifier = model_map[name][0](out_inplanes, num_classes) 189 | base_model = model_map[name][1] 190 | 191 | model = base_model(backbone, classifier, aux_classifier) 192 | return model 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | def _load_model( 202 | arch_type: str, 203 | backbone: str, 204 | pretrained: bool, 205 | progress: bool, 206 | num_classes: int, 207 | aux_loss: Optional[bool], 208 | **kwargs: Any 209 | ) -> nn.Module: 210 | if pretrained: 211 | aux_loss = True 212 | kwargs["pretrained_backbone"] = False 213 | model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs) 214 | if pretrained: 215 | _load_weights(model, arch_type, backbone, progress) 216 | return model 217 | 218 | 219 | 220 | def _load_model_CSS( 221 | arch_type: str, 222 | backbone: str, 223 | pretrained: bool, 224 | progress: bool, 225 | num_classes: int, 226 | aux_loss: Optional[bool], 227 | **kwargs: Any 228 | ) -> nn.Module: 229 | if pretrained: 230 | aux_loss = True 231 | kwargs["pretrained_backbone"] = False 232 | model = _segm_model_CSS(arch_type, backbone, num_classes, aux_loss, **kwargs) 233 | if pretrained: 234 | _load_weights(model, arch_type, backbone, progress) 235 | return model 236 | 237 | 238 | 239 | 240 | 241 | def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: 242 | arch = arch_type + '_' + backbone + '_coco' 243 | model_url = model_urls.get(arch, None) 244 | if model_url is None: 245 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 246 | else: 247 | # assert 1==2, "a bit mahfan, we don't allow pretrained for now, not needed in segmentation anyways" 248 | state_dict = load_state_dict_from_url(model_url, progress=progress) 249 | model.load_state_dict(state_dict, strict = False) 250 | 251 | 252 | def deeplabv3_resnet50( 253 | pretrained: bool = False, 254 | progress: bool = True, 255 | num_classes: int = 21, 256 | aux_loss: Optional[bool] = None, 257 | **kwargs: Any 258 | ) -> nn.Module: 259 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 260 | 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 263 | contains the same classes as Pascal VOC 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | num_classes (int): number of output classes of the model (including the background) 266 | aux_loss (bool): If True, it uses an auxiliary loss 267 | """ 268 | return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) 269 | 270 | 271 | 272 | 273 | def deeplabv3_resnet50_CSS( 274 | pretrained: bool = False, 275 | progress: bool = True, 276 | num_classes: int = 21, 277 | aux_loss: Optional[bool] = None, 278 | **kwargs: Any 279 | ) -> nn.Module: 280 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 281 | 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 284 | contains the same classes as Pascal VOC 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | num_classes (int): number of output classes of the model (including the background) 287 | aux_loss (bool): If True, it uses an auxiliary loss 288 | """ 289 | return _load_model_CSS('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) 290 | 291 | -------------------------------------------------------------------------------- /echonet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for videos, plotting and computing performance metrics.""" 2 | 3 | import os 4 | import typing 5 | import datetime 6 | 7 | import cv2 # pytype: disable=attribute-error 8 | import matplotlib 9 | import numpy as np 10 | import torch 11 | import tqdm 12 | 13 | 14 | from . import seg_cycle 15 | from . import video_segin 16 | from . import vidsegin_teachstd_kd 17 | 18 | 19 | def loadvideo(filename: str) -> np.ndarray: 20 | """Loads a video from a file. 21 | 22 | Args: 23 | filename (str): filename of video 24 | 25 | Returns: 26 | A np.ndarray with dimensions (channels=3, frames, height, width). The 27 | values will be uint8's ranging from 0 to 255. 28 | 29 | Raises: 30 | FileNotFoundError: Could not find `filename` 31 | ValueError: An error occurred while reading the video 32 | """ 33 | 34 | if not os.path.exists(filename): 35 | raise FileNotFoundError(filename) 36 | ###debug 37 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "opening vid") 38 | capture = cv2.VideoCapture(filename) 39 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "videocapture done") 40 | 41 | frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 42 | frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 43 | frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 44 | 45 | v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) 46 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "reading capture") 47 | for count in range(frame_count): 48 | ret, frame = capture.read() 49 | if not ret: 50 | raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) 51 | 52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 53 | v[count, :, :] = frame 54 | 55 | v = v.transpose((3, 0, 1, 2)) 56 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "finished opening vid") 57 | 58 | return v 59 | 60 | 61 | def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): 62 | """Saves a video to a file. 63 | 64 | Args: 65 | filename (str): filename of video 66 | array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) 67 | fps (float or int): frames per second 68 | 69 | Returns: 70 | None 71 | """ 72 | 73 | c, _, height, width = array.shape 74 | 75 | if c != 3: 76 | raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) 77 | fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') 78 | out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) 79 | 80 | for frame in array.transpose((1, 2, 3, 0)): 81 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 82 | out.write(frame) 83 | 84 | 85 | def get_mean_and_std(dataset: torch.utils.data.Dataset, 86 | samples: int = 128, 87 | batch_size: int = 8, 88 | num_workers: int = 4): 89 | """Computes mean and std from samples from a Pytorch dataset. 90 | 91 | Args: 92 | dataset (torch.utils.data.Dataset): A Pytorch dataset. 93 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which 94 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) 95 | samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and 96 | standard deviation are computed over all elements. 97 | Defaults to 128. 98 | batch_size (int, optional): how many samples per batch to load 99 | Defaults to 8. 100 | num_workers (int, optional): how many subprocesses to use for data 101 | loading. If 0, the data will be loaded in the main process. 102 | Defaults to 4. 103 | 104 | Returns: 105 | A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). 106 | """ 107 | 108 | if samples is not None and len(dataset) > samples: 109 | np.random.seed(0) 110 | indices = np.random.choice(len(dataset), samples, replace=False) 111 | dataset = torch.utils.data.Subset(dataset, indices) 112 | 113 | dataloader = torch.utils.data.DataLoader( 114 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) 115 | 116 | n = 0 # number of elements taken (should be equal to samples by end of for loop) 117 | s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) 118 | s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) 119 | # for (x, *_) in tqdm.tqdm(dataloader): 120 | for (x,_,*_) in tqdm.tqdm(dataloader): 121 | x = x.transpose(0, 1).contiguous().view(3, -1) 122 | n += x.shape[1] 123 | s1 += torch.sum(x, dim=1).numpy() 124 | s2 += torch.sum(x ** 2, dim=1).numpy() 125 | mean = s1 / n # type: np.ndarray 126 | std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray 127 | 128 | mean = mean.astype(np.float32) 129 | std = std.astype(np.float32) 130 | 131 | return mean, std 132 | 133 | 134 | def bootstrap(a, b, func, samples=10000): 135 | """Computes a bootstrapped confidence intervals for ``func(a, b)''. 136 | 137 | Args: 138 | a (array_like): first argument to `func`. 139 | b (array_like): second argument to `func`. 140 | func (callable): Function to compute confidence intervals for. 141 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which 142 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) 143 | samples (int, optional): Number of samples to compute. 144 | Defaults to 10000. 145 | 146 | Returns: 147 | A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). 148 | """ 149 | a = np.array(a) 150 | b = np.array(b) 151 | 152 | bootstraps = [] 153 | for _ in range(samples): 154 | ind = np.random.choice(len(a), len(a)) 155 | bootstraps.append(func(a[ind], b[ind])) 156 | bootstraps = sorted(bootstraps) 157 | 158 | return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] 159 | 160 | 161 | def latexify(): 162 | """Sets matplotlib params to appear more like LaTeX. 163 | 164 | Based on https://nipunbatra.github.io/blog/2014/latexify.html 165 | """ 166 | params = {'backend': 'pdf', 167 | 'axes.titlesize': 8, 168 | 'axes.labelsize': 8, 169 | 'font.size': 8, 170 | 'legend.fontsize': 8, 171 | 'xtick.labelsize': 8, 172 | 'ytick.labelsize': 8, 173 | 'font.family': 'DejaVu Serif', 174 | 'font.serif': 'Computer Modern', 175 | } 176 | matplotlib.rcParams.update(params) 177 | 178 | 179 | def dice_similarity_coefficient(inter, union): 180 | """Computes the dice similarity coefficient. 181 | 182 | Args: 183 | inter (iterable): iterable of the intersections 184 | union (iterable): iterable of the unions 185 | """ 186 | return 2 * sum(inter) / (sum(union) + sum(inter)) 187 | 188 | 189 | __all__ = ["video", "segmentation", "seg_ctrmlt", "seg_sslflw", "vidseg", "seg_cycle", "vidsegin_iekd_att", "vidseg_iekd_att_mult", "vidseg_iekd_att_mult_reg", "vidsegin_iekd_att_reg", "video_segin", "vidseg_iekd_att", "video_seginsegonly", "video_segin_hallucinate", "videossl", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient", "vidseg_iekd_ftchck"] 190 | -------------------------------------------------------------------------------- /echonet/utils/seg_cycle.py: -------------------------------------------------------------------------------- 1 | """Functions for training and running segmentation.""" 2 | 3 | import math 4 | import os 5 | import time 6 | import shutil 7 | import datetime 8 | import pandas as pd 9 | 10 | import click 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import scipy.signal 14 | import skimage.draw 15 | from PIL import Image 16 | import torch 17 | import torchvision 18 | import tqdm 19 | 20 | import echonet 21 | 22 | 23 | @click.command("seg_cycle") 24 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) 25 | @click.option("--output", type=click.Path(file_okay=False), default=None) 26 | @click.option("--model_name", type=click.Choice( 27 | sorted(name for name in torchvision.models.segmentation.__dict__ 28 | if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))), 29 | default="deeplabv3_resnet50") 30 | @click.option("--pretrained/--random", default=False) 31 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) 32 | @click.option("--run_test/--skip_test", default=False) 33 | @click.option("--save_video", type=str, default=None) 34 | @click.option("--num_epochs", type=int, default=25) 35 | @click.option("--lr", type=float, default=1e-5) 36 | @click.option("--weight_decay", type=float, default=0) 37 | @click.option("--lr_step_period", type=int, default=None) 38 | @click.option("--num_train_patients", type=int, default=None) 39 | @click.option("--num_workers", type=int, default=4) 40 | @click.option("--batch_size", type=int, default=20) 41 | @click.option("--device", type=str, default=None) 42 | @click.option("--seed", type=int, default=0) 43 | @click.option("--reduced_set/--full_set", default=True) 44 | @click.option("--rd_label", type=int, default=920) 45 | @click.option("--rd_unlabel", type=int, default=6440) 46 | @click.option("--ssl_edesonly/--ssl_rndfrm", default=True) 47 | @click.option("--run_inference", type=str, default=None) 48 | @click.option("--chunk_size", type=int, default=3) 49 | @click.option("--cyc_off", type=int, default=2) 50 | @click.option("--target_region", type=int, default=15) 51 | @click.option("--temperature", type=int, default=10) 52 | @click.option("--val_chunk", type=int, default=40) 53 | @click.option("--loss_cyc_w", type=float, default=1) 54 | @click.option("--css_strtup", type=int, default=0) 55 | 56 | def run( 57 | data_dir=None, 58 | output=None, 59 | model_name="deeplabv3_resnet50", 60 | pretrained=False, 61 | weights=None, 62 | run_test=False, 63 | save_video=None, 64 | num_epochs=25, 65 | lr=1e-5, 66 | weight_decay=1e-5, 67 | lr_step_period=None, 68 | num_train_patients=None, 69 | num_workers=4, 70 | batch_size=20, 71 | device=None, 72 | seed=0, 73 | reduced_set = True, 74 | rd_label = 920, 75 | rd_unlabel = 6440, 76 | ssl_edesonly = True, 77 | run_inference = None, 78 | chunk_size = 3, 79 | cyc_off = 2, 80 | target_region = 15, 81 | temperature = 10, 82 | val_chunk = 40, 83 | loss_cyc_w = 1, 84 | css_strtup = 0 85 | ): 86 | 87 | if reduced_set: 88 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))): 89 | print("Generating new file list for ssl dataset") 90 | np.random.seed(0) 91 | 92 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv")) 93 | data["Split"].map(lambda x: x.upper()) 94 | 95 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName']) 96 | np.random.shuffle(file_name_list) 97 | 98 | label_list = file_name_list[:rd_label] 99 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel] 100 | 101 | data['SSL_SPLIT'] = "EXCLUDE" 102 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED" 103 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED" 104 | 105 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False) 106 | 107 | 108 | # Seed RNGs 109 | np.random.seed(seed) 110 | torch.manual_seed(seed) 111 | 112 | def worker_init_fn(worker_id): 113 | # print("worker id is", torch.utils.data.get_worker_info().id) 114 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2 115 | np.random.seed(np.random.get_state()[1][0] + worker_id) 116 | 117 | 118 | # Set default output directory 119 | if output is None: 120 | output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random")) 121 | os.makedirs(output, exist_ok=True) 122 | 123 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 124 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))): 125 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 126 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 127 | 128 | 129 | # Set device for computations 130 | if device is None: 131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 132 | elif device == "gpu": 133 | device = torch.device("cuda") 134 | elif device == "cpu": 135 | device = torch.device("cpu") 136 | else: 137 | assert 1==2, "wrong parameter for device" 138 | 139 | 140 | 141 | #### Setup model 142 | model_0 = echonet.segmentation.segmentation.deeplabv3_resnet50_CSS(pretrained=pretrained, aux_loss=False) 143 | model_0.classifier[-1] = torch.nn.Conv2d(model_0.classifier[-1].in_channels, 1, kernel_size=model_0.classifier[-1].kernel_size) # change number of outputs to 1 144 | model_0 = torch.nn.DataParallel(model_0) 145 | model_0.to(device) 146 | 147 | if weights: 148 | checkpoint = torch.load(weights) 149 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False) 150 | 151 | # Set up optimizer 152 | optim_0 = torch.optim.SGD(model_0.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 153 | if lr_step_period is None: 154 | lr_step_period = math.inf 155 | scheduler_0 = torch.optim.lr_scheduler.StepLR(optim_0, lr_step_period) 156 | 157 | 158 | # Compute mean and std 159 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) 160 | tasks_eval = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] 161 | kwargs_eval = {"target_type": tasks_eval, 162 | "mean": mean, 163 | "std": std 164 | } 165 | 166 | tasks_seg = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] 167 | kwargs_seg = {"target_type": tasks_seg, 168 | "mean": mean, 169 | "std": std 170 | } 171 | 172 | kwargs = {"target_type": ["EF", "CYCLE"], 173 | "mean": mean, 174 | "std": std, 175 | "length": 40, 176 | "period": 3, 177 | } 178 | 179 | 180 | dataset = {} 181 | dataset_trainsub = {} 182 | dataset_valsub = {} 183 | if reduced_set: 184 | dataset_trainsub['lb_seg'] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, ssl_edesonly = True) 185 | dataset_trainsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1) 186 | dataset_trainsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2) 187 | else: 188 | assert not ssl_edesonly, "Check parameters, trying to conduct ssl with full datasest with EDES only" 189 | dataset_trainsub['lb_seg'] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs_seg, ssl_postfix="", ssl_type = 0, ssl_edesonly = True) 190 | dataset_trainsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="", ssl_type = 0) 191 | dataset_trainsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="", ssl_type = 0) 192 | dataset['train'] = dataset_trainsub 193 | 194 | 195 | if reduced_set: 196 | dataset_valsub["lb_seg"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel)) 197 | dataset_valsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel)) 198 | dataset_valsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel)) 199 | else: 200 | assert 1 == 2, "only run with reduced set for now " 201 | dataset["val"] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="") 202 | dataset['val'] = dataset_valsub 203 | 204 | 205 | # Run training and testing loops 206 | with open(os.path.join(output, "log.csv"), "a") as f: 207 | 208 | f.write("Run timestamp: {}\n".format(bkup_tmstmp)) 209 | 210 | epoch_resume = 0 211 | bestLoss = float("inf") 212 | try: 213 | # Attempt to load checkpoint 214 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) 215 | print("checkpoint.keys", checkpoint.keys()) 216 | model_0.load_state_dict(checkpoint['state_dict']) 217 | optim_0.load_state_dict(checkpoint['opt_dict']) 218 | scheduler_0.load_state_dict(checkpoint['scheduler_dict']) 219 | 220 | np_rndstate_chkpt = checkpoint['np_rndstate'] 221 | trch_rndstate_chkpt = checkpoint['trch_rndstate'] 222 | 223 | np.random.set_state(np_rndstate_chkpt) 224 | torch.set_rng_state(trch_rndstate_chkpt) 225 | 226 | epoch_resume = checkpoint["epoch"] + 1 227 | bestLoss = checkpoint["best_loss"] 228 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 229 | except FileNotFoundError: 230 | f.write("Starting run from scratch\n") 231 | 232 | for epoch in range(epoch_resume, num_epochs): 233 | print("Epoch #{}".format(epoch), flush=True) 234 | for phase in ['train', 'val']: 235 | start_time = time.time() 236 | 237 | if device.type == "cuda": 238 | for i in range(torch.cuda.device_count()): 239 | torch.cuda.reset_peak_memory_stats(i) 240 | 241 | ds = dataset[phase] 242 | 243 | if phase == "train": 244 | 245 | dataloader_lb_seg = torch.utils.data.DataLoader( 246 | ds['lb_seg'], batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 247 | dataloader_lb_cyc = torch.utils.data.DataLoader( 248 | ds['lb_cyc'], batch_size=1, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 249 | dataloader_unlb_cyc = torch.utils.data.DataLoader( 250 | ds['unlb_cyc'], batch_size=1, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 251 | 252 | 253 | loss, loss_seg, lrgdice, smldice, loss_cyc, large_inter_0, large_union_0, small_inter_0, small_union_0 = echonet.utils.seg_cycle.run_epoch_ssl( model_0, 254 | dataloader_lb_seg, 255 | dataloader_lb_cyc, 256 | dataloader_unlb_cyc, 257 | phase == "train", 258 | optim_0, 259 | batch_size, 260 | device, 261 | output, 262 | phase, 263 | mean, 264 | std, 265 | epoch, 266 | chunk_size = chunk_size, 267 | cyc_off = cyc_off, 268 | target_region = target_region, 269 | temperature = temperature, 270 | val_chunk = val_chunk, 271 | loss_cyc_w = loss_cyc_w, 272 | css_strtup = css_strtup 273 | ) 274 | 275 | 276 | overall_dice_0 = 2 * (large_inter_0.sum() + small_inter_0.sum()) / (large_union_0.sum() + large_inter_0.sum() + small_union_0.sum() + small_inter_0.sum()) 277 | large_dice_0 = 2 * large_inter_0.sum() / (large_union_0.sum() + large_inter_0.sum()) 278 | small_dice_0 = 2 * small_inter_0.sum() / (small_union_0.sum() + small_inter_0.sum()) 279 | 280 | f.write("{},{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, 281 | phase, 282 | loss, 283 | loss_seg, 284 | loss_cyc, 285 | overall_dice_0, 286 | large_dice_0, 287 | small_dice_0, 288 | time.time() - start_time, 289 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 290 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 291 | batch_size)) 292 | f.flush() 293 | 294 | else: 295 | dataloader_lb_seg = torch.utils.data.DataLoader( 296 | ds['lb_seg'], batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 297 | dataloader_lb_cyc = torch.utils.data.DataLoader( 298 | ds['lb_cyc'], batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 299 | dataloader_unlb_cyc = torch.utils.data.DataLoader( 300 | ds['unlb_cyc'], batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 301 | 302 | 303 | 304 | loss, loss_seg_val, lrgdice_val, smldice_val, loss_cyc_val, large_inter_0, large_union_0, small_inter_0, small_union_0 = echonet.utils.seg_cycle.run_epoch_ssl( model_0, 305 | dataloader_lb_seg, 306 | dataloader_lb_cyc, 307 | dataloader_unlb_cyc, 308 | phase == "train", 309 | optim_0, 310 | batch_size, 311 | device, 312 | output, 313 | phase, 314 | mean, 315 | std, 316 | epoch, 317 | chunk_size = chunk_size, 318 | cyc_off = cyc_off, 319 | target_region = target_region, 320 | temperature = temperature, 321 | val_chunk = val_chunk, 322 | loss_cyc_w = loss_cyc_w, 323 | css_strtup = css_strtup 324 | ) 325 | 326 | overall_dice_0 = 2 * (large_inter_0.sum() + small_inter_0.sum()) / (large_union_0.sum() + large_inter_0.sum() + small_union_0.sum() + small_inter_0.sum()) 327 | large_dice_0 = 2 * large_inter_0.sum() / (large_union_0.sum() + large_inter_0.sum()) 328 | small_dice_0 = 2 * small_inter_0.sum() / (small_union_0.sum() + small_inter_0.sum()) 329 | 330 | f.write("{},{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, 331 | phase, 332 | loss, 333 | loss_seg_val, 334 | loss_cyc_val, 335 | overall_dice_0, 336 | large_dice_0, 337 | small_dice_0, 338 | time.time() - start_time, 339 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 340 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 341 | batch_size)) 342 | 343 | 344 | 345 | f.flush() 346 | 347 | 348 | scheduler_0.step() 349 | 350 | # Save checkpoint 351 | save = { 352 | 'epoch': epoch, 353 | 'state_dict': model_0.state_dict(), 354 | 'best_loss': bestLoss, 355 | 'loss': loss, 356 | 'opt_dict': optim_0.state_dict(), 357 | 'scheduler_dict': scheduler_0.state_dict(), 358 | 'np_rndstate': np.random.get_state(), 359 | 'trch_rndstate': torch.get_rng_state() 360 | } 361 | torch.save(save, os.path.join(output, "checkpoint.pt")) 362 | if loss_seg_val < bestLoss: 363 | print("saved best because {} < {}".format(loss_seg_val, bestLoss)) 364 | torch.save(save, os.path.join(output, "best.pt")) 365 | bestLoss = loss_seg_val 366 | 367 | # Load best weights 368 | if num_epochs != 0: 369 | checkpoint = torch.load(os.path.join(output, "best.pt")) 370 | model_0.load_state_dict(checkpoint['state_dict']) 371 | 372 | f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) 373 | f.flush() 374 | 375 | if run_test: 376 | for split in ["val", "test"]: 377 | if reduced_set: 378 | if split == "train": 379 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2) 380 | else: 381 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel)) 382 | else: 383 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="") 384 | 385 | dataloader = torch.utils.data.DataLoader(dataset, 386 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) 387 | loss, large_inter, large_union, small_inter, small_union = echonet.utils.seg_cycle.run_epoch(model_0, dataloader, False, None, device) 388 | 389 | overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter) 390 | large_dice = 2 * large_inter / (large_union + large_inter) 391 | small_dice = 2 * small_inter / (small_union + small_inter) 392 | with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g: 393 | g.write("Filename, Overall, Large, Small\n") 394 | for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice): 395 | g.write("{},{},{},{}\n".format(filename, overall, large, small)) 396 | 397 | f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient))) 398 | f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient))) 399 | f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient))) 400 | f.flush() 401 | 402 | if run_inference: 403 | if run_inference == "all": 404 | run_inference_range = ['train', 'val', 'test'] 405 | else: 406 | run_inference_range = [run_inference] 407 | 408 | for run_inference_itr in run_inference_range: 409 | if run_inference_itr != "train" or True: 410 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr, 411 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate 412 | mean=mean, std=std, # Normalization 413 | length=None, max_length=None, period=1 # Take all frames 414 | ) 415 | else: 416 | if reduced_set: 417 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr, 418 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate 419 | mean=mean, std=std, # Normalization 420 | ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, ssl_mult = 1, 421 | length=None, max_length=None, period=1 # Take all frames 422 | ) 423 | else: 424 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr, 425 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate 426 | mean=mean, std=std, # Normalization 427 | length=None, max_length=None, period=1 # Take all frames 428 | ) 429 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn) 430 | 431 | output_dir = os.path.join(output, "{}_infer_cmpct".format(run_inference_itr)) 432 | 433 | os.makedirs(output_dir, exist_ok = True) 434 | 435 | checkpoint = torch.load(os.path.join(output, "best.pt")) 436 | 437 | model_0.load_state_dict(checkpoint['state_dict']) 438 | 439 | model_0.eval() 440 | 441 | with torch.no_grad(): 442 | for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader): 443 | # Run segmentation model on blocks of frames one-by-one 444 | # The whole concatenated video may be too long to run together 445 | 446 | print(os.path.join(output_dir, "{}_{}.npy".format(filenames[-1].replace(".avi", ""), length[-1] - 1))) 447 | 448 | 449 | if os.path.isfile(os.path.join(output_dir, "{}.npy".format(filenames[-1].replace(".avi", "")))): 450 | # print("already exists") 451 | continue 452 | 453 | y = np.concatenate([model_0(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)]) 454 | 455 | y_idx = 0 456 | for batch_idx in range(len(filenames)): 457 | filename_itr = filenames[batch_idx] 458 | 459 | logit = y[y_idx:y_idx + length[batch_idx], 0, :, :] 460 | 461 | logit_out_path = os.path.join(output_dir, "{}.npy".format(filename_itr.replace(".avi", ""))) 462 | np.save(logit_out_path, logit) 463 | y_idx = y_idx + length[batch_idx] 464 | 465 | pass 466 | 467 | 468 | 469 | 470 | 471 | 472 | def run_epoch_ssl(model_0, 473 | dataloader_lb_seg, 474 | dataloader_lb_cyc, 475 | dataloader_unlb_cyc, 476 | train, 477 | optim_0, 478 | batch_size, 479 | device, 480 | output, 481 | phase, 482 | mean, 483 | std, 484 | epoch, 485 | chunk_size = 3, 486 | cyc_off = 2, 487 | target_region = 15, 488 | temperature = 10, 489 | val_chunk = 40, 490 | loss_cyc_w = 1, 491 | css_strtup = 0 492 | ): 493 | 494 | 495 | n = 0 496 | n_seg = 0 497 | 498 | total = 0 499 | total_cyc = 0 500 | 501 | total_seg = 0 502 | 503 | model_0.train(train) 504 | output_dir = os.path.join(output, "{}_feat_comp".format(phase)) 505 | os.makedirs(output_dir, exist_ok = True) 506 | 507 | large_inter_0 = 0 508 | large_union_0 = 0 509 | small_inter_0 = 0 510 | small_union_0 = 0 511 | large_inter_list_0 = [] 512 | large_union_list_0 = [] 513 | small_inter_list_0 = [] 514 | small_union_list_0 = [] 515 | 516 | torch.set_grad_enabled(train) 517 | 518 | total_itr_num = len(dataloader_lb_seg) 519 | 520 | dataloader_lb_seg_itr = iter(dataloader_lb_seg) 521 | dataloader_unlb_cyc_itr = iter(dataloader_unlb_cyc) 522 | 523 | for train_iter in range(total_itr_num): 524 | 525 | #### Supervised segmentation 526 | _, (large_frame, small_frame, large_trace, small_trace) = dataloader_lb_seg_itr.next() 527 | 528 | large_frame = large_frame.to(device) 529 | large_trace = large_trace.to(device) 530 | 531 | small_frame = small_frame.to(device) 532 | small_trace = small_trace.to(device) 533 | 534 | if not train: 535 | with torch.no_grad(): 536 | y_large_0 = model_0(large_frame)["out"] 537 | else: 538 | y_large_0 = model_0(large_frame)["out"] 539 | 540 | loss_large_0 = torch.nn.functional.binary_cross_entropy_with_logits(y_large_0[:, 0, :, :], large_trace, reduction="sum") 541 | # Compute pixel intersection and union between human and computer segmentations 542 | large_inter_0 += np.logical_and(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 543 | large_union_0 += np.logical_or(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 544 | large_inter_list_0.extend(np.logical_and(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 545 | large_union_list_0.extend(np.logical_or(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 546 | 547 | y_small_0 = model_0(small_frame)["out"] 548 | loss_small_0 = torch.nn.functional.binary_cross_entropy_with_logits(y_small_0[:, 0, :, :], small_trace, reduction="sum") 549 | # Compute pixel intersection and union between human and computer segmentations 550 | small_inter_0 += np.logical_and(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 551 | small_union_0 += np.logical_or(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 552 | small_inter_list_0.extend(np.logical_and(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 553 | small_union_list_0.extend(np.logical_or(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 554 | 555 | loss_seg = (loss_large_0 + loss_small_0) / 2 556 | 557 | loss_seg_item = loss_seg.item() 558 | large_trace_size = large_trace.size(0) 559 | total_seg += loss_seg_item * large_trace.size(0) 560 | 561 | 562 | 563 | ### CSS training 564 | X_raw, target, target_iekd, start, video_path, i1, j1 = dataloader_unlb_cyc_itr.next() 565 | X_bfcwh = X_raw.permute(0,2,1,3,4) 566 | X_segfeed = X_bfcwh.reshape(-1, X_bfcwh.shape[2], X_bfcwh.shape[3], X_bfcwh.shape[4]) 567 | 568 | ####### get feature output 569 | if not train: 570 | with torch.no_grad(): 571 | feat_out = model_0(X_segfeed)['x_layer4'].sum(dim=(2,3)) 572 | else: 573 | feat_out = model_0(X_segfeed)['x_layer4'].sum(dim=(2,3)) 574 | 575 | 576 | feat_out_query = feat_out[:target_region] # Template region P 577 | feat_out_query_cyc = feat_out[cyc_off:target_region] # Template region with offset 578 | feat_out_key = feat_out[target_region:] # Search region Q 579 | 580 | target_strtpt = np.random.choice(target_region - (chunk_size + cyc_off) + 1) ## choosing p* 581 | target_strtpt_1ht = torch.eye(target_region - (chunk_size + cyc_off) + 1)[target_strtpt] 582 | target_strtpt_1ht = target_strtpt_1ht.to(device) 583 | 584 | query_feat = feat_out_query[target_strtpt:target_strtpt + chunk_size, ...] ### choosing E^p* 585 | 586 | key_size = feat_out_key.shape[0] 587 | feat_size = feat_out.shape[1] 588 | 589 | ### feature-wise distance calculation 590 | dist_mat = feat_out_key.unsqueeze(1).repeat((1,chunk_size, 1)) - query_feat.unsqueeze(1).transpose(0,1).repeat(key_size, 1, 1) 591 | dist_mat_sq = dist_mat.pow(2) 592 | dist_mat_sq_ftsm = dist_mat_sq.sum(dim = -1) 593 | 594 | ### distance calculation per phase 595 | indices_ftsm = torch.arange(chunk_size) 596 | gather_indx_ftsm = torch.arange(key_size).view((key_size, 1)).repeat((1,chunk_size)) 597 | gather_indx_shft_ftsm = (gather_indx_ftsm + indices_ftsm) % (key_size) ### gets a index corresponding to the feature vectors included in each phase 598 | gather_indx_shft_ftsm = gather_indx_shft_ftsm.to(device) 599 | dist_mat_sq_shft_ftsm = torch.gather(dist_mat_sq_ftsm, 0, gather_indx_shft_ftsm)[:key_size - (chunk_size + cyc_off) + 1] ### gathers the feature-wise distance values to calculate the distance for the phase 600 | dist_mat_sq_total_ftsm = dist_mat_sq_shft_ftsm.sum(dim=(1)) 601 | 602 | ### calculating similarity value 603 | similarity = - dist_mat_sq_total_ftsm 604 | similarity_averaged = similarity / feat_size / chunk_size * temperature 605 | alpha_raw = torch.nn.functional.softmax(similarity_averaged, dim = 0) 606 | alpha_weights = alpha_raw.unsqueeze(1).unsqueeze(1).repeat([1, chunk_size, feat_size]) 607 | 608 | 609 | #### calculate shifted phase values 610 | indices_beta = torch.arange(chunk_size).view((1, chunk_size, 1)).repeat((key_size,1, feat_size)) 611 | gather_indx_beta = torch.arange(key_size).view((key_size, 1, 1)).repeat((1,chunk_size, feat_size)) 612 | gather_indx_alpha_shft = (gather_indx_beta + indices_beta) % (key_size) 613 | gather_indx_alpha_shft = gather_indx_alpha_shft.to(device) 614 | feat_out_key_beta = torch.gather(feat_out_key.unsqueeze(1).repeat(1, chunk_size, 1), 0, gather_indx_alpha_shft)[cyc_off:key_size - chunk_size + 1] 615 | 616 | ### calculate \tilde{E}^{q+c} 617 | weighted_features = alpha_weights * feat_out_key_beta 618 | weighted_features_averaged = weighted_features.sum(dim=0) 619 | 620 | 621 | #### match back to template region and find distance value 622 | q_dist_mat = feat_out_query_cyc.unsqueeze(1).repeat((1,chunk_size, 1)) - weighted_features_averaged.unsqueeze(1).transpose(0,1).repeat((target_region - cyc_off), 1, 1) 623 | q_dist_mat_sq = q_dist_mat.pow(2) 624 | q_dist_mat_sq_ftsm = q_dist_mat_sq.sum(dim = -1) 625 | 626 | indices_query_ftsm = torch.arange(chunk_size) 627 | gather_indx_query_ftsm = torch.arange(target_region - cyc_off).view((target_region - cyc_off, 1)).repeat((1,chunk_size)) 628 | gather_indx_query_shft_ftsm = (gather_indx_query_ftsm + indices_query_ftsm) % (target_region - cyc_off) 629 | gather_indx_query_shft_ftsm = gather_indx_query_shft_ftsm.to(device) 630 | q_dist_mat_sq_shft_ftsm = torch.gather(q_dist_mat_sq_ftsm, 0, gather_indx_query_shft_ftsm)[:(target_region - cyc_off) - chunk_size + 1] 631 | q_dist_mat_sq_total_ftsm = q_dist_mat_sq_shft_ftsm.sum(dim=(1)) 632 | 633 | ### calculate similarity value 634 | q_similarity = - q_dist_mat_sq_total_ftsm 635 | q_similarity_averaged = q_similarity / feat_size / chunk_size * temperature 636 | 637 | ### calculate cross-entropy loss 638 | frm_prd = torch.argmax(q_similarity_averaged) 639 | frm_lb = torch.argmax(target_strtpt_1ht) 640 | 641 | loss_cyc_raw = torch.nn.functional.cross_entropy(q_similarity_averaged.unsqueeze(0), frm_lb.unsqueeze(0)) 642 | loss_cyc_wght = loss_cyc_raw * loss_cyc_w 643 | 644 | loss_cyc_raw_item = loss_cyc_raw.item() 645 | total_cyc += loss_cyc_raw_item 646 | 647 | 648 | if train: 649 | if epoch < css_strtup: 650 | loss_total = loss_seg 651 | else: 652 | loss_total = loss_seg + loss_cyc_wght 653 | optim_0.zero_grad() 654 | loss_total.backward() 655 | optim_0.step() 656 | 657 | 658 | loss_total_item = loss_seg_item + loss_cyc_raw_item * loss_cyc_w 659 | 660 | total += loss_total_item 661 | 662 | n += 1 663 | n_seg += large_trace_size 664 | 665 | 666 | # Show info on process bar 667 | if train_iter % 5 == 0: 668 | print("Itr trainphase {} - {}/{} - ttl {:.4f} ({:.4f}) seg {:.4f} ({:.4f}) dlrg {:.4f} dsml {:.4f} cyc {:.4f} ({:.4f}) ".format( 669 | train, 670 | train_iter, 671 | total_itr_num, 672 | total / n_seg , # total 673 | loss_total_item, # total_item 674 | total_seg / n_seg / 112 / 112 , # total seg 675 | loss_seg_item, # seg item 676 | 2 * large_inter_0 / (large_union_0 + large_inter_0 + 0.000001), 677 | 2 * small_inter_0 / (small_union_0 + small_inter_0 + 0.000001), 678 | total_cyc / n, 679 | loss_cyc_raw_item 680 | ), flush = True) 681 | 682 | large_inter_list_0 = np.array(large_inter_list_0) 683 | large_union_list_0 = np.array(large_union_list_0) 684 | small_inter_list_0 = np.array(small_inter_list_0) 685 | small_union_list_0 = np.array(small_union_list_0) 686 | 687 | return (total / n_seg, 688 | total_seg / n_seg / 112 / 112, 689 | 2 * large_inter_0 / (large_union_0 + large_inter_0 + 0.000001), 690 | 2 * small_inter_0 / (small_union_0 + small_inter_0 + 0.000001), 691 | total_cyc / n, 692 | large_inter_list_0, 693 | large_union_list_0, 694 | small_inter_list_0, 695 | small_union_list_0 696 | ) 697 | 698 | 699 | 700 | 701 | def run_epoch(model, dataloader, train, optim, device): 702 | """Run one epoch of training/evaluation for segmentation. 703 | 704 | Args: 705 | model (torch.nn.Module): Model to train/evaulate. 706 | dataloder (torch.utils.data.DataLoader): Dataloader for dataset. 707 | train (bool): Whether or not to train model. 708 | optim (torch.optim.Optimizer): Optimizer 709 | device (torch.device): Device to run on 710 | """ 711 | 712 | total = 0. 713 | n = 0 714 | 715 | pos = 0 716 | neg = 0 717 | pos_pix = 0 718 | neg_pix = 0 719 | 720 | model.train(train) 721 | 722 | large_inter = 0 723 | large_union = 0 724 | small_inter = 0 725 | small_union = 0 726 | large_inter_list = [] 727 | large_union_list = [] 728 | small_inter_list = [] 729 | small_union_list = [] 730 | 731 | with torch.set_grad_enabled(train): 732 | with tqdm.tqdm(total=len(dataloader)) as pbar: 733 | for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader: 734 | # Count number of pixels in/out of human segmentation 735 | pos += (large_trace == 1).sum().item() 736 | pos += (small_trace == 1).sum().item() 737 | neg += (large_trace == 0).sum().item() 738 | neg += (small_trace == 0).sum().item() 739 | 740 | # Count number of pixels in/out of computer segmentation 741 | pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy() 742 | pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy() 743 | neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy() 744 | neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy() 745 | 746 | # Run prediction for diastolic frames and compute loss 747 | large_frame = large_frame.to(device) 748 | large_trace = large_trace.to(device) 749 | y_large = model(large_frame)["out"] 750 | loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum") 751 | # Compute pixel intersection and union between human and computer segmentations 752 | large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 753 | large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 754 | large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 755 | large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 756 | 757 | # Run prediction for systolic frames and compute loss 758 | small_frame = small_frame.to(device) 759 | small_trace = small_trace.to(device) 760 | y_small = model(small_frame)["out"] 761 | loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum") 762 | # Compute pixel intersection and union between human and computer segmentations 763 | small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 764 | small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() 765 | small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 766 | small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) 767 | 768 | # Take gradient step if training 769 | loss = (loss_large + loss_small) / 2 770 | if train: 771 | optim.zero_grad() 772 | loss.backward() 773 | optim.step() 774 | 775 | # Accumulate losses and compute baselines 776 | total += loss.item() 777 | n += large_trace.size(0) 778 | p = pos / (pos + neg) 779 | p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2) 780 | 781 | # Show info on process bar 782 | pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter))) 783 | pbar.update() 784 | 785 | large_inter_list = np.array(large_inter_list) 786 | large_union_list = np.array(large_union_list) 787 | small_inter_list = np.array(small_inter_list) 788 | small_union_list = np.array(small_union_list) 789 | 790 | return (total / n / 112 / 112, 791 | large_inter_list, 792 | large_union_list, 793 | small_inter_list, 794 | small_union_list, 795 | ) 796 | 797 | 798 | def _video_collate_fn(x): 799 | """Collate function for Pytorch dataloader to merge multiple videos. 800 | 801 | This function should be used in a dataloader for a dataset that returns 802 | a video as the first element, along with some (non-zero) tuple of 803 | targets. Then, the input x is a list of tuples: 804 | - x[i][0] is the i-th video in the batch 805 | - x[i][1] are the targets for the i-th video 806 | 807 | This function returns a 3-tuple: 808 | - The first element is the videos concatenated along the frames 809 | dimension. This is done so that videos of different lengths can be 810 | processed together (tensors cannot be "jagged", so we cannot have 811 | a dimension for video, and another for frames). 812 | - The second element is contains the targets with no modification. 813 | - The third element is a list of the lengths of the videos in frames. 814 | """ 815 | video, target = zip(*x) # Extract the videos and targets 816 | 817 | # ``video'' is a tuple of length ``batch_size'' 818 | # Each element has shape (channels=3, frames, height, width) 819 | # height and width are expected to be the same across videos, but 820 | # frames can be different. 821 | 822 | # ``target'' is also a tuple of length ``batch_size'' 823 | # Each element is a tuple of the targets for the item. 824 | 825 | i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames 826 | 827 | # This contatenates the videos along the the frames dimension (basically 828 | # playing the videos one after another). The frames dimension is then 829 | # moved to be first. 830 | # Resulting shape is (total frames, channels=3, height, width) 831 | video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1)) 832 | 833 | # Swap dimensions (approximately a transpose) 834 | # Before: target[i][j] is the j-th target of element i 835 | # After: target[i][j] is the i-th target of element j 836 | target = zip(*target) 837 | 838 | return video, target, i 839 | 840 | -------------------------------------------------------------------------------- /echonet/utils/video_segin.py: -------------------------------------------------------------------------------- 1 | """EF regression from video with Segmentation prediction mask inputs """ 2 | 3 | 4 | import math 5 | import os 6 | import time 7 | import shutil 8 | import datetime 9 | import pandas as pd 10 | from PIL import Image 11 | 12 | import click 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import sklearn.metrics 16 | import torch 17 | import torchvision 18 | import tqdm 19 | 20 | import echonet 21 | import echonet.models 22 | 23 | from scipy.special import expit 24 | 25 | @click.command("video_segin") 26 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) 27 | @click.option("--output", type=click.Path(file_okay=False), default=None) 28 | @click.option("--task", type=str, default="EF") 29 | @click.option("--model_name", type=click.Choice(['mc3_18', 'r2plus1d_18', 'r3d_18', 'r2plus1d_18_ncor']), 30 | default="r2plus1d_18") 31 | @click.option("--pretrained/--random", default=True) 32 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) 33 | @click.option("--run_test/--skip_test", default=False) 34 | @click.option("--num_epochs", type=int, default=30) 35 | @click.option("--lr", type=float, default=1e-4) 36 | @click.option("--weight_decay", type=float, default=1e-4) 37 | @click.option("--lr_step_period", type=int, default=15) 38 | @click.option("--frames", type=int, default=32) 39 | @click.option("--period", type=int, default=2) 40 | @click.option("--num_train_patients", type=int, default=None) 41 | @click.option("--num_workers", type=int, default=4) 42 | @click.option("--batch_size", type=int, default=20) 43 | @click.option("--device", type=str, default=None) 44 | @click.option("--seed", type=int, default=0) 45 | @click.option("--full_test/--quick_test", default=True) 46 | @click.option("--val_samp", type=int, default=3) 47 | @click.option("--reduced_set/--full_set", default=True) 48 | @click.option("--rd_label", type=int, default=100) 49 | @click.option("--rd_unlabel", type=int, default=100) 50 | @click.option("--segsource", type=str, default=None) 51 | 52 | def run( 53 | data_dir=None, 54 | output=None, 55 | task="EF", 56 | model_name="r2plus1d_18", 57 | pretrained=True, 58 | weights=None, 59 | run_test=False, 60 | num_epochs=30, 61 | lr=1e-4, 62 | weight_decay=1e-4, 63 | lr_step_period=15, 64 | frames=32, 65 | period=2, 66 | num_train_patients=None, 67 | num_workers=4, 68 | batch_size=20, 69 | device=None, 70 | seed=0, 71 | full_test = True, 72 | val_samp = 3, 73 | reduced_set = True, 74 | rd_label = 100, 75 | rd_unlabel = 100, 76 | segsource = None 77 | ): 78 | 79 | assert segsource, "for video_segin needs segsource option" 80 | 81 | if reduced_set: 82 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))): 83 | print("Generating new file list for ssl dataset") 84 | np.random.seed(0) 85 | 86 | 87 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv")) 88 | data["Split"].map(lambda x: x.upper()) 89 | 90 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName']) 91 | np.random.shuffle(file_name_list) 92 | 93 | label_list = file_name_list[:rd_label] 94 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel] 95 | 96 | data['SSL_SPLIT'] = "EXCLUDE" 97 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED" 98 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED" 99 | 100 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False) 101 | 102 | 103 | # Seed RNGs 104 | np.random.seed(seed) 105 | torch.manual_seed(seed) 106 | 107 | def worker_init_fn(worker_id): 108 | # print("worker id is", torch.utils.data.get_worker_info().id) 109 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2 110 | np.random.seed(np.random.get_state()[1][0] + worker_id) 111 | 112 | # Set default output directory 113 | if output is None: 114 | output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) 115 | os.makedirs(output, exist_ok=True) 116 | 117 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 118 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))): 119 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 120 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 121 | 122 | # Set device for computations 123 | if device is None: 124 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 125 | elif device == "gpu": 126 | device = torch.device("cuda") 127 | elif device == "cpu": 128 | device = torch.device("cpu") 129 | else: 130 | assert 1==2, "wrong parameter for device" 131 | 132 | 133 | model = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained) 134 | model.fc = torch.nn.Linear(model.fc.in_features, 1) 135 | model.fc.bias.data[0] = 55.6 136 | 137 | model_ref = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained) 138 | 139 | #### add additional channel to pre-trained model 140 | model.stem = torch.nn.Sequential( 141 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7), 142 | stride=(1, 2, 2), padding=(0, 3, 3), 143 | bias=False), 144 | torch.nn.BatchNorm3d(45), 145 | torch.nn.ReLU(inplace=True), 146 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 147 | stride=(1, 1, 1), padding=(1, 0, 0), 148 | bias=False), 149 | torch.nn.BatchNorm3d(64), 150 | torch.nn.ReLU(inplace=True)) 151 | 152 | for weight_itr in range(1,6): 153 | model.stem[weight_itr].load_state_dict(model_ref.stem[weight_itr].state_dict()) 154 | 155 | model.stem[0].weight.data[:,:3,:,:,:] = model_ref.stem[0].weight.data[:,:,:,:,:] 156 | model.stem[0].weight.data[:,3,:,:,:] = torch.tensor(np.random.uniform(low = -1, high = 1, size = model.stem[0].weight.data[:,3,:,:,:].shape)).float() 157 | 158 | model = torch.nn.DataParallel(model) 159 | model.to(device) 160 | 161 | if weights is not None: 162 | checkpoint = torch.load(weights) 163 | model.load_state_dict(checkpoint['state_dict']) 164 | 165 | # Set up optimizer 166 | optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 167 | if lr_step_period is None: 168 | lr_step_period = math.inf 169 | scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) 170 | 171 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) 172 | print("mean std", mean, std) 173 | kwargs = {"target_type": task, 174 | "mean": mean, 175 | "std": std, 176 | "length": frames, 177 | "period": period, 178 | } 179 | 180 | # Set up datasets and dataloaders 181 | dataset = {} 182 | 183 | if reduced_set: 184 | dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource)) 185 | dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource)) 186 | else: 187 | dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="", segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource)) 188 | dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource)) 189 | 190 | # Run training and testing loops 191 | with open(os.path.join(output, "log.csv"), "a") as f: 192 | 193 | f.write("Run timestamp: {}\n".format(bkup_tmstmp)) 194 | 195 | epoch_resume = 0 196 | bestLoss = float("inf") 197 | try: 198 | # Attempt to load checkpoint 199 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) 200 | model.load_state_dict(checkpoint['state_dict'], strict = False) 201 | optim.load_state_dict(checkpoint['opt_dict']) 202 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 203 | 204 | np_rndstate_chkpt = checkpoint['np_rndstate'] 205 | trch_rndstate_chkpt = checkpoint['trch_rndstate'] 206 | 207 | np.random.set_state(np_rndstate_chkpt) 208 | torch.set_rng_state(trch_rndstate_chkpt) 209 | 210 | epoch_resume = checkpoint["epoch"] + 1 211 | bestLoss = checkpoint["best_loss"] 212 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 213 | except FileNotFoundError: 214 | f.write("Starting run from scratch\n") 215 | 216 | 217 | for epoch in range(epoch_resume, num_epochs): 218 | print("Epoch #{}".format(epoch), flush=True) 219 | for phase in ['train', 'val']: 220 | 221 | start_time = time.time() 222 | 223 | if device.type == "cuda": 224 | for i in range(torch.cuda.device_count()): 225 | torch.cuda.reset_peak_memory_stats(i) 226 | 227 | if phase == "train": 228 | ds = dataset[phase] 229 | dataloader = torch.utils.data.DataLoader( 230 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn) 231 | 232 | loss, loss_reg, loss_ctr, yhat, y, _, _ = echonet.utils.video_segin.run_epoch(model, 233 | dataloader, 234 | phase == "train", 235 | optim, 236 | device) 237 | 238 | r2_value = sklearn.metrics.r2_score(y, yhat) 239 | 240 | f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, 241 | phase, 242 | loss, 243 | r2_value, 244 | time.time() - start_time, 245 | y.size, 246 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 247 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 248 | batch_size, 249 | loss_reg, 250 | loss_ctr)) 251 | f.flush() 252 | 253 | 254 | else: 255 | ### for validation 256 | ### store seeds 257 | np_rndstate = np.random.get_state() 258 | trch_rndstate = torch.get_rng_state() 259 | 260 | r2_track = [] 261 | loss_track = [] 262 | lossreg_track = [] 263 | losscor_track = [] 264 | 265 | 266 | for val_samp_itr in range(val_samp): 267 | 268 | print("running validation batch for seed =", val_samp_itr) 269 | 270 | np.random.seed(val_samp_itr) 271 | torch.manual_seed(val_samp_itr) 272 | 273 | ds = dataset[phase] 274 | dataloader = torch.utils.data.DataLoader( 275 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 276 | 277 | loss_valit, loss_reg_valit, loss_ctr_valit, yhat, y, _, _ = echonet.utils.video_segin.run_epoch(model, 278 | dataloader, 279 | phase == "train", 280 | optim, 281 | device) 282 | 283 | r2_track.append(sklearn.metrics.r2_score(y, yhat)) 284 | loss_track.append(loss_valit) 285 | lossreg_track.append(loss_reg_valit) 286 | losscor_track.append(loss_ctr_valit) 287 | 288 | r2_value = np.average(np.array(r2_track)) 289 | loss = np.average(np.array(loss_track)) 290 | lossreg = np.average(np.array(lossreg_track)) 291 | losscor = np.average(np.array(losscor_track)) 292 | 293 | f.write("{},{},{},{},{},{},{},{},{},{},{}".format(epoch, 294 | phase, 295 | loss, 296 | r2_value, 297 | time.time() - start_time, 298 | y.size, 299 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 300 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 301 | batch_size, 302 | lossreg, 303 | losscor)) 304 | 305 | for trck_write in range(len(r2_track)): 306 | f.write(",{}".format(r2_track[trck_write])) 307 | 308 | for trck_write in range(len(loss_track)): 309 | f.write(",{}".format(loss_track[trck_write])) 310 | 311 | f.write("\n") 312 | f.flush() 313 | 314 | np.random.set_state(np_rndstate) 315 | torch.set_rng_state(trch_rndstate) 316 | 317 | 318 | scheduler.step() 319 | 320 | # Save checkpoint 321 | save = { 322 | 'epoch': epoch, 323 | 'state_dict': model.state_dict(), 324 | 'period': period, 325 | 'frames': frames, 326 | 'best_loss': bestLoss, 327 | 'loss': loss, 328 | 'r2': r2_value, 329 | 'opt_dict': optim.state_dict(), 330 | 'scheduler_dict': scheduler.state_dict(), 331 | 'np_rndstate': np.random.get_state(), 332 | 'trch_rndstate': torch.get_rng_state() 333 | } 334 | torch.save(save, os.path.join(output, "checkpoint.pt")) 335 | 336 | if lossreg < bestLoss: 337 | print("saved best because {} < {}".format(lossreg, bestLoss)) 338 | torch.save(save, os.path.join(output, "best.pt")) 339 | bestLoss = lossreg 340 | 341 | 342 | if num_epochs != 0: 343 | checkpoint = torch.load(os.path.join(output, "best.pt")) 344 | model.load_state_dict(checkpoint['state_dict'], strict = False) 345 | f.write("Best validation loss {} from epoch {}, R2 {}\n".format(checkpoint["loss"], checkpoint["epoch"], checkpoint["r2"])) 346 | f.flush() 347 | 348 | if run_test: 349 | # for split in ["val", "test"]: 350 | for split in ["test", "val"]: 351 | # Performance without test-time augmentation 352 | 353 | if not full_test: 354 | 355 | for seed_itr in range(5): 356 | np.random.seed(seed_itr) 357 | torch.manual_seed(seed_itr) 358 | 359 | if reduced_set: 360 | dataloader = torch.utils.data.DataLoader( 361 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)), 362 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn) 363 | else: 364 | dataloader = torch.utils.data.DataLoader( 365 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)), 366 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn) 367 | 368 | 369 | loss, loss_reg, loss_ctr, yhat, y, start_frame_record, vidpath_record = echonet.utils.video_segin.run_epoch(model, 370 | dataloader, 371 | False, 372 | None, 373 | device, 374 | run_dir = output, 375 | test_val = split) 376 | 377 | f.write("Seed is {}\n".format(seed_itr)) 378 | f.write("{} - {} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) 379 | f.write("{} - {} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) 380 | f.write("{} - {} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) 381 | f.flush() 382 | 383 | with open(os.path.join(output, "z_{}_{}_s{}_strtfrmchk.csv".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, seed_itr)), "a") as f_start_frame: 384 | for frame_itr in start_frame_record: 385 | f_start_frame.write("{}\n".format(frame_itr)) 386 | f_start_frame.flush() 387 | 388 | with open(os.path.join(output, "z_{}_{}_s{}_vidpthchk.csv".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, seed_itr)), "a") as f_vidpath: 389 | for vidpath_itr in vidpath_record: 390 | f_vidpath.write("{}\n".format(vidpath_itr)) 391 | f_vidpath.flush() 392 | 393 | else: 394 | # Performance with test-time augmentation 395 | if reduced_set: 396 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all", ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)) 397 | else: 398 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all", ssl_postfix="", segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)) 399 | 400 | yhat, y = echonet.utils.video_segin.test_epoch_all(model, 401 | ds, 402 | False, 403 | None, 404 | device, 405 | save_all=True, 406 | block_size=batch_size, 407 | run_dir = output, 408 | test_val = split, 409 | **kwargs, 410 | segsource = segsource) 411 | 412 | f.write("Seed is {} \n".format(seed)) 413 | f.write("{} - {} (all clips, mod) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) 414 | f.write("{} - {} (all clips, mod) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) 415 | f.write("{} - {} (all clips, mod) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) 416 | f.flush() 417 | 418 | 419 | 420 | 421 | def test_epoch_all(model, dataset, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, target_type = None, mean = None, std = None, length = None, period = None, segsource = None): 422 | 423 | assert segsource, "need to feed segsource argument to test_epoch_all" 424 | 425 | model.train(False) 426 | 427 | total = 0 # total training loss 428 | total_reg = 0 429 | total_ncor = 0 430 | 431 | n = 0 # number of videos processed 432 | s1 = 0 # sum of ground truth EF 433 | s2 = 0 # Sum of ground truth EF squared 434 | 435 | yhat = [] 436 | y = [] 437 | 438 | #### some params in the dataloader 439 | 440 | if (mean is None) or (std is None) or (length is None) or (period is None): 441 | assert 1==2, "missing key params" 442 | 443 | max_length = 250 444 | 445 | if run_dir: 446 | 447 | temp_savefile = os.path.join(run_dir, "temp_inference_{}.csv".format(test_val)) 448 | 449 | with torch.set_grad_enabled(False): 450 | orig_filelist = dataset.fnames 451 | 452 | if os.path.isfile(temp_savefile): 453 | exist_data = pd.read_csv(temp_savefile) 454 | exist_file = list(exist_data['fnames']) 455 | target_filelist = sorted(list(set(orig_filelist) - set(exist_file))) 456 | else: 457 | target_filelist = sorted(list(orig_filelist)) 458 | exist_data = pd.DataFrame(columns = ['fnames', 'yhat']) 459 | 460 | for filelistitr_idx in range(len(target_filelist)): 461 | filelistitr = target_filelist[filelistitr_idx] 462 | 463 | video_path = os.path.join(echonet.config.DATA_DIR, "Videos", filelistitr) 464 | ### Get data 465 | video = echonet.utils.loadvideo(video_path).astype(np.float32) 466 | 467 | seg_infer_path = os.path.join("../infer_buffers/{}/{}_infer_cmpct".format(segsource, test_val), filelistitr.replace(".avi", ".npy")) 468 | seg_infer_logits = np.load(seg_infer_path) 469 | seg_infer_probs = expit(seg_infer_logits) 470 | seg_infer_prob_norm = seg_infer_probs * 2 - 1 471 | 472 | seg_infer_prob_norm = np.expand_dims(seg_infer_prob_norm, axis=0) 473 | 474 | if isinstance(mean, (float, int)): 475 | video -= mean 476 | else: 477 | video -= mean.reshape(3, 1, 1, 1) 478 | 479 | if isinstance(std, (float, int)): 480 | video /= std 481 | else: 482 | video /= std.reshape(3, 1, 1, 1) 483 | 484 | c, f, h, w = video.shape 485 | if length is None: 486 | # Take as many frames as possible 487 | length = f // period 488 | else: 489 | # Take specified number of frames 490 | length = length 491 | 492 | if max_length is not None: 493 | # Shorten videos to max_length 494 | length = min(length, max_length) 495 | 496 | f_old = f 497 | 498 | if f < length * period: 499 | # Pad video with frames filled with zeros if too short 500 | # 0 represents the mean color (dark grey), since this is after normalization 501 | video = np.concatenate((video, np.zeros((c, length * period - f, h, w), video.dtype)), axis=1) 502 | seg_infer_prob_norm = np.concatenate((seg_infer_prob_norm, np.ones((1, length * period - f, h, w), video.dtype) * -1) , axis=1) 503 | c, f, h, w = video.shape # pylint: disable=E0633 504 | 505 | start = np.arange(f - (length - 1) * period) 506 | #### Do looping starting from here 507 | 508 | reg1 = [] 509 | n_clips = start.shape[0] 510 | batch = 1 511 | for s_itr in range(0, start.shape[0], block_size): 512 | print("{}, processing file {} out of {}, block {} out of {}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), filelistitr_idx, len(target_filelist), s_itr, start.shape[0]), flush=True) 513 | # print("s range", start[s_itr: s_itr + block_size]) 514 | # print("frame range", s + period * np.arange(length)) 515 | vid_samp = tuple(video[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size]) 516 | seg_infer_samp = tuple(seg_infer_prob_norm[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size]) 517 | 518 | vid_in = np.concatenate((np.stack(vid_samp), np.stack(seg_infer_samp)), axis=1) 519 | 520 | X1 = torch.tensor(np.stack(vid_in)) 521 | if X1.dtype == torch.double: 522 | X1 = X1.float() 523 | 524 | X1 = X1.to(device) 525 | 526 | if device.type == "cuda": 527 | all_output = model(X1) 528 | else: 529 | #### we only ever use cpu for testing 530 | all_output = torch.ones((X1.shape[0])) 531 | 532 | reg1.append(all_output.detach().cpu().numpy()) 533 | 534 | reg1 = np.vstack(reg1) 535 | reg1_mean = reg1.reshape(batch, n_clips, -1).mean(1) 536 | 537 | exist_data = exist_data.append({'fnames':filelistitr, 'yhat':reg1_mean[0,0]}, ignore_index=True) 538 | 539 | if filelistitr_idx % 20 == 0: 540 | exist_data.to_csv(temp_savefile, index = False) 541 | 542 | label_data_path = os.path.join(echonet.config.DATA_DIR, "FileList.csv") 543 | label_data = pd.read_csv(label_data_path) 544 | label_data_select = label_data[['FileName','EF']] 545 | label_data_select.columns = ['fnames','EF'] 546 | with_predict = exist_data.merge(label_data_select, on='fnames') 547 | 548 | predict_out_path = os.path.join(run_dir, "{}_predictions.csv".format(test_val)) 549 | with_predict.to_csv(predict_out_path, index=False) 550 | 551 | 552 | return with_predict['yhat'].to_numpy(), with_predict['EF'].to_numpy() 553 | 554 | 555 | def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None): 556 | 557 | model.train(train) 558 | 559 | total = 0 # total training loss 560 | total_reg = 0 561 | total_ncor = 0 562 | 563 | n = 0 # number of videos processed 564 | s1 = 0 # sum of ground truth EF 565 | s2 = 0 # Sum of ground truth EF squared 566 | 567 | yhat = [] 568 | y = [] 569 | start_frame_record = [] 570 | vidpath_record = [] 571 | 572 | with torch.set_grad_enabled(train): 573 | with tqdm.tqdm(total=len(dataloader)) as pbar: 574 | # samples_cnt = 0 575 | for (X, outcome, start_frame, video_path, _, _) in dataloader: 576 | 577 | if not train: 578 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy()) 579 | vidpath_record.append(video_path) 580 | 581 | y.append(outcome.detach().cpu().numpy()) 582 | 583 | if X.dtype == torch.double: 584 | X = X.float() 585 | 586 | X = X.to(device) 587 | 588 | outcome = outcome.to(device) 589 | 590 | s1 += outcome.sum() 591 | s2 += (outcome ** 2).sum() 592 | 593 | assert block_size is None, "block_size should be none, not used" 594 | 595 | if device.type == "cuda": 596 | all_output = model(X) 597 | else: 598 | ### We only ever use cpu for testing 599 | all_output = model(X) 600 | 601 | 602 | loss_cor_item = 0 603 | total_ncor = 0 604 | 605 | loss_reg = torch.nn.functional.mse_loss(all_output.view(-1), outcome) 606 | loss = loss_reg 607 | 608 | yhat.append(all_output.view(-1).to("cpu").detach().numpy()) 609 | 610 | if train: 611 | optim.zero_grad() 612 | loss.backward() 613 | optim.step() 614 | 615 | total += loss.item() * outcome.size(0) 616 | total_reg += loss_reg.item() * outcome.size(0) 617 | 618 | n += outcome.size(0) 619 | 620 | pbar.set_postfix_str("{:.2f} {:.2f} {:.4f} ({:.2f}) / {:.2f} {}".format(total / n, loss_reg.item(), loss_cor_item, loss.item(), s2 / n - (s1 / n) ** 2, 0)) 621 | pbar.update() 622 | 623 | if not save_all: 624 | yhat = np.concatenate(yhat) 625 | if not train: 626 | start_frame_record = np.concatenate(start_frame_record) 627 | 628 | y = np.concatenate(y) 629 | 630 | return total / n, total_reg / n, total_ncor / n, yhat, y, start_frame_record, vidpath_record 631 | 632 | 633 | 634 | -------------------------------------------------------------------------------- /echonet/utils/vidsegin_teachstd_kd.py: -------------------------------------------------------------------------------- 1 | """Teacher Student Distillation""" 2 | 3 | import math 4 | import os 5 | import time 6 | import shutil 7 | import datetime 8 | import pandas as pd 9 | import cv2 10 | 11 | import click 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import sklearn.metrics 15 | import torch 16 | import torchvision 17 | import tqdm 18 | import subprocess 19 | 20 | import echonet 21 | import echonet.models 22 | 23 | @click.command("vidsegin_teachstd_kd") 24 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) 25 | @click.option("--output", type=click.Path(file_okay=False), default=None) 26 | @click.option("--task", type=str, default="EF") 27 | @click.option("--model_name", type=click.Choice(['mc3_18', 'r2plus1d_18', "r2plus1d_18_segin", 'r3d_18', 'r2plus1d_18_ncor']), 28 | default="r2plus1d_18") 29 | @click.option("--pretrained/--random", default=True) 30 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) 31 | @click.option("--weights_0", type=click.Path(exists=True, dir_okay=False), default=None) 32 | @click.option("--run_test/--skip_test", default=False) 33 | @click.option("--num_epochs", type=int, default=30) 34 | @click.option("--lr", type=float, default=1e-4) 35 | @click.option("--weight_decay", type=float, default=1e-4) 36 | @click.option("--lr_step_period", type=int, default=15) 37 | @click.option("--frames", type=int, default=32) 38 | @click.option("--period", type=int, default=2) 39 | @click.option("--num_train_patients", type=int, default=None) 40 | @click.option("--num_workers", type=int, default=4) 41 | @click.option("--batch_size", type=int, default=20) 42 | @click.option("--device", type=str, default=None) 43 | @click.option("--seed", type=int, default=0) 44 | @click.option("--full_test/--quick_test", default=False) 45 | @click.option("--val_samp", type=int, default=3) 46 | @click.option("--reduced_set/--full_set", default=True) 47 | @click.option("--rd_label", type=int, default=920) 48 | @click.option("--rd_unlabel", type=int, default=6440) 49 | @click.option("--max_block", type=int, default=20) 50 | @click.option("--segsource", type=str, default=None) 51 | @click.option("--w_unlb", type=float, default=2.5) 52 | @click.option("--batch_size_unlb", type=int, default=10) 53 | @click.option("--notcamus/--camus", default=True) 54 | 55 | def run( 56 | data_dir=None, 57 | output=None, 58 | task="EF", 59 | model_name="r2plus1d_18", 60 | pretrained=True, 61 | weights=None, 62 | weights_0=None, 63 | run_test=False, 64 | num_epochs=30, 65 | lr=1e-4, 66 | weight_decay=1e-4, 67 | lr_step_period=15, 68 | frames=32, 69 | period=2, 70 | num_train_patients=None, 71 | num_workers=4, 72 | batch_size=20, 73 | device=None, 74 | seed=0, 75 | full_test = False, 76 | val_samp = 3, 77 | reduced_set = True, 78 | rd_label = 920, 79 | rd_unlabel = 6440, 80 | max_block = 20, 81 | segsource = None, 82 | w_unlb = 2.5, 83 | batch_size_unlb=10, 84 | notcamus = True 85 | ): 86 | 87 | assert segsource, "function needs segsource option" 88 | 89 | 90 | if reduced_set: 91 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))): 92 | print("Generating new file list for ssl dataset") 93 | np.random.seed(0) 94 | 95 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv")) 96 | data["Split"].map(lambda x: x.upper()) 97 | 98 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName']) 99 | np.random.shuffle(file_name_list) 100 | 101 | label_list = file_name_list[:rd_label] 102 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel] 103 | 104 | data['SSL_SPLIT'] = "EXCLUDE" 105 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED" 106 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED" 107 | 108 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False) 109 | 110 | 111 | # Seed RNGs 112 | np.random.seed(seed) 113 | torch.manual_seed(seed) 114 | 115 | def worker_init_fn(worker_id): 116 | # print("worker id is", torch.utils.data.get_worker_info().id) 117 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2 118 | np.random.seed(np.random.get_state()[1][0] + worker_id) 119 | 120 | # Set default output directory 121 | if output is None: 122 | output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) 123 | os.makedirs(output, exist_ok=True) 124 | 125 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 126 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))): 127 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 128 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp))) 129 | 130 | # Set device for computations 131 | if device is None: 132 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 133 | elif device == "gpu": 134 | device = torch.device("cuda") 135 | elif device == "cpu": 136 | device = torch.device("cpu") 137 | else: 138 | assert 1==2, "wrong parameter for device" 139 | 140 | model_0 = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained) 141 | model_0.fc = torch.nn.Linear(model_0.fc.in_features, 1) 142 | model_0.fc.bias.data[0] = 55.6 143 | 144 | model_ref = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained) 145 | 146 | model_0.stem = torch.nn.Sequential( 147 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7), 148 | stride=(1, 2, 2), padding=(0, 3, 3), 149 | bias=False), 150 | torch.nn.BatchNorm3d(45), 151 | torch.nn.ReLU(inplace=True), 152 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 153 | stride=(1, 1, 1), padding=(1, 0, 0), 154 | bias=False), 155 | torch.nn.BatchNorm3d(64), 156 | torch.nn.ReLU(inplace=True)) 157 | 158 | for weight_itr in range(1,6): 159 | model_0.stem[weight_itr].load_state_dict(model_ref.stem[weight_itr].state_dict()) 160 | 161 | model_0.stem[0].weight.data[:,:3,:,:,:] = model_ref.stem[0].weight.data[:,:,:,:,:] 162 | model_0.stem[0].weight.data[:,3,:,:,:] = torch.tensor(np.random.uniform(low = -1, high = 1, size = model_0.stem[0].weight.data[:,3,:,:,:].shape)).float() 163 | 164 | model_0 = torch.nn.DataParallel(model_0) 165 | 166 | model_0.to(device) 167 | 168 | 169 | if weights is not None: 170 | checkpoint = torch.load(weights) 171 | model.load_state_dict(checkpoint['state_dict']) 172 | 173 | ### we initialize teacher and student weights. 174 | if weights_0 is not None: 175 | checkpoint_0 = torch.load(weights_0) 176 | if checkpoint_0.get("state_dict_0"): 177 | ## initialize teacher weights 178 | print("loading from state_dict_0") 179 | model_0.load_state_dict(checkpoint_0['state_dict_0']) 180 | 181 | ## initialize student weights where transferable to speed up training 182 | state_dict = checkpoint_0['state_dict_0'] 183 | from collections import OrderedDict 184 | new_state_dict = OrderedDict() 185 | for k, v in state_dict.items(): 186 | name = k[7:] # remove `module.` 187 | new_state_dict[name] = v 188 | 189 | model = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained) 190 | model.fc = torch.nn.Linear(model.fc.in_features, 1) 191 | model.fc.bias.data[0] = 55.6 192 | 193 | model.stem = torch.nn.Sequential( 194 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7), 195 | stride=(1, 2, 2), padding=(0, 3, 3), 196 | bias=False), 197 | torch.nn.BatchNorm3d(45), 198 | torch.nn.ReLU(inplace=True), 199 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 200 | stride=(1, 1, 1), padding=(1, 0, 0), 201 | bias=False), 202 | torch.nn.BatchNorm3d(64), 203 | torch.nn.ReLU(inplace=True)) 204 | 205 | model.load_state_dict(new_state_dict) 206 | 207 | model.stem = torch.nn.Sequential( 208 | torch.nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 209 | stride=(1, 2, 2), padding=(0, 3, 3), 210 | bias=False), 211 | torch.nn.BatchNorm3d(45), 212 | torch.nn.ReLU(inplace=True), 213 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 214 | stride=(1, 1, 1), padding=(1, 0, 0), 215 | bias=False), 216 | torch.nn.BatchNorm3d(64), 217 | torch.nn.ReLU(inplace=True)) 218 | 219 | for weight_itr in range(1,6): 220 | model.stem[weight_itr].load_state_dict(model_0.module.stem[weight_itr].state_dict()) 221 | 222 | model.stem[0].weight.data[:,:3,:,:,:] = model_0.module.stem[0].weight.data[:,:3,:,:,:] 223 | 224 | model = torch.nn.DataParallel(model) 225 | model.to(device) 226 | 227 | elif checkpoint_0.get("state_dict"): 228 | ## initialize teacher weights 229 | print("loading from state_dict") 230 | model_0.load_state_dict(checkpoint_0['state_dict']) 231 | 232 | ## initialize student weights where transferable to speed up training 233 | state_dict = checkpoint_0['state_dict'] 234 | from collections import OrderedDict 235 | new_state_dict = OrderedDict() 236 | for k, v in state_dict.items(): 237 | name = k[7:] # remove `module.` 238 | new_state_dict[name] = v 239 | 240 | model = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained) 241 | model.fc = torch.nn.Linear(model.fc.in_features, 1) 242 | model.fc.bias.data[0] = 55.6 243 | 244 | model.stem = torch.nn.Sequential( 245 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7), 246 | stride=(1, 2, 2), padding=(0, 3, 3), 247 | bias=False), 248 | torch.nn.BatchNorm3d(45), 249 | torch.nn.ReLU(inplace=True), 250 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 251 | stride=(1, 1, 1), padding=(1, 0, 0), 252 | bias=False), 253 | torch.nn.BatchNorm3d(64), 254 | torch.nn.ReLU(inplace=True)) 255 | 256 | model.load_state_dict(new_state_dict) 257 | 258 | model.stem = torch.nn.Sequential( 259 | torch.nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 260 | stride=(1, 2, 2), padding=(0, 3, 3), 261 | bias=False), 262 | torch.nn.BatchNorm3d(45), 263 | torch.nn.ReLU(inplace=True), 264 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 265 | stride=(1, 1, 1), padding=(1, 0, 0), 266 | bias=False), 267 | torch.nn.BatchNorm3d(64), 268 | torch.nn.ReLU(inplace=True)) 269 | 270 | for weight_itr in range(1,6): 271 | model.stem[weight_itr].load_state_dict(model_0.module.stem[weight_itr].state_dict()) 272 | 273 | model.stem[0].weight.data[:,:3,:,:,:] = model_0.module.stem[0].weight.data[:,:3,:,:,:] 274 | 275 | model = torch.nn.DataParallel(model) 276 | model.to(device) 277 | else: 278 | assert 1==2, "missing key" 279 | 280 | 281 | optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 282 | if lr_step_period is None: 283 | lr_step_period = math.inf 284 | scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) 285 | 286 | optim_0 = torch.optim.SGD(model_0.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 287 | if lr_step_period is None: 288 | lr_step_period = math.inf 289 | scheduler_0 = torch.optim.lr_scheduler.StepLR(optim_0, lr_step_period) 290 | 291 | 292 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo_tskd(root=data_dir, split="train")) 293 | print("mean std", mean, std) 294 | kwargs = {"target_type": ["EF", "IEKD"], 295 | "mean": mean, 296 | "std": std, 297 | "length": frames, 298 | "period": period, 299 | } 300 | 301 | kwargs_testall = {"target_type": "EF", 302 | "mean": mean, 303 | "std": std, 304 | "length": frames, 305 | "period": period, 306 | } 307 | 308 | dataset = {} 309 | dataset_trainsub = {} 310 | if reduced_set: 311 | dataset_trainsub['lb'] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource)) 312 | dataset_trainsub["unlb_0"] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource)) 313 | dataset["train"] = dataset_trainsub 314 | dataset["val"] = echonet.datasets.Echo_tskd(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource)) 315 | else: 316 | dataset["train"] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="", segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource)) 317 | dataset["val"] = echonet.datasets.Echo_tskd(root=data_dir, split="val", **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource)) 318 | 319 | 320 | # Run training and testing loops 321 | with open(os.path.join(output, "log.csv"), "a") as f: 322 | 323 | f.write("Run timestamp: {}\n".format(bkup_tmstmp)) 324 | 325 | epoch_resume = 0 326 | bestLoss = float("inf") 327 | try: 328 | # Attempt to load checkpoint 329 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) 330 | model.load_state_dict(checkpoint['state_dict'], strict = False) 331 | optim.load_state_dict(checkpoint['opt_dict']) 332 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 333 | 334 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False) 335 | optim_0.load_state_dict(checkpoint['opt_dict_0']) 336 | scheduler_0.load_state_dict(checkpoint['scheduler_dict_0']) 337 | 338 | np_rndstate_chkpt = checkpoint['np_rndstate'] 339 | trch_rndstate_chkpt = checkpoint['trch_rndstate'] 340 | 341 | np.random.set_state(np_rndstate_chkpt) 342 | torch.set_rng_state(trch_rndstate_chkpt) 343 | 344 | epoch_resume = checkpoint["epoch"] + 1 345 | bestLoss = checkpoint["best_loss"] 346 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 347 | except FileNotFoundError: 348 | f.write("Starting run from scratch\n") 349 | 350 | 351 | for epoch in range(epoch_resume, num_epochs): 352 | print("Epoch #{}".format(epoch), flush=True) 353 | for phase in ['train', 'val']: 354 | 355 | start_time = time.time() 356 | 357 | if device.type == "cuda": 358 | for i in range(torch.cuda.device_count()): 359 | torch.cuda.reset_peak_memory_stats(i) 360 | 361 | if phase == "train": 362 | ds_lb = dataset[phase]['lb'] 363 | dataloader_lb = torch.utils.data.DataLoader( 364 | ds_lb, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn) 365 | 366 | ds_unlb_0 = dataset[phase]['unlb_0'] 367 | dataloader_unlb_0 = torch.utils.data.DataLoader( 368 | ds_unlb_0, batch_size=batch_size_unlb, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn) 369 | 370 | 371 | total, total_reg, total_reg_reg, total_reg_unlb, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch(model = model, 372 | model_0 = model_0, 373 | dataloader = dataloader_lb, 374 | dataloader_unlb_0 = dataloader_unlb_0, 375 | train = phase == "train", 376 | optim = optim, 377 | optim_0 = optim_0, 378 | device = device, 379 | w_unlb = w_unlb) 380 | 381 | 382 | r2_value = sklearn.metrics.r2_score(y, yhat) 383 | 384 | f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, 385 | phase, 386 | total, 387 | total_reg_reg, 388 | r2_value, 389 | total_reg_unlb, 390 | time.time() - start_time, 391 | y.size, 392 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 393 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 394 | batch_size)) 395 | f.flush() 396 | 397 | # print("successful run until exit") 398 | # exit() 399 | 400 | else: 401 | ### for validation 402 | ### store seeds 403 | np_rndstate = np.random.get_state() 404 | trch_rndstate = torch.get_rng_state() 405 | 406 | r2_track = [] 407 | lossreg_track = [] 408 | 409 | 410 | for val_samp_itr in range(val_samp): 411 | 412 | print("running validation batch for seed =", val_samp_itr) 413 | 414 | np.random.seed(val_samp_itr) 415 | torch.manual_seed(val_samp_itr) 416 | 417 | ds = dataset[phase] 418 | dataloader = torch.utils.data.DataLoader( 419 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) 420 | 421 | 422 | total, total_reg, total_reg_reg, total_reg_unlb, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch(model = model, 423 | model_0 = model_0, 424 | dataloader = dataloader, 425 | dataloader_unlb_0 = None, 426 | train = phase == "train", 427 | optim = optim, 428 | optim_0 = optim_0, 429 | device = device, 430 | w_unlb = w_unlb) 431 | 432 | 433 | r2_track.append(sklearn.metrics.r2_score(y, yhat)) 434 | lossreg_track.append(total_reg_reg) 435 | 436 | 437 | r2_value = np.average(np.array(r2_track)) 438 | lossreg = np.average(np.array(lossreg_track)) 439 | 440 | f.write("{},{},{},{},{},{},{},{},{},{}".format(epoch, 441 | phase, 442 | lossreg, 443 | r2_value, 444 | total_reg_unlb, 445 | time.time() - start_time, 446 | y.size, 447 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 448 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 449 | batch_size)) 450 | 451 | for trck_write in range(len(r2_track)): 452 | f.write(",{}".format(r2_track[trck_write])) 453 | 454 | for trck_write in range(len(lossreg_track)): 455 | f.write(",{}".format(lossreg_track[trck_write])) 456 | 457 | 458 | f.write("\n") 459 | f.flush() 460 | 461 | np.random.set_state(np_rndstate) 462 | torch.set_rng_state(trch_rndstate) 463 | 464 | 465 | scheduler.step() 466 | scheduler_0.step() 467 | 468 | # Save checkpoint 469 | save = { 470 | 'epoch': epoch, 471 | 'state_dict': model.state_dict(), 472 | 'state_dict_0': model_0.state_dict(), 473 | 'period': period, 474 | 'frames': frames, 475 | 'best_loss': bestLoss, 476 | 'loss': lossreg, 477 | 'r2': r2_value, 478 | 'opt_dict': optim.state_dict(), 479 | 'opt_dict_0': optim_0.state_dict(), 480 | 'scheduler_dict': scheduler.state_dict(), 481 | 'scheduler_dict_0': scheduler_0.state_dict(), 482 | 'np_rndstate': np.random.get_state(), 483 | 'trch_rndstate': torch.get_rng_state() 484 | } 485 | torch.save(save, os.path.join(output, "checkpoint.pt")) 486 | 487 | #### save based on reg loss 488 | if lossreg < bestLoss: 489 | print("saved best because {} < {}".format(lossreg, bestLoss)) 490 | torch.save(save, os.path.join(output, "best.pt")) 491 | bestLoss = lossreg 492 | 493 | 494 | # Load best weights 495 | if num_epochs != 0: 496 | checkpoint = torch.load(os.path.join(output, "best.pt")) 497 | model.load_state_dict(checkpoint['state_dict'], strict = False) 498 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False) 499 | f.write("Best validation loss {} from epoch {}, R2 {}\n".format(checkpoint["loss"], checkpoint["epoch"], checkpoint["r2"])) 500 | f.flush() 501 | 502 | if run_test: 503 | if notcamus: 504 | split_list = ["test", "val"] 505 | # split_list = ["test"] 506 | else: 507 | split_list = ["train", "test"] 508 | 509 | for split in split_list: 510 | # Performance without test-time augmentation 511 | 512 | if not full_test: 513 | 514 | for seed_itr in range(5): 515 | np.random.seed(seed_itr) 516 | torch.manual_seed(seed_itr) 517 | 518 | dataloader = torch.utils.data.DataLoader( 519 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)), 520 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn) 521 | 522 | 523 | 524 | loss, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch_val(model = model, 525 | model_0 = model_0, 526 | dataloader = dataloader, 527 | train = False, 528 | optim = None, 529 | optim_0 = None, 530 | device = device) 531 | 532 | f.write("Seed is {}\n".format(seed_itr)) 533 | f.write("{} - {} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) 534 | f.write("{} - {} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) 535 | f.write("{} - {} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) 536 | f.flush() 537 | 538 | 539 | else: 540 | # Performance with test-time augmentation 541 | if reduced_set: 542 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, clips="all", ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel)) 543 | else: 544 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, clips="all", ssl_postfix="") 545 | 546 | yhat, y = echonet.utils.vidsegin_teachstd_kd.test_epoch_all(model, 547 | ds, 548 | False, 549 | None, 550 | device, 551 | save_all=True, 552 | block_size=batch_size, 553 | run_dir = output, 554 | test_val = split, 555 | **kwargs) 556 | 557 | f.write("Seed is {} \n".format(seed)) 558 | f.write("{} - {} (all clips, mod) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) 559 | f.write("{} - {} (all clips, mod) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) 560 | f.write("{} - {} (all clips, mod) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) 561 | f.flush() 562 | 563 | 564 | 565 | 566 | def test_epoch_all(model, dataset, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, mean = None, std = None, length = None, period = None, target_type = None): 567 | model.train(False) 568 | 569 | total = 0 # total training loss 570 | total_reg = 0 571 | total_ncor = 0 572 | 573 | n = 0 # number of videos processed 574 | s1 = 0 # sum of ground truth EF 575 | s2 = 0 # Sum of ground truth EF squared 576 | 577 | yhat = [] 578 | y = [] 579 | 580 | #### some params in the dataloader 581 | 582 | if (mean is None) or (std is None) or (length is None) or (period is None): 583 | assert 1==2, "missing key params" 584 | 585 | max_length = 250 586 | 587 | if run_dir: 588 | 589 | temp_savefile = os.path.join(run_dir, "temp_inference_{}.csv".format(test_val)) 590 | 591 | with torch.set_grad_enabled(False): 592 | orig_filelist = dataset.fnames 593 | 594 | if os.path.isfile(temp_savefile): 595 | exist_data = pd.read_csv(temp_savefile) 596 | exist_file = list(exist_data['fnames']) 597 | target_filelist = sorted(list(set(orig_filelist) - set(exist_file))) 598 | else: 599 | target_filelist = sorted(list(orig_filelist)) 600 | exist_data = pd.DataFrame(columns = ['fnames', 'yhat']) 601 | 602 | for filelistitr_idx in range(len(target_filelist)): 603 | filelistitr = target_filelist[filelistitr_idx] 604 | 605 | video_path = os.path.join(echonet.config.DATA_DIR, "Videos", filelistitr) 606 | ### Get data 607 | video = echonet.utils.loadvideo(video_path).astype(np.float32) 608 | 609 | if isinstance(mean, (float, int)): 610 | video -= mean 611 | else: 612 | video -= mean.reshape(3, 1, 1, 1) 613 | 614 | if isinstance(std, (float, int)): 615 | video /= std 616 | else: 617 | video /= std.reshape(3, 1, 1, 1) 618 | 619 | c, f, h, w = video.shape 620 | if length is None: 621 | # Take as many frames as possible 622 | length = f // period 623 | else: 624 | # Take specified number of frames 625 | length = length 626 | 627 | if max_length is not None: 628 | # Shorten videos to max_length 629 | length = min(length, max_length) 630 | 631 | if f < length * period: 632 | # Pad video with frames filled with zeros if too short 633 | # 0 represents the mean color (dark grey), since this is after normalization 634 | video = np.concatenate((video, np.zeros((c, length * period - f, h, w), video.dtype)), axis=1) 635 | c, f, h, w = video.shape # pylint: disable=E0633 636 | 637 | start = np.arange(f - (length - 1) * period) 638 | #### Do looping starting from here 639 | 640 | reg1 = [] 641 | n_clips = start.shape[0] 642 | batch = 1 643 | for s_itr in range(0, start.shape[0], block_size): 644 | print("{}, processing file {} out of {}, block {} out of {}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), filelistitr_idx, len(target_filelist), s_itr, start.shape[0]), flush=True) 645 | # print("s range", start[s_itr: s_itr + block_size]) 646 | # print("frame range", s + period * np.arange(length)) 647 | vid_samp = tuple(video[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size]) 648 | X1 = torch.tensor(np.stack(vid_samp)) 649 | X1 = X1.to(device) 650 | 651 | if device.type == "cuda": 652 | all_output = model(X1) 653 | else: 654 | #### we only ever use cpu for testing 655 | all_output = torch.ones((X1.shape[0])) 656 | 657 | reg1.append(all_output[0].detach().cpu().numpy()) 658 | 659 | reg1 = np.vstack(reg1) 660 | reg1_mean = reg1.reshape(batch, n_clips, -1).mean(1) 661 | 662 | exist_data = exist_data.append({'fnames':filelistitr, 'yhat':reg1_mean[0,0]}, ignore_index=True) 663 | 664 | if filelistitr_idx % 20 == 0: 665 | exist_data.to_csv(temp_savefile, index = False) 666 | 667 | label_data_path = os.path.join(echonet.config.DATA_DIR, "FileList.csv") 668 | label_data = pd.read_csv(label_data_path) 669 | label_data_select = label_data[['FileName','EF']] 670 | label_data_select.columns = ['fnames','EF'] 671 | with_predict = exist_data.merge(label_data_select, on='fnames') 672 | 673 | predict_out_path = os.path.join(run_dir, "{}_predictions.csv".format(test_val)) 674 | with_predict.to_csv(predict_out_path, index=False) 675 | 676 | 677 | # print(with_predict) 678 | # exit() 679 | return with_predict['yhat'].to_numpy(), with_predict['EF'].to_numpy() 680 | 681 | 682 | def run_epoch(model, 683 | model_0, 684 | dataloader, 685 | dataloader_unlb_0, 686 | train, 687 | optim, 688 | optim_0, 689 | device, 690 | save_all=False, 691 | block_size=None, 692 | run_dir = None, 693 | test_val = None, 694 | w_unlb = 0): 695 | 696 | 697 | total = 0 # total training loss 698 | total_reg = 0 699 | total_reg_reg = 0 700 | total_loss_reg_reg_unlb = 0 701 | 702 | 703 | n = 0 # number of videos processed 704 | n_frm = 0 705 | s1 = 0 # sum of ground truth EF 706 | s2 = 0 # Sum of ground truth EF squared 707 | 708 | yhat = [] 709 | yhat_seg = [] 710 | y = [] 711 | start_frame_record = [] 712 | vidpath_record = [] 713 | 714 | if dataloader_unlb_0: 715 | dataloader_unlb_0_itr = iter(dataloader_unlb_0) 716 | 717 | with torch.set_grad_enabled(train): 718 | with tqdm.tqdm(total=len(dataloader)) as pbar: 719 | enum_idx = 0 720 | for (X_all, outcome, seg_info, start_frame, video_path, _, _) in dataloader: 721 | enum_idx = enum_idx + 1 722 | 723 | 724 | if not train: 725 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy()) 726 | vidpath_record.append(video_path) 727 | 728 | y.append(outcome.detach().cpu().numpy()) 729 | 730 | if X_all.dtype == torch.double: 731 | X_all = X_all.float() 732 | 733 | X_noseg = X_all[:,:3,...].to(device) 734 | X_wseg = X_all.to(device) 735 | 736 | outcome = outcome.to(device) 737 | 738 | s1 += outcome.sum() 739 | s2 += (outcome ** 2).sum() 740 | 741 | if dataloader_unlb_0: 742 | 743 | (X_all_unlb, outcome_unlb, seg_info_unlb, start_frame_unlb, video_path_unlb, _, _) = dataloader_unlb_0_itr.next() 744 | if X_all_unlb.dtype == torch.double: 745 | X_all_unlb = X_all_unlb.float() 746 | 747 | X_noseg_unlb = X_all_unlb[:,:3,...].to(device) 748 | X_wseg_unlb = X_all_unlb.to(device) 749 | 750 | 751 | if train: 752 | model.train(True) 753 | else: 754 | model.train(False) 755 | model_0.train(False) 756 | 757 | if train: 758 | all_output_unlb = model(X_noseg_unlb) 759 | else: 760 | with torch.no_grad(): 761 | all_output_unlb = model(X_noseg_unlb) 762 | 763 | y_pred_unlb = all_output_unlb[0] 764 | 765 | with torch.no_grad(): 766 | all_output_seg_unlb = model_0(X_wseg_unlb) 767 | 768 | y_pred_seg_unlb = all_output_seg_unlb[0] 769 | 770 | y_pred_avg_unlb = y_pred_seg_unlb.view(-1).detach() 771 | 772 | loss_reg_reg_unlb_vid = torch.nn.functional.mse_loss(y_pred_unlb.view(-1), y_pred_avg_unlb) 773 | loss_reg_reg_unlb_seg = torch.nn.functional.mse_loss(y_pred_seg_unlb.view(-1), y_pred_avg_unlb) 774 | 775 | loss_reg_reg_unlb = loss_reg_reg_unlb_vid / 2 776 | 777 | loss_reg_reg_unlb_item = loss_reg_reg_unlb.item() 778 | 779 | else: 780 | loss_reg_reg_unlb_item = 0 781 | 782 | 783 | #### train video model 784 | if train: 785 | model.train(True) 786 | # attKD.train(True) 787 | else: 788 | model.train(False) 789 | # attKD.train(False) 790 | model_0.train(False) 791 | 792 | if train: 793 | all_output = model(X_noseg) 794 | else: 795 | with torch.no_grad(): 796 | all_output = model(X_noseg) 797 | 798 | y_pred = all_output[0] 799 | 800 | with torch.no_grad(): 801 | all_output_seg = model_0(X_wseg) 802 | 803 | y_pred_seg = all_output_seg[0] 804 | 805 | 806 | loss_reg_reg = torch.nn.functional.mse_loss(y_pred.view(-1), outcome) 807 | yhat.append(y_pred.view(-1).to("cpu").detach().numpy()) 808 | 809 | if dataloader_unlb_0: 810 | loss_reg = loss_reg_reg + w_unlb * loss_reg_reg_unlb 811 | else: 812 | loss_reg = loss_reg_reg 813 | 814 | 815 | if train: 816 | optim.zero_grad() 817 | loss_reg.backward() 818 | optim.step() 819 | 820 | total_reg += loss_reg.item() * outcome.size(0) 821 | total_reg_reg += loss_reg_reg.item() * outcome.size(0) 822 | 823 | 824 | loss_reg_item = loss_reg.item() 825 | loss_reg_reg_item = loss_reg_reg.item() 826 | 827 | 828 | total = total_reg 829 | if dataloader_unlb_0: 830 | total_loss_reg_reg_unlb = total_loss_reg_reg_unlb + loss_reg_reg_unlb.item() 831 | 832 | n += outcome.size(0) 833 | 834 | pbar.set_postfix_str("total {:.4f} / reg {:.4f} ({:.4f}) regrg {:.4f} ({:.4f}) regulb {:.4f} ({:.4f})".format(total / n , 835 | total_reg / n , loss_reg_item, 836 | total_reg_reg / n , loss_reg_reg_item, 837 | total_loss_reg_reg_unlb / n, loss_reg_reg_unlb_item 838 | )) 839 | pbar.update() 840 | yhat_cat = np.concatenate(yhat) 841 | y = np.concatenate(y) 842 | 843 | 844 | return (total / n , 845 | total_reg / n , 846 | total_reg_reg / n , 847 | total_loss_reg_reg_unlb / n , 848 | yhat_cat, y) 849 | 850 | 851 | 852 | 853 | 854 | def run_epoch_val(model, model_0, dataloader, train, optim, optim_0, device, save_all=False, block_size=None): 855 | """Run one epoch of training/evaluation for segmentation. 856 | 857 | Args: 858 | model (torch.nn.Module): Model to train/evaulate. 859 | dataloder (torch.utils.data.DataLoader): Dataloader for dataset. 860 | train (bool): Whether or not to train model. 861 | optim (torch.optim.Optimizer): Optimizer 862 | device (torch.device): Device to run on 863 | save_all (bool, optional): If True, return predictions for all 864 | test-time augmentations separately. If False, return only 865 | the mean prediction. 866 | Defaults to False. 867 | block_size (int or None, optional): Maximum number of augmentations 868 | to run on at the same time. Use to limit the amount of memory 869 | used. If None, always run on all augmentations simultaneously. 870 | Default is None. 871 | """ 872 | 873 | total = 0 # total training loss 874 | total_reg = 0 875 | start_frame_record = [] 876 | vidpath_record = [] 877 | yhat = [] 878 | y = [] 879 | 880 | n = 0 881 | with torch.set_grad_enabled(train): 882 | with tqdm.tqdm(total=len(dataloader)) as pbar: 883 | enum_idx = 0 884 | for (X, outcome, start_frame, video_path, _, _) in dataloader: 885 | enum_idx = enum_idx + 1 886 | 887 | if not train: 888 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy()) 889 | vidpath_record.append(video_path) 890 | 891 | y.append(outcome.detach().cpu().numpy()) 892 | 893 | if X.dtype == torch.double: 894 | X = X.float() 895 | 896 | X = X[:,:3,...].to(device) 897 | 898 | outcome = outcome.to(device) 899 | 900 | model.train(False) 901 | model_0.train(False) 902 | all_output = model(X) 903 | 904 | y_pred = all_output[0] 905 | 906 | loss_reg = torch.nn.functional.mse_loss(y_pred.view(-1), outcome) 907 | yhat.append(y_pred.view(-1).to("cpu").detach().numpy()) 908 | 909 | total_reg += loss_reg.item() * outcome.size(0) 910 | total = total_reg 911 | n += outcome.size(0) 912 | 913 | pbar.set_postfix_str("total {:.4f}".format(total / n)) 914 | pbar.update() 915 | 916 | yhat = np.concatenate(yhat) 917 | 918 | y = np.concatenate(y) 919 | 920 | 921 | return (total / n, yhat, y) 922 | 923 | 924 | -------------------------------------------------------------------------------- /flow_a_tmi_revise_v2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_a_tmi_revise_v2.PNG -------------------------------------------------------------------------------- /flow_b_tmi_revise.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_b_tmi_revise.PNG -------------------------------------------------------------------------------- /flow_graph.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_graph.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | ca-certificates=2021.4.13=h06a4308_1 6 | certifi=2020.12.5=py39h06a4308_0 7 | ld_impl_linux-64=2.33.1=h53a641e_7 8 | libffi=3.3=he6710b0_2 9 | libgcc-ng=9.1.0=hdf63c60_0 10 | libstdcxx-ng=9.1.0=hdf63c60_0 11 | ncurses=6.2=he6710b0_1 12 | numpy=1.21.2=pypi_0 13 | openssl=1.1.1k=h27cfd23_0 14 | pillow=8.3.1=pypi_0 15 | pip=21.0.1=py39h06a4308_0 16 | python=3.9.4=hdb3f193_0 17 | readline=8.1=h27cfd23_0 18 | setuptools=52.0.0=py39h06a4308_0 19 | sqlite=3.35.4=hdfb4753_0 20 | tk=8.6.10=hbc83047_0 21 | torch=1.7.1+cu110=pypi_0 22 | torchaudio=0.7.2=pypi_0 23 | torchvision=0.8.2+cu110=pypi_0 24 | typing-extensions=3.10.0.0=pypi_0 25 | tzdata=2020f=h52ac0ba_0 26 | wheel=0.36.2=pyhd3eb1b0_0 27 | xz=5.2.5=h7b6447c_0 28 | zlib=1.2.11=h7b6447c_3 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Metadata for package to allow installation with pip.""" 3 | 4 | import os 5 | 6 | import setuptools 7 | 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | # Use same version from code 12 | # See 3 from 13 | # https://packaging.python.org/guides/single-sourcing-package-version/ 14 | version = {} 15 | with open(os.path.join("echonet", "__version__.py")) as f: 16 | exec(f.read(), version) # pylint: disable=W0122 17 | 18 | setuptools.setup( 19 | name="echonet", 20 | description="Video-based AI for beat-to-beat cardiac function assessment.", 21 | version=version["__version__"], 22 | url="https://echonet.github.io/dynamic", 23 | packages=setuptools.find_packages(exclude=["output.*", "output*", "*output.*", "*output*", "*output", "output"]), 24 | install_requires=[ 25 | "click", 26 | "numpy", 27 | "pandas", 28 | "torch", 29 | "torchvision", 30 | "opencv-python", 31 | "scikit-image", 32 | "tqdm", 33 | "sklearn" 34 | ], 35 | classifiers=[ 36 | "Programming Language :: Python :: 3", 37 | ], 38 | entry_points={ 39 | "console_scripts": [ 40 | "echonet=echonet:main", 41 | ], 42 | } 43 | 44 | ) 45 | --------------------------------------------------------------------------------