├── .gitignore
├── .travis.yml
├── README.md
├── echonet
├── __init__.py
├── __main__.py
├── __version__.py
├── config.py
├── datasets
│ ├── __init__.py
│ └── echo.py
├── models
│ ├── __init__.py
│ └── rnet2dp1.py
├── segmentation
│ ├── __init__.py
│ ├── _utils.py
│ ├── deeplabv3.py
│ └── segmentation.py
└── utils
│ ├── __init__.py
│ ├── seg_cycle.py
│ ├── video_segin.py
│ └── vidsegin_teachstd_kd.py
├── flow_a_tmi_revise_v2.PNG
├── flow_b_tmi_revise.PNG
├── flow_graph.PNG
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | *.swp
4 | echonet.cfg
5 | .echonet.cfg
6 | *.pyc
7 | echonet.egg-info/
8 | *.pth
9 | *.pt
10 | *.npy
11 | output/zdbg*
12 | output/*/size
13 | output/*/videos
14 | *.avi
15 | *.zip
16 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: minimal
2 |
3 | os:
4 | - linux
5 |
6 | env:
7 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2 (torchvision 0.2 does not have VisionDataset)
8 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3 (torchvision 0.3 has a cuda issue)
9 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4
10 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5
11 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2
12 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3
13 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4
14 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5
15 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2
16 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3
17 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4
18 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5
19 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2
20 | # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3
21 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4
22 | - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5
23 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2
24 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3
25 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4
26 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5
27 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2
28 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3
29 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4
30 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5
31 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2
32 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3
33 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4
34 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5
35 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2
36 | # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3
37 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4
38 | - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5
39 |
40 | install:
41 | - if [[ "$TRAVIS_OS_NAME" == "linux" ]];
42 | then
43 | MINICONDA_OS=Linux;
44 | sudo apt-get update;
45 | else
46 | MINICONDA_OS=MacOSX;
47 | brew update;
48 | fi
49 | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-${MINICONDA_OS}-x86_64.sh -O miniconda.sh
50 | - bash miniconda.sh -b -p $HOME/miniconda
51 | - source "$HOME/miniconda/etc/profile.d/conda.sh"
52 | - hash -r
53 | - conda config --set always_yes yes --set changeps1 no
54 | - conda update -q conda
55 | # Useful for debugging any issues with conda
56 | - conda info -a
57 | - conda search pytorch || true
58 |
59 | - conda create -q -n test-environment python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION}
60 | - conda activate test-environment
61 | - pip install -q torchvision==${TORCHVISION_VERSION} "pillow<7.0.0"
62 | - pip install -q .
63 | - pip install -q flake8 pylint
64 |
65 | script:
66 | - flake8 --ignore=E501
67 | - pylint --disable=C0103,C0301,R0401,R0801,R0902,R0912,R0913,R0914,R0915 --extension-pkg-whitelist=cv2,torch --generated-members=torch.* echonet/ scripts/*.py setup.py
68 | - python -c "import echonet"
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Cyclical Self-Supervision for Semi-Supervised Ejection Fraction Prediction from Echocardiogram Videos
3 |
4 |
5 |
6 | This is the implementation of CSS for Semi-Supervised Ejection Fraction Prediction for the paper ["Cyclical Self-Supervision for Semi-Supervised Ejection Fraction Prediction from Echocardiogram Videos"]().
7 |
8 | 
9 |
10 |
11 |
12 |
13 | ## Data
14 |
15 | Researchers can request the EchoNet-Dynamic dataset at https://echonet.github.io/dynamic/ and set the directory path in the configuration file, `echonet.cfg`.
16 |
17 |
18 |
19 |
20 | ## Environment
21 |
22 | It is recommended to use PyTorch `conda` environments for running the program. A requirements file has been included.
23 |
24 |
25 |
26 |
27 | ## Training and testing
28 |
29 | The code must first be installed by running
30 |
31 | pip3 install --user .
32 |
33 | under the repository directory `CSS-SemiVideo`. Training consists of three components:
34 |
35 |
36 |
37 | ### 1) To train the CSS semi-supervised segmentation model, run:
38 |
39 | ```
40 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --reduced_set
41 | ```
42 |
43 | The LV segmentation prediction masks of all frames must be inferred for the second stage. To do so, run:
44 |
45 | ```
46 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=train
47 |
48 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=val
49 |
50 | echonet seg_cycle --batch_size=20 --output=output/css_seg --loss_cyc_w=0.01 --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --skip_test --reduced_set --run_inference=test
51 | ```
52 |
53 | The segmentation prediction outputs will be located under the output folder `output/css_seg`. To reduce installation time for EchoNet-Dynamic, these are moved to a separatate directory parallel to `CSS-SemiVideo`, i.e. `CSS-SemiVideo/../infer_buffers/css_seg`. Segmentation masks are also sourced from this location for Step 2 of the framework.
54 |
55 | To do this, run:
56 |
57 | ```
58 | mkdir ../infer_buffers/css_seg
59 | mv output/css_seg/*_infer_cmpct ../infer_buffers/css_seg/
60 | ```
61 |
62 |
63 | ### 2) To train the multi-modal LVEF prediction model, run:
64 |
65 | ```
66 | echonet video_segin --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --output=output/teacher_model --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --segsource=css_seg
67 | ```
68 |
69 |
70 | ### 3) To train teacher-student distillation, run:
71 |
72 | ```
73 | echonet vidsegin_teachstd_kd --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --output=output/end2end_model --num_epochs=25 --rd_label=920 --rd_unlabel=6440 --run_test --reduced_set --max_block=20 --segsource=css_seg --w_unlb=5 --batch_size_unlb=10 --weights_0=output/teacher_model/best.pt
74 | ```
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | ## Pretrained models
84 |
85 | Trained checkpoints and models can be downloaded from:
86 |
87 | 1) CSS for semi-supervised segmentation: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/EqiP-N0MDRZGlwqr5PeZUrYBtLki8QWtBlMqRK1FNkjbcw?e=DIpkIm
88 |
89 | 2) Multi-modal LVEF regression: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/ErxaHepi4ndAnMcvSOwTH5wBDI6rHypqdcBiXF8B0XYvmg?e=Rud7Pf
90 |
91 | 3) Teacher-student distillation: https://hkustconnect-my.sharepoint.com/:f:/g/personal/wdaiaj_connect_ust_hk/Ev7mQ1ReI05LtiDIqQu1IpYBC6xN4R47PsYnhDUQr4n3fw?e=US4caq
92 |
93 |
94 | To run with the pretrained model weights, replace the `.pts` files in the target output directory with the downloaded files.
95 |
96 |
97 |
98 | | Experiments | MAE | RMSE | R2 |
99 | | ---------- | :-----------: | :-----------: | :-----------: |
100 | | Multi-Modal | 5.13 ± 0.05 | 6.90 ± 0.07 | 67.6% ± 0.5 |
101 | | Teacher-student Distillation | 4.90 ± 0.04 | 6.57 ± 0.06 | 71.1% ± 0.4 |
102 |
103 |
104 |
105 |
106 | ## Notes
107 | * Contact: DAI Weihang (wdai03@gmail.com)
108 |
109 |
110 |
111 | ## Citation
112 | If this code is useful for your research, please consider citing:
113 |
114 | (to be released)
115 |
--------------------------------------------------------------------------------
/echonet/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The echonet package contains code for loading echocardiogram videos, and
3 | functions for training and testing segmentation and ejection fraction
4 | prediction models.
5 | """
6 |
7 | import click
8 |
9 | from echonet.__version__ import __version__
10 | from echonet.config import CONFIG as config
11 | import echonet.datasets as datasets
12 | import echonet.utils as utils
13 | import echonet.models as models
14 | import echonet.segmentation as segmentation
15 |
16 | @click.group()
17 | def main():
18 | """Entry point for command line interface."""
19 |
20 |
21 | del click
22 |
23 | main.add_command(utils.seg_cycle.run)
24 | main.add_command(utils.vidsegin_teachstd_kd.run)
25 | main.add_command(utils.video_segin.run)
26 |
27 |
28 | __all__ = ["__version__", "config", "datasets", "main", "utils", "models", "segmentation"]
29 |
--------------------------------------------------------------------------------
/echonet/__main__.py:
--------------------------------------------------------------------------------
1 | """Entry point for command line."""
2 |
3 | import echonet
4 |
5 |
6 | if __name__ == '__main__':
7 | echonet.main()
8 |
--------------------------------------------------------------------------------
/echonet/__version__.py:
--------------------------------------------------------------------------------
1 | """Version number for Echonet package."""
2 |
3 | __version__ = "1.0.0"
4 |
--------------------------------------------------------------------------------
/echonet/config.py:
--------------------------------------------------------------------------------
1 | """Sets paths based on configuration files."""
2 |
3 | import configparser
4 | import os
5 | import types
6 |
7 | _FILENAME = None
8 | _PARAM = {}
9 | for filename in ["echonet.cfg",
10 | ".echonet.cfg",
11 | os.path.expanduser("~/echonet.cfg"),
12 | os.path.expanduser("~/.echonet.cfg"),
13 | ]:
14 | if os.path.isfile(filename):
15 | _FILENAME = filename
16 | config = configparser.ConfigParser()
17 | with open(filename, "r") as f:
18 | config.read_string("[config]\n" + f.read())
19 | _PARAM = config["config"]
20 | break
21 |
22 | CONFIG = types.SimpleNamespace(
23 | FILENAME=_FILENAME,
24 | DATA_DIR=_PARAM.get("data_dir", "../EchoNet/Heart-videos/"))
25 |
--------------------------------------------------------------------------------
/echonet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The echonet.datasets submodule defines a Pytorch dataset for loading
3 | echocardiogram videos.
4 | """
5 |
6 | from .echo import Echo, Echo_tskd, Echo_CSS
7 |
8 | __all__ = ["Echo", "Echo_tskd", "Echo_CSS"]
9 |
--------------------------------------------------------------------------------
/echonet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .rnet2dp1 import r2plus1d_18_kd, r2plus1d_18
2 |
3 | __all__ = ["r2plus1d_18_kd", "r2plus1d_18"]
--------------------------------------------------------------------------------
/echonet/models/rnet2dp1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import math
5 | from torchvision.models.utils import load_state_dict_from_url
6 | import numpy as np
7 | # from ..utils import load_state_dict_from_url
8 |
9 |
10 | __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18', 'r2plus1d_18_kd']
11 |
12 | model_urls = {
13 | 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth',
14 | 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth',
15 | 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth',
16 | }
17 |
18 |
19 | class Conv3DSimple(nn.Conv3d):
20 | def __init__(self,
21 | in_planes,
22 | out_planes,
23 | midplanes=None,
24 | stride=1,
25 | padding=1):
26 |
27 | super(Conv3DSimple, self).__init__(
28 | in_channels=in_planes,
29 | out_channels=out_planes,
30 | kernel_size=(3, 3, 3),
31 | stride=stride,
32 | padding=padding,
33 | bias=False)
34 |
35 | @staticmethod
36 | def get_downsample_stride(stride):
37 | return stride, stride, stride
38 |
39 |
40 | class Conv2Plus1D(nn.Sequential):
41 |
42 | def __init__(self,
43 | in_planes,
44 | out_planes,
45 | midplanes,
46 | stride=1,
47 | padding=1):
48 | super(Conv2Plus1D, self).__init__(
49 | nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
50 | stride=(1, stride, stride), padding=(0, padding, padding),
51 | bias=False),
52 | nn.BatchNorm3d(midplanes),
53 | nn.ReLU(inplace=True),
54 | nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
55 | stride=(stride, 1, 1), padding=(padding, 0, 0),
56 | bias=False))
57 |
58 | @staticmethod
59 | def get_downsample_stride(stride):
60 | return stride, stride, stride
61 |
62 |
63 | class Conv3DNoTemporal(nn.Conv3d):
64 |
65 | def __init__(self,
66 | in_planes,
67 | out_planes,
68 | midplanes=None,
69 | stride=1,
70 | padding=1):
71 |
72 | super(Conv3DNoTemporal, self).__init__(
73 | in_channels=in_planes,
74 | out_channels=out_planes,
75 | kernel_size=(1, 3, 3),
76 | stride=(1, stride, stride),
77 | padding=(0, padding, padding),
78 | bias=False)
79 |
80 | @staticmethod
81 | def get_downsample_stride(stride):
82 | return 1, stride, stride
83 |
84 |
85 | class BasicBlock(nn.Module):
86 |
87 | expansion = 1
88 |
89 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
90 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
91 |
92 | super(BasicBlock, self).__init__()
93 | self.conv1 = nn.Sequential(
94 | conv_builder(inplanes, planes, midplanes, stride),
95 | nn.BatchNorm3d(planes),
96 | nn.ReLU(inplace=True)
97 | )
98 | self.conv2 = nn.Sequential(
99 | conv_builder(planes, planes, midplanes),
100 | nn.BatchNorm3d(planes)
101 | )
102 | self.relu = nn.ReLU(inplace=True)
103 | self.downsample = downsample
104 | self.stride = stride
105 |
106 | def forward(self, x):
107 | residual = x
108 |
109 | out = self.conv1(x)
110 | out = self.conv2(out)
111 | if self.downsample is not None:
112 | residual = self.downsample(x)
113 |
114 | out += residual
115 | out = self.relu(out)
116 |
117 | return out
118 |
119 |
120 | class Bottleneck(nn.Module):
121 | expansion = 4
122 |
123 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
124 |
125 | super(Bottleneck, self).__init__()
126 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
127 |
128 | # 1x1x1
129 | self.conv1 = nn.Sequential(
130 | nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
131 | nn.BatchNorm3d(planes),
132 | nn.ReLU(inplace=True)
133 | )
134 | # Second kernel
135 | self.conv2 = nn.Sequential(
136 | conv_builder(planes, planes, midplanes, stride),
137 | nn.BatchNorm3d(planes),
138 | nn.ReLU(inplace=True)
139 | )
140 |
141 | # 1x1x1
142 | self.conv3 = nn.Sequential(
143 | nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
144 | nn.BatchNorm3d(planes * self.expansion)
145 | )
146 | self.relu = nn.ReLU(inplace=True)
147 | self.downsample = downsample
148 | self.stride = stride
149 |
150 | def forward(self, x):
151 | residual = x
152 |
153 | out = self.conv1(x)
154 | out = self.conv2(out)
155 | out = self.conv3(out)
156 |
157 | if self.downsample is not None:
158 | residual = self.downsample(x)
159 |
160 | out += residual
161 | out = self.relu(out)
162 |
163 | return out
164 |
165 |
166 | class BasicStem(nn.Sequential):
167 | """The default conv-batchnorm-relu stem
168 | """
169 | def __init__(self):
170 | super(BasicStem, self).__init__(
171 | nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
172 | padding=(1, 3, 3), bias=False),
173 | nn.BatchNorm3d(64),
174 | nn.ReLU(inplace=True))
175 |
176 |
177 | class R2Plus1dStem(nn.Sequential):
178 | """R(2+1)D stem is different than the default one as it uses separated 3D convolution
179 | """
180 | def __init__(self):
181 | super(R2Plus1dStem, self).__init__(
182 | nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
183 | stride=(1, 2, 2), padding=(0, 3, 3),
184 | bias=False),
185 | nn.BatchNorm3d(45),
186 | nn.ReLU(inplace=True),
187 | nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
188 | stride=(1, 1, 1), padding=(1, 0, 0),
189 | bias=False),
190 | nn.BatchNorm3d(64),
191 | nn.ReLU(inplace=True))
192 |
193 |
194 | class VideoResNet(nn.Module):
195 |
196 | def __init__(self, block, conv_makers, layers,
197 | stem, num_classes=400,
198 | zero_init_residual=False):
199 | """Generic resnet video generator.
200 |
201 | Args:
202 | block (nn.Module): resnet building block
203 | conv_makers (list(functions)): generator function for each layer
204 | layers (List[int]): number of blocks per layer
205 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
206 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
207 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
208 | """
209 | super(VideoResNet, self).__init__()
210 | self.inplanes = 64
211 |
212 | self.stem = stem()
213 |
214 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
215 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
216 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
217 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
218 |
219 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
220 | self.fc = nn.Linear(512 * block.expansion, num_classes)
221 |
222 | # init weights
223 | self._initialize_weights()
224 |
225 | if zero_init_residual:
226 | for m in self.modules():
227 | if isinstance(m, Bottleneck):
228 | nn.init.constant_(m.bn3.weight, 0)
229 |
230 | def forward(self, x):
231 | x = self.stem(x)
232 |
233 | x = self.layer1(x)
234 | x = self.layer2(x)
235 | x = self.layer3(x)
236 | x = self.layer4(x)
237 |
238 | x = self.avgpool(x)
239 | # Flatten the layer to fc
240 | x = x.flatten(1)
241 | x = self.fc(x)
242 |
243 | return x
244 |
245 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
246 | downsample = None
247 |
248 | if stride != 1 or self.inplanes != planes * block.expansion:
249 | ds_stride = conv_builder.get_downsample_stride(stride)
250 | downsample = nn.Sequential(
251 | nn.Conv3d(self.inplanes, planes * block.expansion,
252 | kernel_size=1, stride=ds_stride, bias=False),
253 | nn.BatchNorm3d(planes * block.expansion)
254 | )
255 | layers = []
256 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
257 |
258 | self.inplanes = planes * block.expansion
259 | for i in range(1, blocks):
260 | layers.append(block(self.inplanes, planes, conv_builder))
261 |
262 | return nn.Sequential(*layers)
263 |
264 | def _initialize_weights(self):
265 | for m in self.modules():
266 | if isinstance(m, nn.Conv3d):
267 | nn.init.kaiming_normal_(m.weight, mode='fan_out',
268 | nonlinearity='relu')
269 | if m.bias is not None:
270 | nn.init.constant_(m.bias, 0)
271 | elif isinstance(m, nn.BatchNorm3d):
272 | nn.init.constant_(m.weight, 1)
273 | nn.init.constant_(m.bias, 0)
274 | elif isinstance(m, nn.Linear):
275 | nn.init.normal_(m.weight, 0, 0.01)
276 | nn.init.constant_(m.bias, 0)
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 | class VideoResNet_kd(nn.Module):
286 |
287 | def __init__(self, block, conv_makers, layers,
288 | stem, num_classes=400,
289 | zero_init_residual=False):
290 | """Generic resnet video generator.
291 |
292 | Args:
293 | block (nn.Module): resnet building block
294 | conv_makers (list(functions)): generator function for each layer
295 | layers (List[int]): number of blocks per layer
296 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
297 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
298 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
299 | """
300 | super(VideoResNet_kd, self).__init__()
301 | self.inplanes = 64
302 |
303 | self.stem = stem()
304 |
305 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
306 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
307 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
308 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
309 |
310 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
311 | self.fc = nn.Linear(512 * block.expansion, num_classes)
312 |
313 | # init weights
314 | self._initialize_weights()
315 |
316 | if zero_init_residual:
317 | for m in self.modules():
318 | if isinstance(m, Bottleneck):
319 | nn.init.constant_(m.bn3.weight, 0)
320 |
321 | def forward(self, x):
322 | x = self.stem(x)
323 | x_l0 = x
324 | x = self.layer1(x)
325 | x_l1 = x
326 | x = self.layer2(x)
327 | x_l2 = x
328 | x = self.layer3(x)
329 | x_l3 = x
330 | x = self.layer4(x)
331 | x_l4 = x
332 |
333 | x = self.avgpool(x)
334 |
335 | x = x.flatten(1)
336 | x_reg_feat = x
337 | x = self.fc(x)
338 |
339 | return x, x_l0, x_l1, x_l2, x_reg_feat, x_l4, x_l3
340 |
341 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
342 | downsample = None
343 |
344 | if stride != 1 or self.inplanes != planes * block.expansion:
345 | ds_stride = conv_builder.get_downsample_stride(stride)
346 | downsample = nn.Sequential(
347 | nn.Conv3d(self.inplanes, planes * block.expansion,
348 | kernel_size=1, stride=ds_stride, bias=False),
349 | nn.BatchNorm3d(planes * block.expansion)
350 | )
351 | layers = []
352 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
353 |
354 | self.inplanes = planes * block.expansion
355 | for i in range(1, blocks):
356 | layers.append(block(self.inplanes, planes, conv_builder))
357 |
358 | return nn.Sequential(*layers)
359 |
360 | def _initialize_weights(self):
361 | for m in self.modules():
362 | if isinstance(m, nn.Conv3d):
363 | nn.init.kaiming_normal_(m.weight, mode='fan_out',
364 | nonlinearity='relu')
365 | if m.bias is not None:
366 | nn.init.constant_(m.bias, 0)
367 | elif isinstance(m, nn.BatchNorm3d):
368 | nn.init.constant_(m.weight, 1)
369 | nn.init.constant_(m.bias, 0)
370 | elif isinstance(m, nn.Linear):
371 | nn.init.normal_(m.weight, 0, 0.01)
372 | nn.init.constant_(m.bias, 0)
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 | def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
388 | model = VideoResNet(**kwargs)
389 |
390 | if pretrained:
391 | state_dict = load_state_dict_from_url(model_urls[arch],
392 | progress=progress)
393 | model.load_state_dict(state_dict)
394 | return model
395 |
396 |
397 |
398 | def _video_resnet_kd(arch, pretrained=False, progress=True, **kwargs):
399 | model = VideoResNet_kd(**kwargs)
400 |
401 | if pretrained:
402 | state_dict = load_state_dict_from_url(model_urls[arch],
403 | progress=progress)
404 | model.load_state_dict(state_dict)
405 | return model
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 | def r2plus1d_18(pretrained=False, progress=True, **kwargs):
415 | """Constructor for the 18 layer deep R(2+1)D network as in
416 | https://arxiv.org/abs/1711.11248
417 |
418 | Args:
419 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400
420 | progress (bool): If True, displays a progress bar of the download to stderr
421 |
422 | Returns:
423 | nn.Module: R(2+1)D-18 network
424 | """
425 | return _video_resnet('r2plus1d_18',
426 | pretrained, progress,
427 | block=BasicBlock,
428 | conv_makers=[Conv2Plus1D] * 4,
429 | layers=[2, 2, 2, 2],
430 | stem=R2Plus1dStem, **kwargs)
431 |
432 |
433 |
434 |
435 |
436 | def r2plus1d_18_kd(pretrained=False, progress=True, **kwargs):
437 | """Constructor for the 18 layer deep R(2+1)D network as in
438 | https://arxiv.org/abs/1711.11248
439 |
440 | Args:
441 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400
442 | progress (bool): If True, displays a progress bar of the download to stderr
443 |
444 | Returns:
445 | nn.Module: R(2+1)D-18 network
446 | """
447 | return _video_resnet_kd('r2plus1d_18',
448 | pretrained, progress,
449 | block=BasicBlock,
450 | conv_makers=[Conv2Plus1D] * 4,
451 | layers=[2, 2, 2, 2],
452 | stem=R2Plus1dStem, **kwargs)
453 |
454 |
455 |
456 |
457 |
--------------------------------------------------------------------------------
/echonet/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from .segmentation import *
2 | # from .fcn import *
3 | from .deeplabv3 import *
4 | # from .lraspp import *
--------------------------------------------------------------------------------
/echonet/segmentation/_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Optional, Dict
3 |
4 | from torch import nn, Tensor
5 | from torch.nn import functional as F
6 | import torch
7 |
8 |
9 | class _SimpleSegmentationModel(nn.Module):
10 | __constants__ = ['aux_classifier']
11 |
12 | def __init__(
13 | self,
14 | backbone: nn.Module,
15 | classifier: nn.Module,
16 | aux_classifier: Optional[nn.Module] = None
17 | ) -> None:
18 | super(_SimpleSegmentationModel, self).__init__()
19 | self.backbone = backbone
20 | self.classifier = classifier
21 | # self.aux_classifier = aux_classifier
22 |
23 | self.ctr_avgpool = nn.AdaptiveAvgPool2d((1, 1))
24 | self.ctr_fc = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), nn.Linear(2048, 128))
25 |
26 |
27 | def forward(self, x: Tensor) -> Dict[str, Tensor]:
28 | input_shape = x.shape[-2:]
29 | # contract: features is a dict of tensors
30 |
31 | # print("self.backbone", self.backbone.conv1)
32 | features = self.backbone(x)
33 |
34 | result = OrderedDict()
35 | x = features["out"]
36 | x = self.classifier(x)
37 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
38 | result["out"] = x
39 |
40 | x_ctr = features["out"]
41 | x_ctr = self.ctr_avgpool(x_ctr)
42 | x_ctr = x_ctr.flatten(1)
43 | x_ctr = self.ctr_fc(x_ctr)
44 | # print("x_ctr.shape, in _utils segmentation", x_ctr.shape)
45 | result['ctr_feat'] = F.normalize(x_ctr, dim = 1)
46 | result['feat_mid'] = features["out"]
47 | return result
48 |
49 |
50 |
51 |
52 |
53 | class _SimpleSegmentationModel_CSS(nn.Module):
54 | __constants__ = ['aux_classifier']
55 |
56 | def __init__(
57 | self,
58 | backbone: nn.Module,
59 | classifier: nn.Module,
60 | aux_classifier: Optional[nn.Module] = None
61 | ) -> None:
62 | super(_SimpleSegmentationModel_CSS, self).__init__()
63 | self.backbone = backbone
64 | self.classifier = classifier
65 | # self.aux_classifier = aux_classifier
66 |
67 |
68 | def forward(self, x: Tensor) -> Dict[str, Tensor]:
69 | input_shape = x.shape[-2:]
70 | # contract: features is a dict of tensors
71 |
72 | # print("self.backbone", self.backbone.conv1)
73 |
74 | xtest = self.backbone.conv1(x)
75 | xtest = self.backbone.bn1(xtest)
76 | xtest = self.backbone.relu(xtest)
77 | xtest_layerbs = xtest
78 | xtest = self.backbone.maxpool(xtest)
79 | xtest_layer0 = xtest
80 | xtest = self.backbone.layer1(xtest)
81 | xtest_layer1 = xtest
82 | xtest = self.backbone.layer2(xtest) ### can just output here.
83 | xtest_layer2 = xtest
84 | xtest = self.backbone.layer3(xtest)
85 | xtest = self.backbone.layer4(xtest)
86 | # print("xtest_layerbs.shape", xtest_layerbs.shape)# xtest_layerbs.shape torch.Size([2, 64, 56, 56])
87 | # print("xtest_layer0.shape", xtest_layer0.shape) #xtest_layer0.shape torch.Size([2, 64, 28, 28])
88 | # print("xtest_layer1.shape", xtest_layer1.shape) #xtest_layer1.shape torch.Size([2, 256, 28, 28])
89 | # print("xtest_layer2.shape", xtest_layer2.shape) #torch.Size([2, 512, 14, 14])
90 |
91 | result = OrderedDict()
92 | x = xtest
93 |
94 | x = self.classifier(x)
95 | x_maskpre = x
96 | x_maskpre = F.interpolate(x_maskpre, size=[56,56], mode='bilinear', align_corners=False)
97 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
98 | result["out"] = x
99 | result['x_layerbs'] = xtest_layerbs
100 | result['x_layer1'] = xtest_layer1
101 | result['x_layer4'] = xtest
102 | result['maskfeat'] = x_maskpre
103 | return result
104 |
105 |
106 |
--------------------------------------------------------------------------------
/echonet/segmentation/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from typing import List
5 |
6 | from ._utils import _SimpleSegmentationModel, _SimpleSegmentationModel_CSS
7 |
8 |
9 | __all__ = ["DeepLabV3", "DeepLabV3_CSS"]
10 |
11 |
12 | class DeepLabV3(_SimpleSegmentationModel):
13 | """
14 | Implements DeepLabV3 model from
15 | `"Rethinking Atrous Convolution for Semantic Image Segmentation"
16 | `_.
17 |
18 | Args:
19 | backbone (nn.Module): the network used to compute the features for the model.
20 | The backbone should return an OrderedDict[Tensor], with the key being
21 | "out" for the last feature map used, and "aux" if an auxiliary classifier
22 | is used.
23 | classifier (nn.Module): module that takes the "out" element returned from
24 | the backbone and returns a dense prediction.
25 | aux_classifier (nn.Module, optional): auxiliary classifier used during training
26 | """
27 | pass
28 |
29 |
30 |
31 | class DeepLabV3_CSS(_SimpleSegmentationModel_CSS):
32 | """
33 | Implements DeepLabV3 model from
34 | `"Rethinking Atrous Convolution for Semantic Image Segmentation"
35 | `_.
36 |
37 | Args:
38 | backbone (nn.Module): the network used to compute the features for the model.
39 | The backbone should return an OrderedDict[Tensor], with the key being
40 | "out" for the last feature map used, and "aux" if an auxiliary classifier
41 | is used.
42 | classifier (nn.Module): module that takes the "out" element returned from
43 | the backbone and returns a dense prediction.
44 | aux_classifier (nn.Module, optional): auxiliary classifier used during training
45 | """
46 | pass
47 |
48 |
49 |
50 | class DeepLabHead(nn.Sequential):
51 | def __init__(self, in_channels: int, num_classes: int) -> None:
52 | super(DeepLabHead, self).__init__(
53 | ASPP(in_channels, [12, 24, 36]),
54 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
55 | nn.BatchNorm2d(256),
56 | nn.ReLU(),
57 | nn.Conv2d(256, num_classes, 1)
58 | )
59 |
60 |
61 | class ASPPConv(nn.Sequential):
62 | def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
63 | modules = [
64 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
65 | nn.BatchNorm2d(out_channels),
66 | nn.ReLU()
67 | ]
68 | super(ASPPConv, self).__init__(*modules)
69 |
70 |
71 | class ASPPPooling(nn.Sequential):
72 | def __init__(self, in_channels: int, out_channels: int) -> None:
73 | super(ASPPPooling, self).__init__(
74 | nn.AdaptiveAvgPool2d(1),
75 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
76 | nn.BatchNorm2d(out_channels),
77 | nn.ReLU())
78 |
79 | def forward(self, x: torch.Tensor) -> torch.Tensor:
80 | size = x.shape[-2:]
81 | for mod in self:
82 | x = mod(x)
83 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
84 |
85 |
86 | class ASPP(nn.Module):
87 | def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
88 | super(ASPP, self).__init__()
89 | modules = []
90 | modules.append(nn.Sequential(
91 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
92 | nn.BatchNorm2d(out_channels),
93 | nn.ReLU()))
94 |
95 | rates = tuple(atrous_rates)
96 | for rate in rates:
97 | modules.append(ASPPConv(in_channels, out_channels, rate))
98 |
99 | modules.append(ASPPPooling(in_channels, out_channels))
100 |
101 | self.convs = nn.ModuleList(modules)
102 |
103 | self.project = nn.Sequential(
104 | nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
105 | nn.BatchNorm2d(out_channels),
106 | nn.ReLU(),
107 | nn.Dropout(0.5))
108 |
109 | def forward(self, x: torch.Tensor) -> torch.Tensor:
110 | _res = []
111 | for conv in self.convs:
112 | _res.append(conv(x))
113 | res = torch.cat(_res, dim=1)
114 | return self.project(res)
--------------------------------------------------------------------------------
/echonet/segmentation/segmentation.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from torch import nn
4 | from typing import Any, Optional, Dict
5 | # from .._utils import IntermediateLayerGetter
6 | # from ..._internally_replaced_utils import load_state_dict_from_url
7 | from torch.hub import load_state_dict_from_url
8 | from torchvision.models import resnet
9 | from .deeplabv3 import DeepLabHead, DeepLabV3, DeepLabV3_CSS
10 | # from .fcn import FCN, FCNHead
11 |
12 | __all__ = ['deeplabv3_resnet50', 'deeplabv3_resnet50_CSS']
13 |
14 |
15 | model_urls = {
16 | 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',
17 | 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
18 | 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
19 | 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
20 | 'deeplabv3_mobilenet_v3_large_coco':
21 | 'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth',
22 | 'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth',
23 | }
24 |
25 |
26 |
27 |
28 |
29 |
30 | class IntermediateLayerGetter(nn.ModuleDict):
31 | """
32 | Module wrapper that returns intermediate layers from a model
33 | It has a strong assumption that the modules have been registered
34 | into the model in the same order as they are used.
35 | This means that one should **not** reuse the same nn.Module
36 | twice in the forward if you want this to work.
37 | Additionally, it is only able to query submodules that are directly
38 | assigned to the model. So if `model` is passed, `model.feature1` can
39 | be returned, but not `model.feature1.layer2`.
40 | Args:
41 | model (nn.Module): model on which we will extract the features
42 | return_layers (Dict[name, new_name]): a dict containing the names
43 | of the modules for which the activations will be returned as
44 | the key of the dict, and the value of the dict is the name
45 | of the returned activation (which the user can specify).
46 | Examples::
47 | >>> m = torchvision.models.resnet18(pretrained=True)
48 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
49 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
50 | >>> {'layer1': 'feat1', 'layer3': 'feat2'})
51 | >>> out = new_m(torch.rand(1, 3, 224, 224))
52 | >>> print([(k, v.shape) for k, v in out.items()])
53 | >>> [('feat1', torch.Size([1, 64, 56, 56])),
54 | >>> ('feat2', torch.Size([1, 256, 14, 14]))]
55 | """
56 | _version = 2
57 | __annotations__ = {
58 | "return_layers": Dict[str, str],
59 | }
60 |
61 | def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
62 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
63 | raise ValueError("return_layers are not present in model")
64 | orig_return_layers = return_layers
65 | return_layers = {str(k): str(v) for k, v in return_layers.items()}
66 | layers = OrderedDict()
67 | for name, module in model.named_children():
68 | layers[name] = module
69 | if name in return_layers:
70 | del return_layers[name]
71 | if not return_layers:
72 | break
73 |
74 | super(IntermediateLayerGetter, self).__init__(layers)
75 | self.return_layers = orig_return_layers
76 |
77 | def forward(self, x):
78 | out = OrderedDict()
79 | for name, module in self.items():
80 | x = module(x)
81 | if name in self.return_layers:
82 | out_name = self.return_layers[name]
83 | out[out_name] = x
84 | return out
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | def _segm_model(
93 | name: str,
94 | backbone_name: str,
95 | num_classes: int,
96 | aux: Optional[bool],
97 | pretrained_backbone: bool = True
98 | ) -> nn.Module:
99 | if 'resnet' in backbone_name:
100 | backbone = resnet.__dict__[backbone_name](
101 | pretrained=pretrained_backbone,
102 | replace_stride_with_dilation=[False, True, True])
103 | out_layer = 'layer4'
104 | out_inplanes = 2048
105 | aux_layer = 'layer3'
106 | aux_inplanes = 1024
107 | elif 'mobilenet_v3' in backbone_name:
108 | assert 1==2, "not using mobilenet"
109 |
110 | # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
111 | # The first and last blocks are always included because they are the C0 (conv1) and Cn.
112 | stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
113 | out_pos = stage_indices[-1] # use C5 which has output_stride = 16
114 | out_layer = str(out_pos)
115 | out_inplanes = backbone[out_pos].out_channels
116 | aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
117 | aux_layer = str(aux_pos)
118 | aux_inplanes = backbone[aux_pos].out_channels
119 | else:
120 | raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
121 |
122 | return_layers = {out_layer: 'out'}
123 | if aux:
124 | return_layers[aux_layer] = 'aux'
125 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
126 |
127 | aux_classifier = None
128 | # if aux:
129 | # aux_classifier = FCNHead(aux_inplanes, num_classes)
130 |
131 | model_map = {
132 | 'deeplabv3': (DeepLabHead, DeepLabV3) #,
133 | # 'fcn': (FCNHead, FCN),
134 | }
135 | classifier = model_map[name][0](out_inplanes, num_classes)
136 | base_model = model_map[name][1]
137 |
138 | model = base_model(backbone, classifier, aux_classifier)
139 | return model
140 |
141 |
142 |
143 |
144 |
145 | def _segm_model_CSS(
146 | name: str,
147 | backbone_name: str,
148 | num_classes: int,
149 | aux: Optional[bool],
150 | pretrained_backbone: bool = True
151 | ) -> nn.Module:
152 | if 'resnet' in backbone_name:
153 | backbone = resnet.__dict__[backbone_name](
154 | pretrained=pretrained_backbone,
155 | replace_stride_with_dilation=[False, True, True])
156 | out_layer = 'layer4'
157 | out_inplanes = 2048
158 | aux_layer = 'layer3'
159 | aux_inplanes = 1024
160 | elif 'mobilenet_v3' in backbone_name:
161 | assert 1==2, "not using mobilenet"
162 |
163 | # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
164 | # The first and last blocks are always included because they are the C0 (conv1) and Cn.
165 | stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
166 | out_pos = stage_indices[-1] # use C5 which has output_stride = 16
167 | out_layer = str(out_pos)
168 | out_inplanes = backbone[out_pos].out_channels
169 | aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
170 | aux_layer = str(aux_pos)
171 | aux_inplanes = backbone[aux_pos].out_channels
172 | else:
173 | raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
174 |
175 | return_layers = {out_layer: 'out'}
176 | if aux:
177 | return_layers[aux_layer] = 'aux'
178 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
179 |
180 | aux_classifier = None
181 | # if aux:
182 | # aux_classifier = FCNHead(aux_inplanes, num_classes)
183 |
184 | model_map = {
185 | 'deeplabv3': (DeepLabHead, DeepLabV3_CSS) #,
186 | # 'fcn': (FCNHead, FCN),
187 | }
188 | classifier = model_map[name][0](out_inplanes, num_classes)
189 | base_model = model_map[name][1]
190 |
191 | model = base_model(backbone, classifier, aux_classifier)
192 | return model
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 | def _load_model(
202 | arch_type: str,
203 | backbone: str,
204 | pretrained: bool,
205 | progress: bool,
206 | num_classes: int,
207 | aux_loss: Optional[bool],
208 | **kwargs: Any
209 | ) -> nn.Module:
210 | if pretrained:
211 | aux_loss = True
212 | kwargs["pretrained_backbone"] = False
213 | model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
214 | if pretrained:
215 | _load_weights(model, arch_type, backbone, progress)
216 | return model
217 |
218 |
219 |
220 | def _load_model_CSS(
221 | arch_type: str,
222 | backbone: str,
223 | pretrained: bool,
224 | progress: bool,
225 | num_classes: int,
226 | aux_loss: Optional[bool],
227 | **kwargs: Any
228 | ) -> nn.Module:
229 | if pretrained:
230 | aux_loss = True
231 | kwargs["pretrained_backbone"] = False
232 | model = _segm_model_CSS(arch_type, backbone, num_classes, aux_loss, **kwargs)
233 | if pretrained:
234 | _load_weights(model, arch_type, backbone, progress)
235 | return model
236 |
237 |
238 |
239 |
240 |
241 | def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
242 | arch = arch_type + '_' + backbone + '_coco'
243 | model_url = model_urls.get(arch, None)
244 | if model_url is None:
245 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
246 | else:
247 | # assert 1==2, "a bit mahfan, we don't allow pretrained for now, not needed in segmentation anyways"
248 | state_dict = load_state_dict_from_url(model_url, progress=progress)
249 | model.load_state_dict(state_dict, strict = False)
250 |
251 |
252 | def deeplabv3_resnet50(
253 | pretrained: bool = False,
254 | progress: bool = True,
255 | num_classes: int = 21,
256 | aux_loss: Optional[bool] = None,
257 | **kwargs: Any
258 | ) -> nn.Module:
259 | """Constructs a DeepLabV3 model with a ResNet-50 backbone.
260 |
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
263 | contains the same classes as Pascal VOC
264 | progress (bool): If True, displays a progress bar of the download to stderr
265 | num_classes (int): number of output classes of the model (including the background)
266 | aux_loss (bool): If True, it uses an auxiliary loss
267 | """
268 | return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
269 |
270 |
271 |
272 |
273 | def deeplabv3_resnet50_CSS(
274 | pretrained: bool = False,
275 | progress: bool = True,
276 | num_classes: int = 21,
277 | aux_loss: Optional[bool] = None,
278 | **kwargs: Any
279 | ) -> nn.Module:
280 | """Constructs a DeepLabV3 model with a ResNet-50 backbone.
281 |
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
284 | contains the same classes as Pascal VOC
285 | progress (bool): If True, displays a progress bar of the download to stderr
286 | num_classes (int): number of output classes of the model (including the background)
287 | aux_loss (bool): If True, it uses an auxiliary loss
288 | """
289 | return _load_model_CSS('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
290 |
291 |
--------------------------------------------------------------------------------
/echonet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions for videos, plotting and computing performance metrics."""
2 |
3 | import os
4 | import typing
5 | import datetime
6 |
7 | import cv2 # pytype: disable=attribute-error
8 | import matplotlib
9 | import numpy as np
10 | import torch
11 | import tqdm
12 |
13 |
14 | from . import seg_cycle
15 | from . import video_segin
16 | from . import vidsegin_teachstd_kd
17 |
18 |
19 | def loadvideo(filename: str) -> np.ndarray:
20 | """Loads a video from a file.
21 |
22 | Args:
23 | filename (str): filename of video
24 |
25 | Returns:
26 | A np.ndarray with dimensions (channels=3, frames, height, width). The
27 | values will be uint8's ranging from 0 to 255.
28 |
29 | Raises:
30 | FileNotFoundError: Could not find `filename`
31 | ValueError: An error occurred while reading the video
32 | """
33 |
34 | if not os.path.exists(filename):
35 | raise FileNotFoundError(filename)
36 | ###debug
37 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "opening vid")
38 | capture = cv2.VideoCapture(filename)
39 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "videocapture done")
40 |
41 | frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
42 | frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
43 | frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
44 |
45 | v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)
46 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "reading capture")
47 | for count in range(frame_count):
48 | ret, frame = capture.read()
49 | if not ret:
50 | raise ValueError("Failed to load frame #{} of {}.".format(count, filename))
51 |
52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53 | v[count, :, :] = frame
54 |
55 | v = v.transpose((3, 0, 1, 2))
56 | # print(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), "finished opening vid")
57 |
58 | return v
59 |
60 |
61 | def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
62 | """Saves a video to a file.
63 |
64 | Args:
65 | filename (str): filename of video
66 | array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width)
67 | fps (float or int): frames per second
68 |
69 | Returns:
70 | None
71 | """
72 |
73 | c, _, height, width = array.shape
74 |
75 | if c != 3:
76 | raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
77 | fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
78 | out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
79 |
80 | for frame in array.transpose((1, 2, 3, 0)):
81 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
82 | out.write(frame)
83 |
84 |
85 | def get_mean_and_std(dataset: torch.utils.data.Dataset,
86 | samples: int = 128,
87 | batch_size: int = 8,
88 | num_workers: int = 4):
89 | """Computes mean and std from samples from a Pytorch dataset.
90 |
91 | Args:
92 | dataset (torch.utils.data.Dataset): A Pytorch dataset.
93 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
94 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
95 | samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and
96 | standard deviation are computed over all elements.
97 | Defaults to 128.
98 | batch_size (int, optional): how many samples per batch to load
99 | Defaults to 8.
100 | num_workers (int, optional): how many subprocesses to use for data
101 | loading. If 0, the data will be loaded in the main process.
102 | Defaults to 4.
103 |
104 | Returns:
105 | A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,).
106 | """
107 |
108 | if samples is not None and len(dataset) > samples:
109 | np.random.seed(0)
110 | indices = np.random.choice(len(dataset), samples, replace=False)
111 | dataset = torch.utils.data.Subset(dataset, indices)
112 |
113 | dataloader = torch.utils.data.DataLoader(
114 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
115 |
116 | n = 0 # number of elements taken (should be equal to samples by end of for loop)
117 | s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,))
118 | s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
119 | # for (x, *_) in tqdm.tqdm(dataloader):
120 | for (x,_,*_) in tqdm.tqdm(dataloader):
121 | x = x.transpose(0, 1).contiguous().view(3, -1)
122 | n += x.shape[1]
123 | s1 += torch.sum(x, dim=1).numpy()
124 | s2 += torch.sum(x ** 2, dim=1).numpy()
125 | mean = s1 / n # type: np.ndarray
126 | std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray
127 |
128 | mean = mean.astype(np.float32)
129 | std = std.astype(np.float32)
130 |
131 | return mean, std
132 |
133 |
134 | def bootstrap(a, b, func, samples=10000):
135 | """Computes a bootstrapped confidence intervals for ``func(a, b)''.
136 |
137 | Args:
138 | a (array_like): first argument to `func`.
139 | b (array_like): second argument to `func`.
140 | func (callable): Function to compute confidence intervals for.
141 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
142 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
143 | samples (int, optional): Number of samples to compute.
144 | Defaults to 10000.
145 |
146 | Returns:
147 | A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile).
148 | """
149 | a = np.array(a)
150 | b = np.array(b)
151 |
152 | bootstraps = []
153 | for _ in range(samples):
154 | ind = np.random.choice(len(a), len(a))
155 | bootstraps.append(func(a[ind], b[ind]))
156 | bootstraps = sorted(bootstraps)
157 |
158 | return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]
159 |
160 |
161 | def latexify():
162 | """Sets matplotlib params to appear more like LaTeX.
163 |
164 | Based on https://nipunbatra.github.io/blog/2014/latexify.html
165 | """
166 | params = {'backend': 'pdf',
167 | 'axes.titlesize': 8,
168 | 'axes.labelsize': 8,
169 | 'font.size': 8,
170 | 'legend.fontsize': 8,
171 | 'xtick.labelsize': 8,
172 | 'ytick.labelsize': 8,
173 | 'font.family': 'DejaVu Serif',
174 | 'font.serif': 'Computer Modern',
175 | }
176 | matplotlib.rcParams.update(params)
177 |
178 |
179 | def dice_similarity_coefficient(inter, union):
180 | """Computes the dice similarity coefficient.
181 |
182 | Args:
183 | inter (iterable): iterable of the intersections
184 | union (iterable): iterable of the unions
185 | """
186 | return 2 * sum(inter) / (sum(union) + sum(inter))
187 |
188 |
189 | __all__ = ["video", "segmentation", "seg_ctrmlt", "seg_sslflw", "vidseg", "seg_cycle", "vidsegin_iekd_att", "vidseg_iekd_att_mult", "vidseg_iekd_att_mult_reg", "vidsegin_iekd_att_reg", "video_segin", "vidseg_iekd_att", "video_seginsegonly", "video_segin_hallucinate", "videossl", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient", "vidseg_iekd_ftchck"]
190 |
--------------------------------------------------------------------------------
/echonet/utils/seg_cycle.py:
--------------------------------------------------------------------------------
1 | """Functions for training and running segmentation."""
2 |
3 | import math
4 | import os
5 | import time
6 | import shutil
7 | import datetime
8 | import pandas as pd
9 |
10 | import click
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import scipy.signal
14 | import skimage.draw
15 | from PIL import Image
16 | import torch
17 | import torchvision
18 | import tqdm
19 |
20 | import echonet
21 |
22 |
23 | @click.command("seg_cycle")
24 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
25 | @click.option("--output", type=click.Path(file_okay=False), default=None)
26 | @click.option("--model_name", type=click.Choice(
27 | sorted(name for name in torchvision.models.segmentation.__dict__
28 | if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))),
29 | default="deeplabv3_resnet50")
30 | @click.option("--pretrained/--random", default=False)
31 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
32 | @click.option("--run_test/--skip_test", default=False)
33 | @click.option("--save_video", type=str, default=None)
34 | @click.option("--num_epochs", type=int, default=25)
35 | @click.option("--lr", type=float, default=1e-5)
36 | @click.option("--weight_decay", type=float, default=0)
37 | @click.option("--lr_step_period", type=int, default=None)
38 | @click.option("--num_train_patients", type=int, default=None)
39 | @click.option("--num_workers", type=int, default=4)
40 | @click.option("--batch_size", type=int, default=20)
41 | @click.option("--device", type=str, default=None)
42 | @click.option("--seed", type=int, default=0)
43 | @click.option("--reduced_set/--full_set", default=True)
44 | @click.option("--rd_label", type=int, default=920)
45 | @click.option("--rd_unlabel", type=int, default=6440)
46 | @click.option("--ssl_edesonly/--ssl_rndfrm", default=True)
47 | @click.option("--run_inference", type=str, default=None)
48 | @click.option("--chunk_size", type=int, default=3)
49 | @click.option("--cyc_off", type=int, default=2)
50 | @click.option("--target_region", type=int, default=15)
51 | @click.option("--temperature", type=int, default=10)
52 | @click.option("--val_chunk", type=int, default=40)
53 | @click.option("--loss_cyc_w", type=float, default=1)
54 | @click.option("--css_strtup", type=int, default=0)
55 |
56 | def run(
57 | data_dir=None,
58 | output=None,
59 | model_name="deeplabv3_resnet50",
60 | pretrained=False,
61 | weights=None,
62 | run_test=False,
63 | save_video=None,
64 | num_epochs=25,
65 | lr=1e-5,
66 | weight_decay=1e-5,
67 | lr_step_period=None,
68 | num_train_patients=None,
69 | num_workers=4,
70 | batch_size=20,
71 | device=None,
72 | seed=0,
73 | reduced_set = True,
74 | rd_label = 920,
75 | rd_unlabel = 6440,
76 | ssl_edesonly = True,
77 | run_inference = None,
78 | chunk_size = 3,
79 | cyc_off = 2,
80 | target_region = 15,
81 | temperature = 10,
82 | val_chunk = 40,
83 | loss_cyc_w = 1,
84 | css_strtup = 0
85 | ):
86 |
87 | if reduced_set:
88 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))):
89 | print("Generating new file list for ssl dataset")
90 | np.random.seed(0)
91 |
92 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv"))
93 | data["Split"].map(lambda x: x.upper())
94 |
95 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName'])
96 | np.random.shuffle(file_name_list)
97 |
98 | label_list = file_name_list[:rd_label]
99 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel]
100 |
101 | data['SSL_SPLIT'] = "EXCLUDE"
102 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED"
103 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED"
104 |
105 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False)
106 |
107 |
108 | # Seed RNGs
109 | np.random.seed(seed)
110 | torch.manual_seed(seed)
111 |
112 | def worker_init_fn(worker_id):
113 | # print("worker id is", torch.utils.data.get_worker_info().id)
114 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2
115 | np.random.seed(np.random.get_state()[1][0] + worker_id)
116 |
117 |
118 | # Set default output directory
119 | if output is None:
120 | output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random"))
121 | os.makedirs(output, exist_ok=True)
122 |
123 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
124 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))):
125 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
126 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
127 |
128 |
129 | # Set device for computations
130 | if device is None:
131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132 | elif device == "gpu":
133 | device = torch.device("cuda")
134 | elif device == "cpu":
135 | device = torch.device("cpu")
136 | else:
137 | assert 1==2, "wrong parameter for device"
138 |
139 |
140 |
141 | #### Setup model
142 | model_0 = echonet.segmentation.segmentation.deeplabv3_resnet50_CSS(pretrained=pretrained, aux_loss=False)
143 | model_0.classifier[-1] = torch.nn.Conv2d(model_0.classifier[-1].in_channels, 1, kernel_size=model_0.classifier[-1].kernel_size) # change number of outputs to 1
144 | model_0 = torch.nn.DataParallel(model_0)
145 | model_0.to(device)
146 |
147 | if weights:
148 | checkpoint = torch.load(weights)
149 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False)
150 |
151 | # Set up optimizer
152 | optim_0 = torch.optim.SGD(model_0.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
153 | if lr_step_period is None:
154 | lr_step_period = math.inf
155 | scheduler_0 = torch.optim.lr_scheduler.StepLR(optim_0, lr_step_period)
156 |
157 |
158 | # Compute mean and std
159 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
160 | tasks_eval = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
161 | kwargs_eval = {"target_type": tasks_eval,
162 | "mean": mean,
163 | "std": std
164 | }
165 |
166 | tasks_seg = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
167 | kwargs_seg = {"target_type": tasks_seg,
168 | "mean": mean,
169 | "std": std
170 | }
171 |
172 | kwargs = {"target_type": ["EF", "CYCLE"],
173 | "mean": mean,
174 | "std": std,
175 | "length": 40,
176 | "period": 3,
177 | }
178 |
179 |
180 | dataset = {}
181 | dataset_trainsub = {}
182 | dataset_valsub = {}
183 | if reduced_set:
184 | dataset_trainsub['lb_seg'] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, ssl_edesonly = True)
185 | dataset_trainsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1)
186 | dataset_trainsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2)
187 | else:
188 | assert not ssl_edesonly, "Check parameters, trying to conduct ssl with full datasest with EDES only"
189 | dataset_trainsub['lb_seg'] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs_seg, ssl_postfix="", ssl_type = 0, ssl_edesonly = True)
190 | dataset_trainsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="", ssl_type = 0)
191 | dataset_trainsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="train", **kwargs, ssl_postfix="", ssl_type = 0)
192 | dataset['train'] = dataset_trainsub
193 |
194 |
195 | if reduced_set:
196 | dataset_valsub["lb_seg"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel))
197 | dataset_valsub['lb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel))
198 | dataset_valsub['unlb_cyc'] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel))
199 | else:
200 | assert 1 == 2, "only run with reduced set for now "
201 | dataset["val"] = echonet.datasets.Echo_CSS(root=data_dir, split="val", **kwargs, ssl_postfix="")
202 | dataset['val'] = dataset_valsub
203 |
204 |
205 | # Run training and testing loops
206 | with open(os.path.join(output, "log.csv"), "a") as f:
207 |
208 | f.write("Run timestamp: {}\n".format(bkup_tmstmp))
209 |
210 | epoch_resume = 0
211 | bestLoss = float("inf")
212 | try:
213 | # Attempt to load checkpoint
214 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
215 | print("checkpoint.keys", checkpoint.keys())
216 | model_0.load_state_dict(checkpoint['state_dict'])
217 | optim_0.load_state_dict(checkpoint['opt_dict'])
218 | scheduler_0.load_state_dict(checkpoint['scheduler_dict'])
219 |
220 | np_rndstate_chkpt = checkpoint['np_rndstate']
221 | trch_rndstate_chkpt = checkpoint['trch_rndstate']
222 |
223 | np.random.set_state(np_rndstate_chkpt)
224 | torch.set_rng_state(trch_rndstate_chkpt)
225 |
226 | epoch_resume = checkpoint["epoch"] + 1
227 | bestLoss = checkpoint["best_loss"]
228 | f.write("Resuming from epoch {}\n".format(epoch_resume))
229 | except FileNotFoundError:
230 | f.write("Starting run from scratch\n")
231 |
232 | for epoch in range(epoch_resume, num_epochs):
233 | print("Epoch #{}".format(epoch), flush=True)
234 | for phase in ['train', 'val']:
235 | start_time = time.time()
236 |
237 | if device.type == "cuda":
238 | for i in range(torch.cuda.device_count()):
239 | torch.cuda.reset_peak_memory_stats(i)
240 |
241 | ds = dataset[phase]
242 |
243 | if phase == "train":
244 |
245 | dataloader_lb_seg = torch.utils.data.DataLoader(
246 | ds['lb_seg'], batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
247 | dataloader_lb_cyc = torch.utils.data.DataLoader(
248 | ds['lb_cyc'], batch_size=1, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
249 | dataloader_unlb_cyc = torch.utils.data.DataLoader(
250 | ds['unlb_cyc'], batch_size=1, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
251 |
252 |
253 | loss, loss_seg, lrgdice, smldice, loss_cyc, large_inter_0, large_union_0, small_inter_0, small_union_0 = echonet.utils.seg_cycle.run_epoch_ssl( model_0,
254 | dataloader_lb_seg,
255 | dataloader_lb_cyc,
256 | dataloader_unlb_cyc,
257 | phase == "train",
258 | optim_0,
259 | batch_size,
260 | device,
261 | output,
262 | phase,
263 | mean,
264 | std,
265 | epoch,
266 | chunk_size = chunk_size,
267 | cyc_off = cyc_off,
268 | target_region = target_region,
269 | temperature = temperature,
270 | val_chunk = val_chunk,
271 | loss_cyc_w = loss_cyc_w,
272 | css_strtup = css_strtup
273 | )
274 |
275 |
276 | overall_dice_0 = 2 * (large_inter_0.sum() + small_inter_0.sum()) / (large_union_0.sum() + large_inter_0.sum() + small_union_0.sum() + small_inter_0.sum())
277 | large_dice_0 = 2 * large_inter_0.sum() / (large_union_0.sum() + large_inter_0.sum())
278 | small_dice_0 = 2 * small_inter_0.sum() / (small_union_0.sum() + small_inter_0.sum())
279 |
280 | f.write("{},{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
281 | phase,
282 | loss,
283 | loss_seg,
284 | loss_cyc,
285 | overall_dice_0,
286 | large_dice_0,
287 | small_dice_0,
288 | time.time() - start_time,
289 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
290 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
291 | batch_size))
292 | f.flush()
293 |
294 | else:
295 | dataloader_lb_seg = torch.utils.data.DataLoader(
296 | ds['lb_seg'], batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
297 | dataloader_lb_cyc = torch.utils.data.DataLoader(
298 | ds['lb_cyc'], batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
299 | dataloader_unlb_cyc = torch.utils.data.DataLoader(
300 | ds['unlb_cyc'], batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
301 |
302 |
303 |
304 | loss, loss_seg_val, lrgdice_val, smldice_val, loss_cyc_val, large_inter_0, large_union_0, small_inter_0, small_union_0 = echonet.utils.seg_cycle.run_epoch_ssl( model_0,
305 | dataloader_lb_seg,
306 | dataloader_lb_cyc,
307 | dataloader_unlb_cyc,
308 | phase == "train",
309 | optim_0,
310 | batch_size,
311 | device,
312 | output,
313 | phase,
314 | mean,
315 | std,
316 | epoch,
317 | chunk_size = chunk_size,
318 | cyc_off = cyc_off,
319 | target_region = target_region,
320 | temperature = temperature,
321 | val_chunk = val_chunk,
322 | loss_cyc_w = loss_cyc_w,
323 | css_strtup = css_strtup
324 | )
325 |
326 | overall_dice_0 = 2 * (large_inter_0.sum() + small_inter_0.sum()) / (large_union_0.sum() + large_inter_0.sum() + small_union_0.sum() + small_inter_0.sum())
327 | large_dice_0 = 2 * large_inter_0.sum() / (large_union_0.sum() + large_inter_0.sum())
328 | small_dice_0 = 2 * small_inter_0.sum() / (small_union_0.sum() + small_inter_0.sum())
329 |
330 | f.write("{},{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
331 | phase,
332 | loss,
333 | loss_seg_val,
334 | loss_cyc_val,
335 | overall_dice_0,
336 | large_dice_0,
337 | small_dice_0,
338 | time.time() - start_time,
339 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
340 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
341 | batch_size))
342 |
343 |
344 |
345 | f.flush()
346 |
347 |
348 | scheduler_0.step()
349 |
350 | # Save checkpoint
351 | save = {
352 | 'epoch': epoch,
353 | 'state_dict': model_0.state_dict(),
354 | 'best_loss': bestLoss,
355 | 'loss': loss,
356 | 'opt_dict': optim_0.state_dict(),
357 | 'scheduler_dict': scheduler_0.state_dict(),
358 | 'np_rndstate': np.random.get_state(),
359 | 'trch_rndstate': torch.get_rng_state()
360 | }
361 | torch.save(save, os.path.join(output, "checkpoint.pt"))
362 | if loss_seg_val < bestLoss:
363 | print("saved best because {} < {}".format(loss_seg_val, bestLoss))
364 | torch.save(save, os.path.join(output, "best.pt"))
365 | bestLoss = loss_seg_val
366 |
367 | # Load best weights
368 | if num_epochs != 0:
369 | checkpoint = torch.load(os.path.join(output, "best.pt"))
370 | model_0.load_state_dict(checkpoint['state_dict'])
371 |
372 | f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
373 | f.flush()
374 |
375 | if run_test:
376 | for split in ["val", "test"]:
377 | if reduced_set:
378 | if split == "train":
379 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2)
380 | else:
381 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel))
382 | else:
383 | dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_seg, ssl_postfix="")
384 |
385 | dataloader = torch.utils.data.DataLoader(dataset,
386 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
387 | loss, large_inter, large_union, small_inter, small_union = echonet.utils.seg_cycle.run_epoch(model_0, dataloader, False, None, device)
388 |
389 | overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter)
390 | large_dice = 2 * large_inter / (large_union + large_inter)
391 | small_dice = 2 * small_inter / (small_union + small_inter)
392 | with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g:
393 | g.write("Filename, Overall, Large, Small\n")
394 | for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice):
395 | g.write("{},{},{},{}\n".format(filename, overall, large, small))
396 |
397 | f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient)))
398 | f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient)))
399 | f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient)))
400 | f.flush()
401 |
402 | if run_inference:
403 | if run_inference == "all":
404 | run_inference_range = ['train', 'val', 'test']
405 | else:
406 | run_inference_range = [run_inference]
407 |
408 | for run_inference_itr in run_inference_range:
409 | if run_inference_itr != "train" or True:
410 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr,
411 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate
412 | mean=mean, std=std, # Normalization
413 | length=None, max_length=None, period=1 # Take all frames
414 | )
415 | else:
416 | if reduced_set:
417 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr,
418 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate
419 | mean=mean, std=std, # Normalization
420 | ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, ssl_mult = 1,
421 | length=None, max_length=None, period=1 # Take all frames
422 | )
423 | else:
424 | dataset = echonet.datasets.Echo(root=data_dir, split=run_inference_itr,
425 | target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate
426 | mean=mean, std=std, # Normalization
427 | length=None, max_length=None, period=1 # Take all frames
428 | )
429 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn)
430 |
431 | output_dir = os.path.join(output, "{}_infer_cmpct".format(run_inference_itr))
432 |
433 | os.makedirs(output_dir, exist_ok = True)
434 |
435 | checkpoint = torch.load(os.path.join(output, "best.pt"))
436 |
437 | model_0.load_state_dict(checkpoint['state_dict'])
438 |
439 | model_0.eval()
440 |
441 | with torch.no_grad():
442 | for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader):
443 | # Run segmentation model on blocks of frames one-by-one
444 | # The whole concatenated video may be too long to run together
445 |
446 | print(os.path.join(output_dir, "{}_{}.npy".format(filenames[-1].replace(".avi", ""), length[-1] - 1)))
447 |
448 |
449 | if os.path.isfile(os.path.join(output_dir, "{}.npy".format(filenames[-1].replace(".avi", "")))):
450 | # print("already exists")
451 | continue
452 |
453 | y = np.concatenate([model_0(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)])
454 |
455 | y_idx = 0
456 | for batch_idx in range(len(filenames)):
457 | filename_itr = filenames[batch_idx]
458 |
459 | logit = y[y_idx:y_idx + length[batch_idx], 0, :, :]
460 |
461 | logit_out_path = os.path.join(output_dir, "{}.npy".format(filename_itr.replace(".avi", "")))
462 | np.save(logit_out_path, logit)
463 | y_idx = y_idx + length[batch_idx]
464 |
465 | pass
466 |
467 |
468 |
469 |
470 |
471 |
472 | def run_epoch_ssl(model_0,
473 | dataloader_lb_seg,
474 | dataloader_lb_cyc,
475 | dataloader_unlb_cyc,
476 | train,
477 | optim_0,
478 | batch_size,
479 | device,
480 | output,
481 | phase,
482 | mean,
483 | std,
484 | epoch,
485 | chunk_size = 3,
486 | cyc_off = 2,
487 | target_region = 15,
488 | temperature = 10,
489 | val_chunk = 40,
490 | loss_cyc_w = 1,
491 | css_strtup = 0
492 | ):
493 |
494 |
495 | n = 0
496 | n_seg = 0
497 |
498 | total = 0
499 | total_cyc = 0
500 |
501 | total_seg = 0
502 |
503 | model_0.train(train)
504 | output_dir = os.path.join(output, "{}_feat_comp".format(phase))
505 | os.makedirs(output_dir, exist_ok = True)
506 |
507 | large_inter_0 = 0
508 | large_union_0 = 0
509 | small_inter_0 = 0
510 | small_union_0 = 0
511 | large_inter_list_0 = []
512 | large_union_list_0 = []
513 | small_inter_list_0 = []
514 | small_union_list_0 = []
515 |
516 | torch.set_grad_enabled(train)
517 |
518 | total_itr_num = len(dataloader_lb_seg)
519 |
520 | dataloader_lb_seg_itr = iter(dataloader_lb_seg)
521 | dataloader_unlb_cyc_itr = iter(dataloader_unlb_cyc)
522 |
523 | for train_iter in range(total_itr_num):
524 |
525 | #### Supervised segmentation
526 | _, (large_frame, small_frame, large_trace, small_trace) = dataloader_lb_seg_itr.next()
527 |
528 | large_frame = large_frame.to(device)
529 | large_trace = large_trace.to(device)
530 |
531 | small_frame = small_frame.to(device)
532 | small_trace = small_trace.to(device)
533 |
534 | if not train:
535 | with torch.no_grad():
536 | y_large_0 = model_0(large_frame)["out"]
537 | else:
538 | y_large_0 = model_0(large_frame)["out"]
539 |
540 | loss_large_0 = torch.nn.functional.binary_cross_entropy_with_logits(y_large_0[:, 0, :, :], large_trace, reduction="sum")
541 | # Compute pixel intersection and union between human and computer segmentations
542 | large_inter_0 += np.logical_and(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
543 | large_union_0 += np.logical_or(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
544 | large_inter_list_0.extend(np.logical_and(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
545 | large_union_list_0.extend(np.logical_or(y_large_0[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
546 |
547 | y_small_0 = model_0(small_frame)["out"]
548 | loss_small_0 = torch.nn.functional.binary_cross_entropy_with_logits(y_small_0[:, 0, :, :], small_trace, reduction="sum")
549 | # Compute pixel intersection and union between human and computer segmentations
550 | small_inter_0 += np.logical_and(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
551 | small_union_0 += np.logical_or(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
552 | small_inter_list_0.extend(np.logical_and(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
553 | small_union_list_0.extend(np.logical_or(y_small_0[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
554 |
555 | loss_seg = (loss_large_0 + loss_small_0) / 2
556 |
557 | loss_seg_item = loss_seg.item()
558 | large_trace_size = large_trace.size(0)
559 | total_seg += loss_seg_item * large_trace.size(0)
560 |
561 |
562 |
563 | ### CSS training
564 | X_raw, target, target_iekd, start, video_path, i1, j1 = dataloader_unlb_cyc_itr.next()
565 | X_bfcwh = X_raw.permute(0,2,1,3,4)
566 | X_segfeed = X_bfcwh.reshape(-1, X_bfcwh.shape[2], X_bfcwh.shape[3], X_bfcwh.shape[4])
567 |
568 | ####### get feature output
569 | if not train:
570 | with torch.no_grad():
571 | feat_out = model_0(X_segfeed)['x_layer4'].sum(dim=(2,3))
572 | else:
573 | feat_out = model_0(X_segfeed)['x_layer4'].sum(dim=(2,3))
574 |
575 |
576 | feat_out_query = feat_out[:target_region] # Template region P
577 | feat_out_query_cyc = feat_out[cyc_off:target_region] # Template region with offset
578 | feat_out_key = feat_out[target_region:] # Search region Q
579 |
580 | target_strtpt = np.random.choice(target_region - (chunk_size + cyc_off) + 1) ## choosing p*
581 | target_strtpt_1ht = torch.eye(target_region - (chunk_size + cyc_off) + 1)[target_strtpt]
582 | target_strtpt_1ht = target_strtpt_1ht.to(device)
583 |
584 | query_feat = feat_out_query[target_strtpt:target_strtpt + chunk_size, ...] ### choosing E^p*
585 |
586 | key_size = feat_out_key.shape[0]
587 | feat_size = feat_out.shape[1]
588 |
589 | ### feature-wise distance calculation
590 | dist_mat = feat_out_key.unsqueeze(1).repeat((1,chunk_size, 1)) - query_feat.unsqueeze(1).transpose(0,1).repeat(key_size, 1, 1)
591 | dist_mat_sq = dist_mat.pow(2)
592 | dist_mat_sq_ftsm = dist_mat_sq.sum(dim = -1)
593 |
594 | ### distance calculation per phase
595 | indices_ftsm = torch.arange(chunk_size)
596 | gather_indx_ftsm = torch.arange(key_size).view((key_size, 1)).repeat((1,chunk_size))
597 | gather_indx_shft_ftsm = (gather_indx_ftsm + indices_ftsm) % (key_size) ### gets a index corresponding to the feature vectors included in each phase
598 | gather_indx_shft_ftsm = gather_indx_shft_ftsm.to(device)
599 | dist_mat_sq_shft_ftsm = torch.gather(dist_mat_sq_ftsm, 0, gather_indx_shft_ftsm)[:key_size - (chunk_size + cyc_off) + 1] ### gathers the feature-wise distance values to calculate the distance for the phase
600 | dist_mat_sq_total_ftsm = dist_mat_sq_shft_ftsm.sum(dim=(1))
601 |
602 | ### calculating similarity value
603 | similarity = - dist_mat_sq_total_ftsm
604 | similarity_averaged = similarity / feat_size / chunk_size * temperature
605 | alpha_raw = torch.nn.functional.softmax(similarity_averaged, dim = 0)
606 | alpha_weights = alpha_raw.unsqueeze(1).unsqueeze(1).repeat([1, chunk_size, feat_size])
607 |
608 |
609 | #### calculate shifted phase values
610 | indices_beta = torch.arange(chunk_size).view((1, chunk_size, 1)).repeat((key_size,1, feat_size))
611 | gather_indx_beta = torch.arange(key_size).view((key_size, 1, 1)).repeat((1,chunk_size, feat_size))
612 | gather_indx_alpha_shft = (gather_indx_beta + indices_beta) % (key_size)
613 | gather_indx_alpha_shft = gather_indx_alpha_shft.to(device)
614 | feat_out_key_beta = torch.gather(feat_out_key.unsqueeze(1).repeat(1, chunk_size, 1), 0, gather_indx_alpha_shft)[cyc_off:key_size - chunk_size + 1]
615 |
616 | ### calculate \tilde{E}^{q+c}
617 | weighted_features = alpha_weights * feat_out_key_beta
618 | weighted_features_averaged = weighted_features.sum(dim=0)
619 |
620 |
621 | #### match back to template region and find distance value
622 | q_dist_mat = feat_out_query_cyc.unsqueeze(1).repeat((1,chunk_size, 1)) - weighted_features_averaged.unsqueeze(1).transpose(0,1).repeat((target_region - cyc_off), 1, 1)
623 | q_dist_mat_sq = q_dist_mat.pow(2)
624 | q_dist_mat_sq_ftsm = q_dist_mat_sq.sum(dim = -1)
625 |
626 | indices_query_ftsm = torch.arange(chunk_size)
627 | gather_indx_query_ftsm = torch.arange(target_region - cyc_off).view((target_region - cyc_off, 1)).repeat((1,chunk_size))
628 | gather_indx_query_shft_ftsm = (gather_indx_query_ftsm + indices_query_ftsm) % (target_region - cyc_off)
629 | gather_indx_query_shft_ftsm = gather_indx_query_shft_ftsm.to(device)
630 | q_dist_mat_sq_shft_ftsm = torch.gather(q_dist_mat_sq_ftsm, 0, gather_indx_query_shft_ftsm)[:(target_region - cyc_off) - chunk_size + 1]
631 | q_dist_mat_sq_total_ftsm = q_dist_mat_sq_shft_ftsm.sum(dim=(1))
632 |
633 | ### calculate similarity value
634 | q_similarity = - q_dist_mat_sq_total_ftsm
635 | q_similarity_averaged = q_similarity / feat_size / chunk_size * temperature
636 |
637 | ### calculate cross-entropy loss
638 | frm_prd = torch.argmax(q_similarity_averaged)
639 | frm_lb = torch.argmax(target_strtpt_1ht)
640 |
641 | loss_cyc_raw = torch.nn.functional.cross_entropy(q_similarity_averaged.unsqueeze(0), frm_lb.unsqueeze(0))
642 | loss_cyc_wght = loss_cyc_raw * loss_cyc_w
643 |
644 | loss_cyc_raw_item = loss_cyc_raw.item()
645 | total_cyc += loss_cyc_raw_item
646 |
647 |
648 | if train:
649 | if epoch < css_strtup:
650 | loss_total = loss_seg
651 | else:
652 | loss_total = loss_seg + loss_cyc_wght
653 | optim_0.zero_grad()
654 | loss_total.backward()
655 | optim_0.step()
656 |
657 |
658 | loss_total_item = loss_seg_item + loss_cyc_raw_item * loss_cyc_w
659 |
660 | total += loss_total_item
661 |
662 | n += 1
663 | n_seg += large_trace_size
664 |
665 |
666 | # Show info on process bar
667 | if train_iter % 5 == 0:
668 | print("Itr trainphase {} - {}/{} - ttl {:.4f} ({:.4f}) seg {:.4f} ({:.4f}) dlrg {:.4f} dsml {:.4f} cyc {:.4f} ({:.4f}) ".format(
669 | train,
670 | train_iter,
671 | total_itr_num,
672 | total / n_seg , # total
673 | loss_total_item, # total_item
674 | total_seg / n_seg / 112 / 112 , # total seg
675 | loss_seg_item, # seg item
676 | 2 * large_inter_0 / (large_union_0 + large_inter_0 + 0.000001),
677 | 2 * small_inter_0 / (small_union_0 + small_inter_0 + 0.000001),
678 | total_cyc / n,
679 | loss_cyc_raw_item
680 | ), flush = True)
681 |
682 | large_inter_list_0 = np.array(large_inter_list_0)
683 | large_union_list_0 = np.array(large_union_list_0)
684 | small_inter_list_0 = np.array(small_inter_list_0)
685 | small_union_list_0 = np.array(small_union_list_0)
686 |
687 | return (total / n_seg,
688 | total_seg / n_seg / 112 / 112,
689 | 2 * large_inter_0 / (large_union_0 + large_inter_0 + 0.000001),
690 | 2 * small_inter_0 / (small_union_0 + small_inter_0 + 0.000001),
691 | total_cyc / n,
692 | large_inter_list_0,
693 | large_union_list_0,
694 | small_inter_list_0,
695 | small_union_list_0
696 | )
697 |
698 |
699 |
700 |
701 | def run_epoch(model, dataloader, train, optim, device):
702 | """Run one epoch of training/evaluation for segmentation.
703 |
704 | Args:
705 | model (torch.nn.Module): Model to train/evaulate.
706 | dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
707 | train (bool): Whether or not to train model.
708 | optim (torch.optim.Optimizer): Optimizer
709 | device (torch.device): Device to run on
710 | """
711 |
712 | total = 0.
713 | n = 0
714 |
715 | pos = 0
716 | neg = 0
717 | pos_pix = 0
718 | neg_pix = 0
719 |
720 | model.train(train)
721 |
722 | large_inter = 0
723 | large_union = 0
724 | small_inter = 0
725 | small_union = 0
726 | large_inter_list = []
727 | large_union_list = []
728 | small_inter_list = []
729 | small_union_list = []
730 |
731 | with torch.set_grad_enabled(train):
732 | with tqdm.tqdm(total=len(dataloader)) as pbar:
733 | for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:
734 | # Count number of pixels in/out of human segmentation
735 | pos += (large_trace == 1).sum().item()
736 | pos += (small_trace == 1).sum().item()
737 | neg += (large_trace == 0).sum().item()
738 | neg += (small_trace == 0).sum().item()
739 |
740 | # Count number of pixels in/out of computer segmentation
741 | pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()
742 | pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()
743 | neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()
744 | neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()
745 |
746 | # Run prediction for diastolic frames and compute loss
747 | large_frame = large_frame.to(device)
748 | large_trace = large_trace.to(device)
749 | y_large = model(large_frame)["out"]
750 | loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")
751 | # Compute pixel intersection and union between human and computer segmentations
752 | large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
753 | large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
754 | large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
755 | large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
756 |
757 | # Run prediction for systolic frames and compute loss
758 | small_frame = small_frame.to(device)
759 | small_trace = small_trace.to(device)
760 | y_small = model(small_frame)["out"]
761 | loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")
762 | # Compute pixel intersection and union between human and computer segmentations
763 | small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
764 | small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
765 | small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
766 | small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
767 |
768 | # Take gradient step if training
769 | loss = (loss_large + loss_small) / 2
770 | if train:
771 | optim.zero_grad()
772 | loss.backward()
773 | optim.step()
774 |
775 | # Accumulate losses and compute baselines
776 | total += loss.item()
777 | n += large_trace.size(0)
778 | p = pos / (pos + neg)
779 | p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)
780 |
781 | # Show info on process bar
782 | pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))
783 | pbar.update()
784 |
785 | large_inter_list = np.array(large_inter_list)
786 | large_union_list = np.array(large_union_list)
787 | small_inter_list = np.array(small_inter_list)
788 | small_union_list = np.array(small_union_list)
789 |
790 | return (total / n / 112 / 112,
791 | large_inter_list,
792 | large_union_list,
793 | small_inter_list,
794 | small_union_list,
795 | )
796 |
797 |
798 | def _video_collate_fn(x):
799 | """Collate function for Pytorch dataloader to merge multiple videos.
800 |
801 | This function should be used in a dataloader for a dataset that returns
802 | a video as the first element, along with some (non-zero) tuple of
803 | targets. Then, the input x is a list of tuples:
804 | - x[i][0] is the i-th video in the batch
805 | - x[i][1] are the targets for the i-th video
806 |
807 | This function returns a 3-tuple:
808 | - The first element is the videos concatenated along the frames
809 | dimension. This is done so that videos of different lengths can be
810 | processed together (tensors cannot be "jagged", so we cannot have
811 | a dimension for video, and another for frames).
812 | - The second element is contains the targets with no modification.
813 | - The third element is a list of the lengths of the videos in frames.
814 | """
815 | video, target = zip(*x) # Extract the videos and targets
816 |
817 | # ``video'' is a tuple of length ``batch_size''
818 | # Each element has shape (channels=3, frames, height, width)
819 | # height and width are expected to be the same across videos, but
820 | # frames can be different.
821 |
822 | # ``target'' is also a tuple of length ``batch_size''
823 | # Each element is a tuple of the targets for the item.
824 |
825 | i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames
826 |
827 | # This contatenates the videos along the the frames dimension (basically
828 | # playing the videos one after another). The frames dimension is then
829 | # moved to be first.
830 | # Resulting shape is (total frames, channels=3, height, width)
831 | video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1))
832 |
833 | # Swap dimensions (approximately a transpose)
834 | # Before: target[i][j] is the j-th target of element i
835 | # After: target[i][j] is the i-th target of element j
836 | target = zip(*target)
837 |
838 | return video, target, i
839 |
840 |
--------------------------------------------------------------------------------
/echonet/utils/video_segin.py:
--------------------------------------------------------------------------------
1 | """EF regression from video with Segmentation prediction mask inputs """
2 |
3 |
4 | import math
5 | import os
6 | import time
7 | import shutil
8 | import datetime
9 | import pandas as pd
10 | from PIL import Image
11 |
12 | import click
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import sklearn.metrics
16 | import torch
17 | import torchvision
18 | import tqdm
19 |
20 | import echonet
21 | import echonet.models
22 |
23 | from scipy.special import expit
24 |
25 | @click.command("video_segin")
26 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
27 | @click.option("--output", type=click.Path(file_okay=False), default=None)
28 | @click.option("--task", type=str, default="EF")
29 | @click.option("--model_name", type=click.Choice(['mc3_18', 'r2plus1d_18', 'r3d_18', 'r2plus1d_18_ncor']),
30 | default="r2plus1d_18")
31 | @click.option("--pretrained/--random", default=True)
32 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
33 | @click.option("--run_test/--skip_test", default=False)
34 | @click.option("--num_epochs", type=int, default=30)
35 | @click.option("--lr", type=float, default=1e-4)
36 | @click.option("--weight_decay", type=float, default=1e-4)
37 | @click.option("--lr_step_period", type=int, default=15)
38 | @click.option("--frames", type=int, default=32)
39 | @click.option("--period", type=int, default=2)
40 | @click.option("--num_train_patients", type=int, default=None)
41 | @click.option("--num_workers", type=int, default=4)
42 | @click.option("--batch_size", type=int, default=20)
43 | @click.option("--device", type=str, default=None)
44 | @click.option("--seed", type=int, default=0)
45 | @click.option("--full_test/--quick_test", default=True)
46 | @click.option("--val_samp", type=int, default=3)
47 | @click.option("--reduced_set/--full_set", default=True)
48 | @click.option("--rd_label", type=int, default=100)
49 | @click.option("--rd_unlabel", type=int, default=100)
50 | @click.option("--segsource", type=str, default=None)
51 |
52 | def run(
53 | data_dir=None,
54 | output=None,
55 | task="EF",
56 | model_name="r2plus1d_18",
57 | pretrained=True,
58 | weights=None,
59 | run_test=False,
60 | num_epochs=30,
61 | lr=1e-4,
62 | weight_decay=1e-4,
63 | lr_step_period=15,
64 | frames=32,
65 | period=2,
66 | num_train_patients=None,
67 | num_workers=4,
68 | batch_size=20,
69 | device=None,
70 | seed=0,
71 | full_test = True,
72 | val_samp = 3,
73 | reduced_set = True,
74 | rd_label = 100,
75 | rd_unlabel = 100,
76 | segsource = None
77 | ):
78 |
79 | assert segsource, "for video_segin needs segsource option"
80 |
81 | if reduced_set:
82 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))):
83 | print("Generating new file list for ssl dataset")
84 | np.random.seed(0)
85 |
86 |
87 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv"))
88 | data["Split"].map(lambda x: x.upper())
89 |
90 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName'])
91 | np.random.shuffle(file_name_list)
92 |
93 | label_list = file_name_list[:rd_label]
94 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel]
95 |
96 | data['SSL_SPLIT'] = "EXCLUDE"
97 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED"
98 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED"
99 |
100 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False)
101 |
102 |
103 | # Seed RNGs
104 | np.random.seed(seed)
105 | torch.manual_seed(seed)
106 |
107 | def worker_init_fn(worker_id):
108 | # print("worker id is", torch.utils.data.get_worker_info().id)
109 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2
110 | np.random.seed(np.random.get_state()[1][0] + worker_id)
111 |
112 | # Set default output directory
113 | if output is None:
114 | output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random"))
115 | os.makedirs(output, exist_ok=True)
116 |
117 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
118 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))):
119 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
120 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
121 |
122 | # Set device for computations
123 | if device is None:
124 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125 | elif device == "gpu":
126 | device = torch.device("cuda")
127 | elif device == "cpu":
128 | device = torch.device("cpu")
129 | else:
130 | assert 1==2, "wrong parameter for device"
131 |
132 |
133 | model = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained)
134 | model.fc = torch.nn.Linear(model.fc.in_features, 1)
135 | model.fc.bias.data[0] = 55.6
136 |
137 | model_ref = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained)
138 |
139 | #### add additional channel to pre-trained model
140 | model.stem = torch.nn.Sequential(
141 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7),
142 | stride=(1, 2, 2), padding=(0, 3, 3),
143 | bias=False),
144 | torch.nn.BatchNorm3d(45),
145 | torch.nn.ReLU(inplace=True),
146 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
147 | stride=(1, 1, 1), padding=(1, 0, 0),
148 | bias=False),
149 | torch.nn.BatchNorm3d(64),
150 | torch.nn.ReLU(inplace=True))
151 |
152 | for weight_itr in range(1,6):
153 | model.stem[weight_itr].load_state_dict(model_ref.stem[weight_itr].state_dict())
154 |
155 | model.stem[0].weight.data[:,:3,:,:,:] = model_ref.stem[0].weight.data[:,:,:,:,:]
156 | model.stem[0].weight.data[:,3,:,:,:] = torch.tensor(np.random.uniform(low = -1, high = 1, size = model.stem[0].weight.data[:,3,:,:,:].shape)).float()
157 |
158 | model = torch.nn.DataParallel(model)
159 | model.to(device)
160 |
161 | if weights is not None:
162 | checkpoint = torch.load(weights)
163 | model.load_state_dict(checkpoint['state_dict'])
164 |
165 | # Set up optimizer
166 | optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
167 | if lr_step_period is None:
168 | lr_step_period = math.inf
169 | scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
170 |
171 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
172 | print("mean std", mean, std)
173 | kwargs = {"target_type": task,
174 | "mean": mean,
175 | "std": std,
176 | "length": frames,
177 | "period": period,
178 | }
179 |
180 | # Set up datasets and dataloaders
181 | dataset = {}
182 |
183 | if reduced_set:
184 | dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource))
185 | dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource))
186 | else:
187 | dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="", segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource))
188 | dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource))
189 |
190 | # Run training and testing loops
191 | with open(os.path.join(output, "log.csv"), "a") as f:
192 |
193 | f.write("Run timestamp: {}\n".format(bkup_tmstmp))
194 |
195 | epoch_resume = 0
196 | bestLoss = float("inf")
197 | try:
198 | # Attempt to load checkpoint
199 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
200 | model.load_state_dict(checkpoint['state_dict'], strict = False)
201 | optim.load_state_dict(checkpoint['opt_dict'])
202 | scheduler.load_state_dict(checkpoint['scheduler_dict'])
203 |
204 | np_rndstate_chkpt = checkpoint['np_rndstate']
205 | trch_rndstate_chkpt = checkpoint['trch_rndstate']
206 |
207 | np.random.set_state(np_rndstate_chkpt)
208 | torch.set_rng_state(trch_rndstate_chkpt)
209 |
210 | epoch_resume = checkpoint["epoch"] + 1
211 | bestLoss = checkpoint["best_loss"]
212 | f.write("Resuming from epoch {}\n".format(epoch_resume))
213 | except FileNotFoundError:
214 | f.write("Starting run from scratch\n")
215 |
216 |
217 | for epoch in range(epoch_resume, num_epochs):
218 | print("Epoch #{}".format(epoch), flush=True)
219 | for phase in ['train', 'val']:
220 |
221 | start_time = time.time()
222 |
223 | if device.type == "cuda":
224 | for i in range(torch.cuda.device_count()):
225 | torch.cuda.reset_peak_memory_stats(i)
226 |
227 | if phase == "train":
228 | ds = dataset[phase]
229 | dataloader = torch.utils.data.DataLoader(
230 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn)
231 |
232 | loss, loss_reg, loss_ctr, yhat, y, _, _ = echonet.utils.video_segin.run_epoch(model,
233 | dataloader,
234 | phase == "train",
235 | optim,
236 | device)
237 |
238 | r2_value = sklearn.metrics.r2_score(y, yhat)
239 |
240 | f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
241 | phase,
242 | loss,
243 | r2_value,
244 | time.time() - start_time,
245 | y.size,
246 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
247 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
248 | batch_size,
249 | loss_reg,
250 | loss_ctr))
251 | f.flush()
252 |
253 |
254 | else:
255 | ### for validation
256 | ### store seeds
257 | np_rndstate = np.random.get_state()
258 | trch_rndstate = torch.get_rng_state()
259 |
260 | r2_track = []
261 | loss_track = []
262 | lossreg_track = []
263 | losscor_track = []
264 |
265 |
266 | for val_samp_itr in range(val_samp):
267 |
268 | print("running validation batch for seed =", val_samp_itr)
269 |
270 | np.random.seed(val_samp_itr)
271 | torch.manual_seed(val_samp_itr)
272 |
273 | ds = dataset[phase]
274 | dataloader = torch.utils.data.DataLoader(
275 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
276 |
277 | loss_valit, loss_reg_valit, loss_ctr_valit, yhat, y, _, _ = echonet.utils.video_segin.run_epoch(model,
278 | dataloader,
279 | phase == "train",
280 | optim,
281 | device)
282 |
283 | r2_track.append(sklearn.metrics.r2_score(y, yhat))
284 | loss_track.append(loss_valit)
285 | lossreg_track.append(loss_reg_valit)
286 | losscor_track.append(loss_ctr_valit)
287 |
288 | r2_value = np.average(np.array(r2_track))
289 | loss = np.average(np.array(loss_track))
290 | lossreg = np.average(np.array(lossreg_track))
291 | losscor = np.average(np.array(losscor_track))
292 |
293 | f.write("{},{},{},{},{},{},{},{},{},{},{}".format(epoch,
294 | phase,
295 | loss,
296 | r2_value,
297 | time.time() - start_time,
298 | y.size,
299 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
300 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
301 | batch_size,
302 | lossreg,
303 | losscor))
304 |
305 | for trck_write in range(len(r2_track)):
306 | f.write(",{}".format(r2_track[trck_write]))
307 |
308 | for trck_write in range(len(loss_track)):
309 | f.write(",{}".format(loss_track[trck_write]))
310 |
311 | f.write("\n")
312 | f.flush()
313 |
314 | np.random.set_state(np_rndstate)
315 | torch.set_rng_state(trch_rndstate)
316 |
317 |
318 | scheduler.step()
319 |
320 | # Save checkpoint
321 | save = {
322 | 'epoch': epoch,
323 | 'state_dict': model.state_dict(),
324 | 'period': period,
325 | 'frames': frames,
326 | 'best_loss': bestLoss,
327 | 'loss': loss,
328 | 'r2': r2_value,
329 | 'opt_dict': optim.state_dict(),
330 | 'scheduler_dict': scheduler.state_dict(),
331 | 'np_rndstate': np.random.get_state(),
332 | 'trch_rndstate': torch.get_rng_state()
333 | }
334 | torch.save(save, os.path.join(output, "checkpoint.pt"))
335 |
336 | if lossreg < bestLoss:
337 | print("saved best because {} < {}".format(lossreg, bestLoss))
338 | torch.save(save, os.path.join(output, "best.pt"))
339 | bestLoss = lossreg
340 |
341 |
342 | if num_epochs != 0:
343 | checkpoint = torch.load(os.path.join(output, "best.pt"))
344 | model.load_state_dict(checkpoint['state_dict'], strict = False)
345 | f.write("Best validation loss {} from epoch {}, R2 {}\n".format(checkpoint["loss"], checkpoint["epoch"], checkpoint["r2"]))
346 | f.flush()
347 |
348 | if run_test:
349 | # for split in ["val", "test"]:
350 | for split in ["test", "val"]:
351 | # Performance without test-time augmentation
352 |
353 | if not full_test:
354 |
355 | for seed_itr in range(5):
356 | np.random.seed(seed_itr)
357 | torch.manual_seed(seed_itr)
358 |
359 | if reduced_set:
360 | dataloader = torch.utils.data.DataLoader(
361 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)),
362 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn)
363 | else:
364 | dataloader = torch.utils.data.DataLoader(
365 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)),
366 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn)
367 |
368 |
369 | loss, loss_reg, loss_ctr, yhat, y, start_frame_record, vidpath_record = echonet.utils.video_segin.run_epoch(model,
370 | dataloader,
371 | False,
372 | None,
373 | device,
374 | run_dir = output,
375 | test_val = split)
376 |
377 | f.write("Seed is {}\n".format(seed_itr))
378 | f.write("{} - {} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
379 | f.write("{} - {} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
380 | f.write("{} - {} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
381 | f.flush()
382 |
383 | with open(os.path.join(output, "z_{}_{}_s{}_strtfrmchk.csv".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, seed_itr)), "a") as f_start_frame:
384 | for frame_itr in start_frame_record:
385 | f_start_frame.write("{}\n".format(frame_itr))
386 | f_start_frame.flush()
387 |
388 | with open(os.path.join(output, "z_{}_{}_s{}_vidpthchk.csv".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, seed_itr)), "a") as f_vidpath:
389 | for vidpath_itr in vidpath_record:
390 | f_vidpath.write("{}\n".format(vidpath_itr))
391 | f_vidpath.flush()
392 |
393 | else:
394 | # Performance with test-time augmentation
395 | if reduced_set:
396 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all", ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split))
397 | else:
398 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all", ssl_postfix="", segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split))
399 |
400 | yhat, y = echonet.utils.video_segin.test_epoch_all(model,
401 | ds,
402 | False,
403 | None,
404 | device,
405 | save_all=True,
406 | block_size=batch_size,
407 | run_dir = output,
408 | test_val = split,
409 | **kwargs,
410 | segsource = segsource)
411 |
412 | f.write("Seed is {} \n".format(seed))
413 | f.write("{} - {} (all clips, mod) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
414 | f.write("{} - {} (all clips, mod) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
415 | f.write("{} - {} (all clips, mod) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
416 | f.flush()
417 |
418 |
419 |
420 |
421 | def test_epoch_all(model, dataset, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, target_type = None, mean = None, std = None, length = None, period = None, segsource = None):
422 |
423 | assert segsource, "need to feed segsource argument to test_epoch_all"
424 |
425 | model.train(False)
426 |
427 | total = 0 # total training loss
428 | total_reg = 0
429 | total_ncor = 0
430 |
431 | n = 0 # number of videos processed
432 | s1 = 0 # sum of ground truth EF
433 | s2 = 0 # Sum of ground truth EF squared
434 |
435 | yhat = []
436 | y = []
437 |
438 | #### some params in the dataloader
439 |
440 | if (mean is None) or (std is None) or (length is None) or (period is None):
441 | assert 1==2, "missing key params"
442 |
443 | max_length = 250
444 |
445 | if run_dir:
446 |
447 | temp_savefile = os.path.join(run_dir, "temp_inference_{}.csv".format(test_val))
448 |
449 | with torch.set_grad_enabled(False):
450 | orig_filelist = dataset.fnames
451 |
452 | if os.path.isfile(temp_savefile):
453 | exist_data = pd.read_csv(temp_savefile)
454 | exist_file = list(exist_data['fnames'])
455 | target_filelist = sorted(list(set(orig_filelist) - set(exist_file)))
456 | else:
457 | target_filelist = sorted(list(orig_filelist))
458 | exist_data = pd.DataFrame(columns = ['fnames', 'yhat'])
459 |
460 | for filelistitr_idx in range(len(target_filelist)):
461 | filelistitr = target_filelist[filelistitr_idx]
462 |
463 | video_path = os.path.join(echonet.config.DATA_DIR, "Videos", filelistitr)
464 | ### Get data
465 | video = echonet.utils.loadvideo(video_path).astype(np.float32)
466 |
467 | seg_infer_path = os.path.join("../infer_buffers/{}/{}_infer_cmpct".format(segsource, test_val), filelistitr.replace(".avi", ".npy"))
468 | seg_infer_logits = np.load(seg_infer_path)
469 | seg_infer_probs = expit(seg_infer_logits)
470 | seg_infer_prob_norm = seg_infer_probs * 2 - 1
471 |
472 | seg_infer_prob_norm = np.expand_dims(seg_infer_prob_norm, axis=0)
473 |
474 | if isinstance(mean, (float, int)):
475 | video -= mean
476 | else:
477 | video -= mean.reshape(3, 1, 1, 1)
478 |
479 | if isinstance(std, (float, int)):
480 | video /= std
481 | else:
482 | video /= std.reshape(3, 1, 1, 1)
483 |
484 | c, f, h, w = video.shape
485 | if length is None:
486 | # Take as many frames as possible
487 | length = f // period
488 | else:
489 | # Take specified number of frames
490 | length = length
491 |
492 | if max_length is not None:
493 | # Shorten videos to max_length
494 | length = min(length, max_length)
495 |
496 | f_old = f
497 |
498 | if f < length * period:
499 | # Pad video with frames filled with zeros if too short
500 | # 0 represents the mean color (dark grey), since this is after normalization
501 | video = np.concatenate((video, np.zeros((c, length * period - f, h, w), video.dtype)), axis=1)
502 | seg_infer_prob_norm = np.concatenate((seg_infer_prob_norm, np.ones((1, length * period - f, h, w), video.dtype) * -1) , axis=1)
503 | c, f, h, w = video.shape # pylint: disable=E0633
504 |
505 | start = np.arange(f - (length - 1) * period)
506 | #### Do looping starting from here
507 |
508 | reg1 = []
509 | n_clips = start.shape[0]
510 | batch = 1
511 | for s_itr in range(0, start.shape[0], block_size):
512 | print("{}, processing file {} out of {}, block {} out of {}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), filelistitr_idx, len(target_filelist), s_itr, start.shape[0]), flush=True)
513 | # print("s range", start[s_itr: s_itr + block_size])
514 | # print("frame range", s + period * np.arange(length))
515 | vid_samp = tuple(video[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size])
516 | seg_infer_samp = tuple(seg_infer_prob_norm[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size])
517 |
518 | vid_in = np.concatenate((np.stack(vid_samp), np.stack(seg_infer_samp)), axis=1)
519 |
520 | X1 = torch.tensor(np.stack(vid_in))
521 | if X1.dtype == torch.double:
522 | X1 = X1.float()
523 |
524 | X1 = X1.to(device)
525 |
526 | if device.type == "cuda":
527 | all_output = model(X1)
528 | else:
529 | #### we only ever use cpu for testing
530 | all_output = torch.ones((X1.shape[0]))
531 |
532 | reg1.append(all_output.detach().cpu().numpy())
533 |
534 | reg1 = np.vstack(reg1)
535 | reg1_mean = reg1.reshape(batch, n_clips, -1).mean(1)
536 |
537 | exist_data = exist_data.append({'fnames':filelistitr, 'yhat':reg1_mean[0,0]}, ignore_index=True)
538 |
539 | if filelistitr_idx % 20 == 0:
540 | exist_data.to_csv(temp_savefile, index = False)
541 |
542 | label_data_path = os.path.join(echonet.config.DATA_DIR, "FileList.csv")
543 | label_data = pd.read_csv(label_data_path)
544 | label_data_select = label_data[['FileName','EF']]
545 | label_data_select.columns = ['fnames','EF']
546 | with_predict = exist_data.merge(label_data_select, on='fnames')
547 |
548 | predict_out_path = os.path.join(run_dir, "{}_predictions.csv".format(test_val))
549 | with_predict.to_csv(predict_out_path, index=False)
550 |
551 |
552 | return with_predict['yhat'].to_numpy(), with_predict['EF'].to_numpy()
553 |
554 |
555 | def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None):
556 |
557 | model.train(train)
558 |
559 | total = 0 # total training loss
560 | total_reg = 0
561 | total_ncor = 0
562 |
563 | n = 0 # number of videos processed
564 | s1 = 0 # sum of ground truth EF
565 | s2 = 0 # Sum of ground truth EF squared
566 |
567 | yhat = []
568 | y = []
569 | start_frame_record = []
570 | vidpath_record = []
571 |
572 | with torch.set_grad_enabled(train):
573 | with tqdm.tqdm(total=len(dataloader)) as pbar:
574 | # samples_cnt = 0
575 | for (X, outcome, start_frame, video_path, _, _) in dataloader:
576 |
577 | if not train:
578 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy())
579 | vidpath_record.append(video_path)
580 |
581 | y.append(outcome.detach().cpu().numpy())
582 |
583 | if X.dtype == torch.double:
584 | X = X.float()
585 |
586 | X = X.to(device)
587 |
588 | outcome = outcome.to(device)
589 |
590 | s1 += outcome.sum()
591 | s2 += (outcome ** 2).sum()
592 |
593 | assert block_size is None, "block_size should be none, not used"
594 |
595 | if device.type == "cuda":
596 | all_output = model(X)
597 | else:
598 | ### We only ever use cpu for testing
599 | all_output = model(X)
600 |
601 |
602 | loss_cor_item = 0
603 | total_ncor = 0
604 |
605 | loss_reg = torch.nn.functional.mse_loss(all_output.view(-1), outcome)
606 | loss = loss_reg
607 |
608 | yhat.append(all_output.view(-1).to("cpu").detach().numpy())
609 |
610 | if train:
611 | optim.zero_grad()
612 | loss.backward()
613 | optim.step()
614 |
615 | total += loss.item() * outcome.size(0)
616 | total_reg += loss_reg.item() * outcome.size(0)
617 |
618 | n += outcome.size(0)
619 |
620 | pbar.set_postfix_str("{:.2f} {:.2f} {:.4f} ({:.2f}) / {:.2f} {}".format(total / n, loss_reg.item(), loss_cor_item, loss.item(), s2 / n - (s1 / n) ** 2, 0))
621 | pbar.update()
622 |
623 | if not save_all:
624 | yhat = np.concatenate(yhat)
625 | if not train:
626 | start_frame_record = np.concatenate(start_frame_record)
627 |
628 | y = np.concatenate(y)
629 |
630 | return total / n, total_reg / n, total_ncor / n, yhat, y, start_frame_record, vidpath_record
631 |
632 |
633 |
634 |
--------------------------------------------------------------------------------
/echonet/utils/vidsegin_teachstd_kd.py:
--------------------------------------------------------------------------------
1 | """Teacher Student Distillation"""
2 |
3 | import math
4 | import os
5 | import time
6 | import shutil
7 | import datetime
8 | import pandas as pd
9 | import cv2
10 |
11 | import click
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import sklearn.metrics
15 | import torch
16 | import torchvision
17 | import tqdm
18 | import subprocess
19 |
20 | import echonet
21 | import echonet.models
22 |
23 | @click.command("vidsegin_teachstd_kd")
24 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
25 | @click.option("--output", type=click.Path(file_okay=False), default=None)
26 | @click.option("--task", type=str, default="EF")
27 | @click.option("--model_name", type=click.Choice(['mc3_18', 'r2plus1d_18', "r2plus1d_18_segin", 'r3d_18', 'r2plus1d_18_ncor']),
28 | default="r2plus1d_18")
29 | @click.option("--pretrained/--random", default=True)
30 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
31 | @click.option("--weights_0", type=click.Path(exists=True, dir_okay=False), default=None)
32 | @click.option("--run_test/--skip_test", default=False)
33 | @click.option("--num_epochs", type=int, default=30)
34 | @click.option("--lr", type=float, default=1e-4)
35 | @click.option("--weight_decay", type=float, default=1e-4)
36 | @click.option("--lr_step_period", type=int, default=15)
37 | @click.option("--frames", type=int, default=32)
38 | @click.option("--period", type=int, default=2)
39 | @click.option("--num_train_patients", type=int, default=None)
40 | @click.option("--num_workers", type=int, default=4)
41 | @click.option("--batch_size", type=int, default=20)
42 | @click.option("--device", type=str, default=None)
43 | @click.option("--seed", type=int, default=0)
44 | @click.option("--full_test/--quick_test", default=False)
45 | @click.option("--val_samp", type=int, default=3)
46 | @click.option("--reduced_set/--full_set", default=True)
47 | @click.option("--rd_label", type=int, default=920)
48 | @click.option("--rd_unlabel", type=int, default=6440)
49 | @click.option("--max_block", type=int, default=20)
50 | @click.option("--segsource", type=str, default=None)
51 | @click.option("--w_unlb", type=float, default=2.5)
52 | @click.option("--batch_size_unlb", type=int, default=10)
53 | @click.option("--notcamus/--camus", default=True)
54 |
55 | def run(
56 | data_dir=None,
57 | output=None,
58 | task="EF",
59 | model_name="r2plus1d_18",
60 | pretrained=True,
61 | weights=None,
62 | weights_0=None,
63 | run_test=False,
64 | num_epochs=30,
65 | lr=1e-4,
66 | weight_decay=1e-4,
67 | lr_step_period=15,
68 | frames=32,
69 | period=2,
70 | num_train_patients=None,
71 | num_workers=4,
72 | batch_size=20,
73 | device=None,
74 | seed=0,
75 | full_test = False,
76 | val_samp = 3,
77 | reduced_set = True,
78 | rd_label = 920,
79 | rd_unlabel = 6440,
80 | max_block = 20,
81 | segsource = None,
82 | w_unlb = 2.5,
83 | batch_size_unlb=10,
84 | notcamus = True
85 | ):
86 |
87 | assert segsource, "function needs segsource option"
88 |
89 |
90 | if reduced_set:
91 | if not os.path.isfile(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel))):
92 | print("Generating new file list for ssl dataset")
93 | np.random.seed(0)
94 |
95 | data = pd.read_csv(os.path.join(echonet.config.DATA_DIR, "FileList.csv"))
96 | data["Split"].map(lambda x: x.upper())
97 |
98 | file_name_list = np.array(data[data['Split']== 'TRAIN']['FileName'])
99 | np.random.shuffle(file_name_list)
100 |
101 | label_list = file_name_list[:rd_label]
102 | unlabel_list = file_name_list[rd_label:rd_label + rd_unlabel]
103 |
104 | data['SSL_SPLIT'] = "EXCLUDE"
105 | data.loc[data['FileName'].isin(label_list), 'SSL_SPLIT'] = "LABELED"
106 | data.loc[data['FileName'].isin(unlabel_list), 'SSL_SPLIT'] = "UNLABELED"
107 |
108 | data.to_csv(os.path.join(echonet.config.DATA_DIR, "FileList_ssl_{}_{}.csv".format(rd_label, rd_unlabel)),index = False)
109 |
110 |
111 | # Seed RNGs
112 | np.random.seed(seed)
113 | torch.manual_seed(seed)
114 |
115 | def worker_init_fn(worker_id):
116 | # print("worker id is", torch.utils.data.get_worker_info().id)
117 | # https://discuss.pytorch.org/t/in-what-order-do-dataloader-workers-do-their-job/88288/2
118 | np.random.seed(np.random.get_state()[1][0] + worker_id)
119 |
120 | # Set default output directory
121 | if output is None:
122 | output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random"))
123 | os.makedirs(output, exist_ok=True)
124 |
125 | bkup_tmstmp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
126 | if os.path.isdir(os.path.join(output, "echonet_{}".format(bkup_tmstmp))):
127 | shutil.rmtree(os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
128 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(bkup_tmstmp)))
129 |
130 | # Set device for computations
131 | if device is None:
132 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133 | elif device == "gpu":
134 | device = torch.device("cuda")
135 | elif device == "cpu":
136 | device = torch.device("cpu")
137 | else:
138 | assert 1==2, "wrong parameter for device"
139 |
140 | model_0 = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained)
141 | model_0.fc = torch.nn.Linear(model_0.fc.in_features, 1)
142 | model_0.fc.bias.data[0] = 55.6
143 |
144 | model_ref = echonet.models.rnet2dp1.r2plus1d_18(pretrained=pretrained)
145 |
146 | model_0.stem = torch.nn.Sequential(
147 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7),
148 | stride=(1, 2, 2), padding=(0, 3, 3),
149 | bias=False),
150 | torch.nn.BatchNorm3d(45),
151 | torch.nn.ReLU(inplace=True),
152 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
153 | stride=(1, 1, 1), padding=(1, 0, 0),
154 | bias=False),
155 | torch.nn.BatchNorm3d(64),
156 | torch.nn.ReLU(inplace=True))
157 |
158 | for weight_itr in range(1,6):
159 | model_0.stem[weight_itr].load_state_dict(model_ref.stem[weight_itr].state_dict())
160 |
161 | model_0.stem[0].weight.data[:,:3,:,:,:] = model_ref.stem[0].weight.data[:,:,:,:,:]
162 | model_0.stem[0].weight.data[:,3,:,:,:] = torch.tensor(np.random.uniform(low = -1, high = 1, size = model_0.stem[0].weight.data[:,3,:,:,:].shape)).float()
163 |
164 | model_0 = torch.nn.DataParallel(model_0)
165 |
166 | model_0.to(device)
167 |
168 |
169 | if weights is not None:
170 | checkpoint = torch.load(weights)
171 | model.load_state_dict(checkpoint['state_dict'])
172 |
173 | ### we initialize teacher and student weights.
174 | if weights_0 is not None:
175 | checkpoint_0 = torch.load(weights_0)
176 | if checkpoint_0.get("state_dict_0"):
177 | ## initialize teacher weights
178 | print("loading from state_dict_0")
179 | model_0.load_state_dict(checkpoint_0['state_dict_0'])
180 |
181 | ## initialize student weights where transferable to speed up training
182 | state_dict = checkpoint_0['state_dict_0']
183 | from collections import OrderedDict
184 | new_state_dict = OrderedDict()
185 | for k, v in state_dict.items():
186 | name = k[7:] # remove `module.`
187 | new_state_dict[name] = v
188 |
189 | model = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained)
190 | model.fc = torch.nn.Linear(model.fc.in_features, 1)
191 | model.fc.bias.data[0] = 55.6
192 |
193 | model.stem = torch.nn.Sequential(
194 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7),
195 | stride=(1, 2, 2), padding=(0, 3, 3),
196 | bias=False),
197 | torch.nn.BatchNorm3d(45),
198 | torch.nn.ReLU(inplace=True),
199 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
200 | stride=(1, 1, 1), padding=(1, 0, 0),
201 | bias=False),
202 | torch.nn.BatchNorm3d(64),
203 | torch.nn.ReLU(inplace=True))
204 |
205 | model.load_state_dict(new_state_dict)
206 |
207 | model.stem = torch.nn.Sequential(
208 | torch.nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
209 | stride=(1, 2, 2), padding=(0, 3, 3),
210 | bias=False),
211 | torch.nn.BatchNorm3d(45),
212 | torch.nn.ReLU(inplace=True),
213 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
214 | stride=(1, 1, 1), padding=(1, 0, 0),
215 | bias=False),
216 | torch.nn.BatchNorm3d(64),
217 | torch.nn.ReLU(inplace=True))
218 |
219 | for weight_itr in range(1,6):
220 | model.stem[weight_itr].load_state_dict(model_0.module.stem[weight_itr].state_dict())
221 |
222 | model.stem[0].weight.data[:,:3,:,:,:] = model_0.module.stem[0].weight.data[:,:3,:,:,:]
223 |
224 | model = torch.nn.DataParallel(model)
225 | model.to(device)
226 |
227 | elif checkpoint_0.get("state_dict"):
228 | ## initialize teacher weights
229 | print("loading from state_dict")
230 | model_0.load_state_dict(checkpoint_0['state_dict'])
231 |
232 | ## initialize student weights where transferable to speed up training
233 | state_dict = checkpoint_0['state_dict']
234 | from collections import OrderedDict
235 | new_state_dict = OrderedDict()
236 | for k, v in state_dict.items():
237 | name = k[7:] # remove `module.`
238 | new_state_dict[name] = v
239 |
240 | model = echonet.models.rnet2dp1.r2plus1d_18_kd(pretrained=pretrained)
241 | model.fc = torch.nn.Linear(model.fc.in_features, 1)
242 | model.fc.bias.data[0] = 55.6
243 |
244 | model.stem = torch.nn.Sequential(
245 | torch.nn.Conv3d(4, 45, kernel_size=(1, 7, 7),
246 | stride=(1, 2, 2), padding=(0, 3, 3),
247 | bias=False),
248 | torch.nn.BatchNorm3d(45),
249 | torch.nn.ReLU(inplace=True),
250 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
251 | stride=(1, 1, 1), padding=(1, 0, 0),
252 | bias=False),
253 | torch.nn.BatchNorm3d(64),
254 | torch.nn.ReLU(inplace=True))
255 |
256 | model.load_state_dict(new_state_dict)
257 |
258 | model.stem = torch.nn.Sequential(
259 | torch.nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
260 | stride=(1, 2, 2), padding=(0, 3, 3),
261 | bias=False),
262 | torch.nn.BatchNorm3d(45),
263 | torch.nn.ReLU(inplace=True),
264 | torch.nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
265 | stride=(1, 1, 1), padding=(1, 0, 0),
266 | bias=False),
267 | torch.nn.BatchNorm3d(64),
268 | torch.nn.ReLU(inplace=True))
269 |
270 | for weight_itr in range(1,6):
271 | model.stem[weight_itr].load_state_dict(model_0.module.stem[weight_itr].state_dict())
272 |
273 | model.stem[0].weight.data[:,:3,:,:,:] = model_0.module.stem[0].weight.data[:,:3,:,:,:]
274 |
275 | model = torch.nn.DataParallel(model)
276 | model.to(device)
277 | else:
278 | assert 1==2, "missing key"
279 |
280 |
281 | optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
282 | if lr_step_period is None:
283 | lr_step_period = math.inf
284 | scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
285 |
286 | optim_0 = torch.optim.SGD(model_0.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
287 | if lr_step_period is None:
288 | lr_step_period = math.inf
289 | scheduler_0 = torch.optim.lr_scheduler.StepLR(optim_0, lr_step_period)
290 |
291 |
292 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo_tskd(root=data_dir, split="train"))
293 | print("mean std", mean, std)
294 | kwargs = {"target_type": ["EF", "IEKD"],
295 | "mean": mean,
296 | "std": std,
297 | "length": frames,
298 | "period": period,
299 | }
300 |
301 | kwargs_testall = {"target_type": "EF",
302 | "mean": mean,
303 | "std": std,
304 | "length": frames,
305 | "period": period,
306 | }
307 |
308 | dataset = {}
309 | dataset_trainsub = {}
310 | if reduced_set:
311 | dataset_trainsub['lb'] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 1, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource))
312 | dataset_trainsub["unlb_0"] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), ssl_type = 2, segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource))
313 | dataset["train"] = dataset_trainsub
314 | dataset["val"] = echonet.datasets.Echo_tskd(root=data_dir, split="val", **kwargs, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource))
315 | else:
316 | dataset["train"] = echonet.datasets.Echo_tskd(root=data_dir, split="train", **kwargs, pad=12, ssl_postfix="", segin_dir = "../infer_buffers/{}/train_infer_cmpct".format(segsource))
317 | dataset["val"] = echonet.datasets.Echo_tskd(root=data_dir, split="val", **kwargs, ssl_postfix="", segin_dir = "../infer_buffers/{}/val_infer_cmpct".format(segsource))
318 |
319 |
320 | # Run training and testing loops
321 | with open(os.path.join(output, "log.csv"), "a") as f:
322 |
323 | f.write("Run timestamp: {}\n".format(bkup_tmstmp))
324 |
325 | epoch_resume = 0
326 | bestLoss = float("inf")
327 | try:
328 | # Attempt to load checkpoint
329 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
330 | model.load_state_dict(checkpoint['state_dict'], strict = False)
331 | optim.load_state_dict(checkpoint['opt_dict'])
332 | scheduler.load_state_dict(checkpoint['scheduler_dict'])
333 |
334 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False)
335 | optim_0.load_state_dict(checkpoint['opt_dict_0'])
336 | scheduler_0.load_state_dict(checkpoint['scheduler_dict_0'])
337 |
338 | np_rndstate_chkpt = checkpoint['np_rndstate']
339 | trch_rndstate_chkpt = checkpoint['trch_rndstate']
340 |
341 | np.random.set_state(np_rndstate_chkpt)
342 | torch.set_rng_state(trch_rndstate_chkpt)
343 |
344 | epoch_resume = checkpoint["epoch"] + 1
345 | bestLoss = checkpoint["best_loss"]
346 | f.write("Resuming from epoch {}\n".format(epoch_resume))
347 | except FileNotFoundError:
348 | f.write("Starting run from scratch\n")
349 |
350 |
351 | for epoch in range(epoch_resume, num_epochs):
352 | print("Epoch #{}".format(epoch), flush=True)
353 | for phase in ['train', 'val']:
354 |
355 | start_time = time.time()
356 |
357 | if device.type == "cuda":
358 | for i in range(torch.cuda.device_count()):
359 | torch.cuda.reset_peak_memory_stats(i)
360 |
361 | if phase == "train":
362 | ds_lb = dataset[phase]['lb']
363 | dataloader_lb = torch.utils.data.DataLoader(
364 | ds_lb, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn)
365 |
366 | ds_unlb_0 = dataset[phase]['unlb_0']
367 | dataloader_unlb_0 = torch.utils.data.DataLoader(
368 | ds_unlb_0, batch_size=batch_size_unlb, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn)
369 |
370 |
371 | total, total_reg, total_reg_reg, total_reg_unlb, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch(model = model,
372 | model_0 = model_0,
373 | dataloader = dataloader_lb,
374 | dataloader_unlb_0 = dataloader_unlb_0,
375 | train = phase == "train",
376 | optim = optim,
377 | optim_0 = optim_0,
378 | device = device,
379 | w_unlb = w_unlb)
380 |
381 |
382 | r2_value = sklearn.metrics.r2_score(y, yhat)
383 |
384 | f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
385 | phase,
386 | total,
387 | total_reg_reg,
388 | r2_value,
389 | total_reg_unlb,
390 | time.time() - start_time,
391 | y.size,
392 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
393 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
394 | batch_size))
395 | f.flush()
396 |
397 | # print("successful run until exit")
398 | # exit()
399 |
400 | else:
401 | ### for validation
402 | ### store seeds
403 | np_rndstate = np.random.get_state()
404 | trch_rndstate = torch.get_rng_state()
405 |
406 | r2_track = []
407 | lossreg_track = []
408 |
409 |
410 | for val_samp_itr in range(val_samp):
411 |
412 | print("running validation batch for seed =", val_samp_itr)
413 |
414 | np.random.seed(val_samp_itr)
415 | torch.manual_seed(val_samp_itr)
416 |
417 | ds = dataset[phase]
418 | dataloader = torch.utils.data.DataLoader(
419 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
420 |
421 |
422 | total, total_reg, total_reg_reg, total_reg_unlb, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch(model = model,
423 | model_0 = model_0,
424 | dataloader = dataloader,
425 | dataloader_unlb_0 = None,
426 | train = phase == "train",
427 | optim = optim,
428 | optim_0 = optim_0,
429 | device = device,
430 | w_unlb = w_unlb)
431 |
432 |
433 | r2_track.append(sklearn.metrics.r2_score(y, yhat))
434 | lossreg_track.append(total_reg_reg)
435 |
436 |
437 | r2_value = np.average(np.array(r2_track))
438 | lossreg = np.average(np.array(lossreg_track))
439 |
440 | f.write("{},{},{},{},{},{},{},{},{},{}".format(epoch,
441 | phase,
442 | lossreg,
443 | r2_value,
444 | total_reg_unlb,
445 | time.time() - start_time,
446 | y.size,
447 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
448 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
449 | batch_size))
450 |
451 | for trck_write in range(len(r2_track)):
452 | f.write(",{}".format(r2_track[trck_write]))
453 |
454 | for trck_write in range(len(lossreg_track)):
455 | f.write(",{}".format(lossreg_track[trck_write]))
456 |
457 |
458 | f.write("\n")
459 | f.flush()
460 |
461 | np.random.set_state(np_rndstate)
462 | torch.set_rng_state(trch_rndstate)
463 |
464 |
465 | scheduler.step()
466 | scheduler_0.step()
467 |
468 | # Save checkpoint
469 | save = {
470 | 'epoch': epoch,
471 | 'state_dict': model.state_dict(),
472 | 'state_dict_0': model_0.state_dict(),
473 | 'period': period,
474 | 'frames': frames,
475 | 'best_loss': bestLoss,
476 | 'loss': lossreg,
477 | 'r2': r2_value,
478 | 'opt_dict': optim.state_dict(),
479 | 'opt_dict_0': optim_0.state_dict(),
480 | 'scheduler_dict': scheduler.state_dict(),
481 | 'scheduler_dict_0': scheduler_0.state_dict(),
482 | 'np_rndstate': np.random.get_state(),
483 | 'trch_rndstate': torch.get_rng_state()
484 | }
485 | torch.save(save, os.path.join(output, "checkpoint.pt"))
486 |
487 | #### save based on reg loss
488 | if lossreg < bestLoss:
489 | print("saved best because {} < {}".format(lossreg, bestLoss))
490 | torch.save(save, os.path.join(output, "best.pt"))
491 | bestLoss = lossreg
492 |
493 |
494 | # Load best weights
495 | if num_epochs != 0:
496 | checkpoint = torch.load(os.path.join(output, "best.pt"))
497 | model.load_state_dict(checkpoint['state_dict'], strict = False)
498 | model_0.load_state_dict(checkpoint['state_dict_0'], strict = False)
499 | f.write("Best validation loss {} from epoch {}, R2 {}\n".format(checkpoint["loss"], checkpoint["epoch"], checkpoint["r2"]))
500 | f.flush()
501 |
502 | if run_test:
503 | if notcamus:
504 | split_list = ["test", "val"]
505 | # split_list = ["test"]
506 | else:
507 | split_list = ["train", "test"]
508 |
509 | for split in split_list:
510 | # Performance without test-time augmentation
511 |
512 | if not full_test:
513 |
514 | for seed_itr in range(5):
515 | np.random.seed(seed_itr)
516 | torch.manual_seed(seed_itr)
517 |
518 | dataloader = torch.utils.data.DataLoader(
519 | echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel), segin_dir = "../infer_buffers/{}/{}_infer_cmpct".format(segsource, split)),
520 | batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"), worker_init_fn=worker_init_fn)
521 |
522 |
523 |
524 | loss, yhat, y = echonet.utils.vidsegin_teachstd_kd.run_epoch_val(model = model,
525 | model_0 = model_0,
526 | dataloader = dataloader,
527 | train = False,
528 | optim = None,
529 | optim_0 = None,
530 | device = device)
531 |
532 | f.write("Seed is {}\n".format(seed_itr))
533 | f.write("{} - {} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
534 | f.write("{} - {} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
535 | f.write("{} - {} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
536 | f.flush()
537 |
538 |
539 | else:
540 | # Performance with test-time augmentation
541 | if reduced_set:
542 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, clips="all", ssl_postfix="_ssl_{}_{}".format(rd_label, rd_unlabel))
543 | else:
544 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs_testall, clips="all", ssl_postfix="")
545 |
546 | yhat, y = echonet.utils.vidsegin_teachstd_kd.test_epoch_all(model,
547 | ds,
548 | False,
549 | None,
550 | device,
551 | save_all=True,
552 | block_size=batch_size,
553 | run_dir = output,
554 | test_val = split,
555 | **kwargs)
556 |
557 | f.write("Seed is {} \n".format(seed))
558 | f.write("{} - {} (all clips, mod) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
559 | f.write("{} - {} (all clips, mod) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
560 | f.write("{} - {} (all clips, mod) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
561 | f.flush()
562 |
563 |
564 |
565 |
566 | def test_epoch_all(model, dataset, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, mean = None, std = None, length = None, period = None, target_type = None):
567 | model.train(False)
568 |
569 | total = 0 # total training loss
570 | total_reg = 0
571 | total_ncor = 0
572 |
573 | n = 0 # number of videos processed
574 | s1 = 0 # sum of ground truth EF
575 | s2 = 0 # Sum of ground truth EF squared
576 |
577 | yhat = []
578 | y = []
579 |
580 | #### some params in the dataloader
581 |
582 | if (mean is None) or (std is None) or (length is None) or (period is None):
583 | assert 1==2, "missing key params"
584 |
585 | max_length = 250
586 |
587 | if run_dir:
588 |
589 | temp_savefile = os.path.join(run_dir, "temp_inference_{}.csv".format(test_val))
590 |
591 | with torch.set_grad_enabled(False):
592 | orig_filelist = dataset.fnames
593 |
594 | if os.path.isfile(temp_savefile):
595 | exist_data = pd.read_csv(temp_savefile)
596 | exist_file = list(exist_data['fnames'])
597 | target_filelist = sorted(list(set(orig_filelist) - set(exist_file)))
598 | else:
599 | target_filelist = sorted(list(orig_filelist))
600 | exist_data = pd.DataFrame(columns = ['fnames', 'yhat'])
601 |
602 | for filelistitr_idx in range(len(target_filelist)):
603 | filelistitr = target_filelist[filelistitr_idx]
604 |
605 | video_path = os.path.join(echonet.config.DATA_DIR, "Videos", filelistitr)
606 | ### Get data
607 | video = echonet.utils.loadvideo(video_path).astype(np.float32)
608 |
609 | if isinstance(mean, (float, int)):
610 | video -= mean
611 | else:
612 | video -= mean.reshape(3, 1, 1, 1)
613 |
614 | if isinstance(std, (float, int)):
615 | video /= std
616 | else:
617 | video /= std.reshape(3, 1, 1, 1)
618 |
619 | c, f, h, w = video.shape
620 | if length is None:
621 | # Take as many frames as possible
622 | length = f // period
623 | else:
624 | # Take specified number of frames
625 | length = length
626 |
627 | if max_length is not None:
628 | # Shorten videos to max_length
629 | length = min(length, max_length)
630 |
631 | if f < length * period:
632 | # Pad video with frames filled with zeros if too short
633 | # 0 represents the mean color (dark grey), since this is after normalization
634 | video = np.concatenate((video, np.zeros((c, length * period - f, h, w), video.dtype)), axis=1)
635 | c, f, h, w = video.shape # pylint: disable=E0633
636 |
637 | start = np.arange(f - (length - 1) * period)
638 | #### Do looping starting from here
639 |
640 | reg1 = []
641 | n_clips = start.shape[0]
642 | batch = 1
643 | for s_itr in range(0, start.shape[0], block_size):
644 | print("{}, processing file {} out of {}, block {} out of {}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), filelistitr_idx, len(target_filelist), s_itr, start.shape[0]), flush=True)
645 | # print("s range", start[s_itr: s_itr + block_size])
646 | # print("frame range", s + period * np.arange(length))
647 | vid_samp = tuple(video[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size])
648 | X1 = torch.tensor(np.stack(vid_samp))
649 | X1 = X1.to(device)
650 |
651 | if device.type == "cuda":
652 | all_output = model(X1)
653 | else:
654 | #### we only ever use cpu for testing
655 | all_output = torch.ones((X1.shape[0]))
656 |
657 | reg1.append(all_output[0].detach().cpu().numpy())
658 |
659 | reg1 = np.vstack(reg1)
660 | reg1_mean = reg1.reshape(batch, n_clips, -1).mean(1)
661 |
662 | exist_data = exist_data.append({'fnames':filelistitr, 'yhat':reg1_mean[0,0]}, ignore_index=True)
663 |
664 | if filelistitr_idx % 20 == 0:
665 | exist_data.to_csv(temp_savefile, index = False)
666 |
667 | label_data_path = os.path.join(echonet.config.DATA_DIR, "FileList.csv")
668 | label_data = pd.read_csv(label_data_path)
669 | label_data_select = label_data[['FileName','EF']]
670 | label_data_select.columns = ['fnames','EF']
671 | with_predict = exist_data.merge(label_data_select, on='fnames')
672 |
673 | predict_out_path = os.path.join(run_dir, "{}_predictions.csv".format(test_val))
674 | with_predict.to_csv(predict_out_path, index=False)
675 |
676 |
677 | # print(with_predict)
678 | # exit()
679 | return with_predict['yhat'].to_numpy(), with_predict['EF'].to_numpy()
680 |
681 |
682 | def run_epoch(model,
683 | model_0,
684 | dataloader,
685 | dataloader_unlb_0,
686 | train,
687 | optim,
688 | optim_0,
689 | device,
690 | save_all=False,
691 | block_size=None,
692 | run_dir = None,
693 | test_val = None,
694 | w_unlb = 0):
695 |
696 |
697 | total = 0 # total training loss
698 | total_reg = 0
699 | total_reg_reg = 0
700 | total_loss_reg_reg_unlb = 0
701 |
702 |
703 | n = 0 # number of videos processed
704 | n_frm = 0
705 | s1 = 0 # sum of ground truth EF
706 | s2 = 0 # Sum of ground truth EF squared
707 |
708 | yhat = []
709 | yhat_seg = []
710 | y = []
711 | start_frame_record = []
712 | vidpath_record = []
713 |
714 | if dataloader_unlb_0:
715 | dataloader_unlb_0_itr = iter(dataloader_unlb_0)
716 |
717 | with torch.set_grad_enabled(train):
718 | with tqdm.tqdm(total=len(dataloader)) as pbar:
719 | enum_idx = 0
720 | for (X_all, outcome, seg_info, start_frame, video_path, _, _) in dataloader:
721 | enum_idx = enum_idx + 1
722 |
723 |
724 | if not train:
725 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy())
726 | vidpath_record.append(video_path)
727 |
728 | y.append(outcome.detach().cpu().numpy())
729 |
730 | if X_all.dtype == torch.double:
731 | X_all = X_all.float()
732 |
733 | X_noseg = X_all[:,:3,...].to(device)
734 | X_wseg = X_all.to(device)
735 |
736 | outcome = outcome.to(device)
737 |
738 | s1 += outcome.sum()
739 | s2 += (outcome ** 2).sum()
740 |
741 | if dataloader_unlb_0:
742 |
743 | (X_all_unlb, outcome_unlb, seg_info_unlb, start_frame_unlb, video_path_unlb, _, _) = dataloader_unlb_0_itr.next()
744 | if X_all_unlb.dtype == torch.double:
745 | X_all_unlb = X_all_unlb.float()
746 |
747 | X_noseg_unlb = X_all_unlb[:,:3,...].to(device)
748 | X_wseg_unlb = X_all_unlb.to(device)
749 |
750 |
751 | if train:
752 | model.train(True)
753 | else:
754 | model.train(False)
755 | model_0.train(False)
756 |
757 | if train:
758 | all_output_unlb = model(X_noseg_unlb)
759 | else:
760 | with torch.no_grad():
761 | all_output_unlb = model(X_noseg_unlb)
762 |
763 | y_pred_unlb = all_output_unlb[0]
764 |
765 | with torch.no_grad():
766 | all_output_seg_unlb = model_0(X_wseg_unlb)
767 |
768 | y_pred_seg_unlb = all_output_seg_unlb[0]
769 |
770 | y_pred_avg_unlb = y_pred_seg_unlb.view(-1).detach()
771 |
772 | loss_reg_reg_unlb_vid = torch.nn.functional.mse_loss(y_pred_unlb.view(-1), y_pred_avg_unlb)
773 | loss_reg_reg_unlb_seg = torch.nn.functional.mse_loss(y_pred_seg_unlb.view(-1), y_pred_avg_unlb)
774 |
775 | loss_reg_reg_unlb = loss_reg_reg_unlb_vid / 2
776 |
777 | loss_reg_reg_unlb_item = loss_reg_reg_unlb.item()
778 |
779 | else:
780 | loss_reg_reg_unlb_item = 0
781 |
782 |
783 | #### train video model
784 | if train:
785 | model.train(True)
786 | # attKD.train(True)
787 | else:
788 | model.train(False)
789 | # attKD.train(False)
790 | model_0.train(False)
791 |
792 | if train:
793 | all_output = model(X_noseg)
794 | else:
795 | with torch.no_grad():
796 | all_output = model(X_noseg)
797 |
798 | y_pred = all_output[0]
799 |
800 | with torch.no_grad():
801 | all_output_seg = model_0(X_wseg)
802 |
803 | y_pred_seg = all_output_seg[0]
804 |
805 |
806 | loss_reg_reg = torch.nn.functional.mse_loss(y_pred.view(-1), outcome)
807 | yhat.append(y_pred.view(-1).to("cpu").detach().numpy())
808 |
809 | if dataloader_unlb_0:
810 | loss_reg = loss_reg_reg + w_unlb * loss_reg_reg_unlb
811 | else:
812 | loss_reg = loss_reg_reg
813 |
814 |
815 | if train:
816 | optim.zero_grad()
817 | loss_reg.backward()
818 | optim.step()
819 |
820 | total_reg += loss_reg.item() * outcome.size(0)
821 | total_reg_reg += loss_reg_reg.item() * outcome.size(0)
822 |
823 |
824 | loss_reg_item = loss_reg.item()
825 | loss_reg_reg_item = loss_reg_reg.item()
826 |
827 |
828 | total = total_reg
829 | if dataloader_unlb_0:
830 | total_loss_reg_reg_unlb = total_loss_reg_reg_unlb + loss_reg_reg_unlb.item()
831 |
832 | n += outcome.size(0)
833 |
834 | pbar.set_postfix_str("total {:.4f} / reg {:.4f} ({:.4f}) regrg {:.4f} ({:.4f}) regulb {:.4f} ({:.4f})".format(total / n ,
835 | total_reg / n , loss_reg_item,
836 | total_reg_reg / n , loss_reg_reg_item,
837 | total_loss_reg_reg_unlb / n, loss_reg_reg_unlb_item
838 | ))
839 | pbar.update()
840 | yhat_cat = np.concatenate(yhat)
841 | y = np.concatenate(y)
842 |
843 |
844 | return (total / n ,
845 | total_reg / n ,
846 | total_reg_reg / n ,
847 | total_loss_reg_reg_unlb / n ,
848 | yhat_cat, y)
849 |
850 |
851 |
852 |
853 |
854 | def run_epoch_val(model, model_0, dataloader, train, optim, optim_0, device, save_all=False, block_size=None):
855 | """Run one epoch of training/evaluation for segmentation.
856 |
857 | Args:
858 | model (torch.nn.Module): Model to train/evaulate.
859 | dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
860 | train (bool): Whether or not to train model.
861 | optim (torch.optim.Optimizer): Optimizer
862 | device (torch.device): Device to run on
863 | save_all (bool, optional): If True, return predictions for all
864 | test-time augmentations separately. If False, return only
865 | the mean prediction.
866 | Defaults to False.
867 | block_size (int or None, optional): Maximum number of augmentations
868 | to run on at the same time. Use to limit the amount of memory
869 | used. If None, always run on all augmentations simultaneously.
870 | Default is None.
871 | """
872 |
873 | total = 0 # total training loss
874 | total_reg = 0
875 | start_frame_record = []
876 | vidpath_record = []
877 | yhat = []
878 | y = []
879 |
880 | n = 0
881 | with torch.set_grad_enabled(train):
882 | with tqdm.tqdm(total=len(dataloader)) as pbar:
883 | enum_idx = 0
884 | for (X, outcome, start_frame, video_path, _, _) in dataloader:
885 | enum_idx = enum_idx + 1
886 |
887 | if not train:
888 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy())
889 | vidpath_record.append(video_path)
890 |
891 | y.append(outcome.detach().cpu().numpy())
892 |
893 | if X.dtype == torch.double:
894 | X = X.float()
895 |
896 | X = X[:,:3,...].to(device)
897 |
898 | outcome = outcome.to(device)
899 |
900 | model.train(False)
901 | model_0.train(False)
902 | all_output = model(X)
903 |
904 | y_pred = all_output[0]
905 |
906 | loss_reg = torch.nn.functional.mse_loss(y_pred.view(-1), outcome)
907 | yhat.append(y_pred.view(-1).to("cpu").detach().numpy())
908 |
909 | total_reg += loss_reg.item() * outcome.size(0)
910 | total = total_reg
911 | n += outcome.size(0)
912 |
913 | pbar.set_postfix_str("total {:.4f}".format(total / n))
914 | pbar.update()
915 |
916 | yhat = np.concatenate(yhat)
917 |
918 | y = np.concatenate(y)
919 |
920 |
921 | return (total / n, yhat, y)
922 |
923 |
924 |
--------------------------------------------------------------------------------
/flow_a_tmi_revise_v2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_a_tmi_revise_v2.PNG
--------------------------------------------------------------------------------
/flow_b_tmi_revise.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_b_tmi_revise.PNG
--------------------------------------------------------------------------------
/flow_graph.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmed-lab/CSS-SemiVideo/5707b6085cb61807253710766c816e3cd1ae6a25/flow_graph.PNG
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | ca-certificates=2021.4.13=h06a4308_1
6 | certifi=2020.12.5=py39h06a4308_0
7 | ld_impl_linux-64=2.33.1=h53a641e_7
8 | libffi=3.3=he6710b0_2
9 | libgcc-ng=9.1.0=hdf63c60_0
10 | libstdcxx-ng=9.1.0=hdf63c60_0
11 | ncurses=6.2=he6710b0_1
12 | numpy=1.21.2=pypi_0
13 | openssl=1.1.1k=h27cfd23_0
14 | pillow=8.3.1=pypi_0
15 | pip=21.0.1=py39h06a4308_0
16 | python=3.9.4=hdb3f193_0
17 | readline=8.1=h27cfd23_0
18 | setuptools=52.0.0=py39h06a4308_0
19 | sqlite=3.35.4=hdfb4753_0
20 | tk=8.6.10=hbc83047_0
21 | torch=1.7.1+cu110=pypi_0
22 | torchaudio=0.7.2=pypi_0
23 | torchvision=0.8.2+cu110=pypi_0
24 | typing-extensions=3.10.0.0=pypi_0
25 | tzdata=2020f=h52ac0ba_0
26 | wheel=0.36.2=pyhd3eb1b0_0
27 | xz=5.2.5=h7b6447c_0
28 | zlib=1.2.11=h7b6447c_3
29 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Metadata for package to allow installation with pip."""
3 |
4 | import os
5 |
6 | import setuptools
7 |
8 | with open("README.md", "r") as fh:
9 | long_description = fh.read()
10 |
11 | # Use same version from code
12 | # See 3 from
13 | # https://packaging.python.org/guides/single-sourcing-package-version/
14 | version = {}
15 | with open(os.path.join("echonet", "__version__.py")) as f:
16 | exec(f.read(), version) # pylint: disable=W0122
17 |
18 | setuptools.setup(
19 | name="echonet",
20 | description="Video-based AI for beat-to-beat cardiac function assessment.",
21 | version=version["__version__"],
22 | url="https://echonet.github.io/dynamic",
23 | packages=setuptools.find_packages(exclude=["output.*", "output*", "*output.*", "*output*", "*output", "output"]),
24 | install_requires=[
25 | "click",
26 | "numpy",
27 | "pandas",
28 | "torch",
29 | "torchvision",
30 | "opencv-python",
31 | "scikit-image",
32 | "tqdm",
33 | "sklearn"
34 | ],
35 | classifiers=[
36 | "Programming Language :: Python :: 3",
37 | ],
38 | entry_points={
39 | "console_scripts": [
40 | "echonet=echonet:main",
41 | ],
42 | }
43 |
44 | )
45 |
--------------------------------------------------------------------------------