├── .gitignore ├── echonet.cfg ├── echonet ├── __version__.py ├── docs │ └── framework_chart.PNG ├── __main__.py ├── models │ ├── __init__.py │ └── rnet2dp1.py ├── datasets │ ├── __init__.py │ └── echo.py ├── __init__.py ├── config.py └── utils │ ├── __init__.py │ └── video.py ├── requirements.txt ├── LICENSE ├── setup.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | trained_model 2 | -------------------------------------------------------------------------------- /echonet.cfg: -------------------------------------------------------------------------------- 1 | DATA_DIR = 2 | -------------------------------------------------------------------------------- /echonet/__version__.py: -------------------------------------------------------------------------------- 1 | """Version number for Echonet package.""" 2 | 3 | __version__ = "1.0.0" 4 | -------------------------------------------------------------------------------- /echonet/docs/framework_chart.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/AdaCon/HEAD/echonet/docs/framework_chart.PNG -------------------------------------------------------------------------------- /echonet/__main__.py: -------------------------------------------------------------------------------- 1 | """Entry point for command line.""" 2 | 3 | import echonet 4 | 5 | 6 | if __name__ == '__main__': 7 | echonet.main() 8 | -------------------------------------------------------------------------------- /echonet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rnet2dp1 import r2plus1d_18_ctrst 2 | from .rnet2dp1 import SupConLoss_admargin 3 | 4 | __all__ = ["r2plus1d_18_ctrst", "SupConLoss_admargin"] -------------------------------------------------------------------------------- /echonet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The echonet.datasets submodule defines a Pytorch dataset for loading 3 | echocardiogram videos. 4 | """ 5 | 6 | from .echo import Echo 7 | 8 | __all__ = ["Echo"] 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5 2 | cycler==0.10.0 3 | decorator==4.4.2 4 | echonet==1.0.0 5 | imageio==2.9.0 6 | joblib==1.0.1 7 | kiwisolver==1.3.1 8 | matplotlib==3.3.4 9 | networkx==2.5 10 | numpy==1.20.1 11 | opencv-python==4.5.1.48 12 | pandas==1.2.3 13 | Pillow==8.1.2 14 | pyparsing==2.4.7 15 | python-dateutil==2.8.1 16 | pytz==2021.1 17 | PyWavelets==1.1.1 18 | scikit-image==0.18.1 19 | scikit-learn==0.24.1 20 | scipy==1.6.1 21 | six==1.15.0 22 | sklearn==0.0 23 | threadpoolctl==2.1.0 24 | tifffile==2021.3.17 25 | torch==1.8.0 26 | torchvision==0.9.0 27 | tqdm==4.59.0 28 | typing-extensions==3.7.4.3 29 | -------------------------------------------------------------------------------- /echonet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The echonet package contains code for loading echocardiogram videos, and 3 | functions for training and testing segmentation and ejection fraction 4 | prediction models. 5 | """ 6 | 7 | import click 8 | 9 | from echonet.__version__ import __version__ 10 | from echonet.config import CONFIG as config 11 | import echonet.datasets as datasets 12 | import echonet.utils as utils 13 | import echonet.models as models 14 | 15 | @click.group() 16 | def main(): 17 | """Entry point for command line interface.""" 18 | 19 | 20 | del click 21 | 22 | 23 | main.add_command(utils.video.run) 24 | 25 | __all__ = ["__version__", "config", "datasets", "main", "utils", "models"] 26 | -------------------------------------------------------------------------------- /echonet/config.py: -------------------------------------------------------------------------------- 1 | """Sets paths based on configuration files.""" 2 | 3 | import configparser 4 | import os 5 | import types 6 | 7 | _FILENAME = None 8 | _PARAM = {} 9 | for filename in ["echonet.cfg", 10 | ".echonet.cfg", 11 | os.path.expanduser("~/echonet.cfg"), 12 | os.path.expanduser("~/.echonet.cfg"), 13 | ]: 14 | if os.path.isfile(filename): 15 | _FILENAME = filename 16 | config = configparser.ConfigParser() 17 | with open(filename, "r") as f: 18 | config.read_string("[config]\n" + f.read()) 19 | _PARAM = config["config"] 20 | break 21 | 22 | CONFIG = types.SimpleNamespace( 23 | FILENAME=_FILENAME, 24 | DATA_DIR=_PARAM.get("data_dir", "../EchoNet/Heart-videos/")) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 XMed-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Metadata for package to allow installation with pip.""" 3 | 4 | import os 5 | 6 | import setuptools 7 | 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | # Use same version from code 12 | # See 3 from 13 | # https://packaging.python.org/guides/single-sourcing-package-version/ 14 | version = {} 15 | with open(os.path.join("echonet", "__version__.py")) as f: 16 | exec(f.read(), version) # pylint: disable=W0122 17 | 18 | setuptools.setup( 19 | name="echonet", 20 | description="Video-based AI for beat-to-beat cardiac function assessment.", 21 | version=version["__version__"], 22 | url="https://echonet.github.io/dynamic", 23 | packages=setuptools.find_packages(), 24 | install_requires=[ 25 | "click", 26 | "numpy", 27 | "pandas", 28 | "torch", 29 | "torchvision", 30 | "opencv-python", 31 | "scikit-image", 32 | "tqdm", 33 | "sklearn" 34 | ], 35 | classifiers=[ 36 | "Programming Language :: Python :: 3", 37 | ], 38 | entry_points={ 39 | "console_scripts": [ 40 | "echonet=echonet:main", 41 | ], 42 | } 43 | 44 | ) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IEEE TMI 2021: AdaCon: Adaptive Contrast for Image Regression in Computer-Aided Disease Assessment 2 | 3 | 4 | ![AdaCon framework](echonet/docs/framework_chart.PNG) 5 | 6 | 7 | This is the implementation of AdaCon on the EchoNet-Dynamic Dataset for the paper ["AdaCon: Adaptive Contrast for Image Regression in Computer-Aided Disease Assessment"](http://arxiv.org/abs/2112.11700) (IEEE TMI). 8 | 9 |
10 |
11 | 12 | ## Data 13 | 14 | Researchers can request the EchoNet-Dynamic dataset at https://echonet.github.io/dynamic/ and set the directory path in the configuration file, `echonet.cfg`. 15 | 16 |
17 |
18 | 19 | 20 | ## Environment 21 | 22 | It is recommended to use PyTorch `conda` environments for running the program. A requirements file has been included. 23 | 24 |
25 |
26 | 27 | ## Training and Testing 28 | 29 | The code must first be installed by running 30 | 31 | pip install --user . 32 | 33 | under the `adacon` directory. To train the model from scratch, run: 34 | 35 | ``` 36 | echonet video --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --run_test --output=training_output 37 | ``` 38 | 39 |
40 |
41 | 42 | ## Pretrained Model 43 | 44 | A trained version of the model can be downloaded from https://hkustconnect-my.sharepoint.com/:u:/g/personal/wdaiaj_connect_ust_hk/EXu95kAzcitGibTOWxwSmDEBKIAia3H8Dw5CbGVDsPbWBg?e=QBzdD6 45 | 46 | Inference with the trained model can be run using 47 | 48 | ``` 49 | echonet video --frames=32 --model_name=r2plus1d_18 --period=2 --batch_size=20 --run_test --output=training_output --weights= --num_epochs=0 50 | ``` 51 | 52 |
53 | 54 | | | MAE | RMSE | R2 | 55 | | ---------- | :-----------: | :-----------: | :-----------: | 56 | | AdaCon | 3.86 | 5.07 | 82.8% | 57 | 58 | 59 |
60 |
61 | 62 | ## Notes 63 | * Contact: DAI Weihang (wdai03@gmail.com) 64 | 65 |
66 |
67 | 68 | ## Citation 69 | If this code is useful for your research, please consider citing: 70 | 71 | ``` 72 | @article{dai2021adaptive, 73 | title={Adaptive Contrast for Image Regression in Computer-Aided Disease Assessment}, 74 | author={Dai, Weihang and Li, Xiaomeng and Chiu, Wan Hang Keith and Kuo, Michael D and Cheng, Kwang-Ting}, 75 | journal={IEEE Transactions on Medical Imaging}, 76 | year={2021}, 77 | publisher={IEEE} 78 | } 79 | ``` -------------------------------------------------------------------------------- /echonet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for videos, plotting and computing performance metrics.""" 2 | 3 | import os 4 | import typing 5 | import datetime 6 | 7 | import cv2 # pytype: disable=attribute-error 8 | import matplotlib 9 | import numpy as np 10 | import torch 11 | import tqdm 12 | 13 | from . import video 14 | 15 | 16 | def loadvideo(filename: str) -> np.ndarray: 17 | """Loads a video from a file. 18 | 19 | Args: 20 | filename (str): filename of video 21 | 22 | Returns: 23 | A np.ndarray with dimensions (channels=3, frames, height, width). The 24 | values will be uint8's ranging from 0 to 255. 25 | 26 | Raises: 27 | FileNotFoundError: Could not find `filename` 28 | ValueError: An error occurred while reading the video 29 | """ 30 | 31 | if not os.path.exists(filename): 32 | raise FileNotFoundError(filename) 33 | capture = cv2.VideoCapture(filename) 34 | 35 | frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) 36 | frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 37 | frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 38 | 39 | v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) 40 | for count in range(frame_count): 41 | ret, frame = capture.read() 42 | if not ret: 43 | raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) 44 | 45 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 46 | v[count, :, :] = frame 47 | 48 | v = v.transpose((3, 0, 1, 2)) 49 | 50 | return v 51 | 52 | 53 | def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): 54 | """Saves a video to a file. 55 | 56 | Args: 57 | filename (str): filename of video 58 | array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) 59 | fps (float or int): frames per second 60 | 61 | Returns: 62 | None 63 | """ 64 | 65 | c, _, height, width = array.shape 66 | 67 | if c != 3: 68 | raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) 69 | fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') 70 | out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) 71 | 72 | for frame in array.transpose((1, 2, 3, 0)): 73 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 74 | out.write(frame) 75 | 76 | 77 | def get_mean_and_std(dataset: torch.utils.data.Dataset, 78 | samples: int = 128, 79 | batch_size: int = 8, 80 | num_workers: int = 4): 81 | """Computes mean and std from samples from a Pytorch dataset. 82 | 83 | Args: 84 | dataset (torch.utils.data.Dataset): A Pytorch dataset. 85 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which 86 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) 87 | samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and 88 | standard deviation are computed over all elements. 89 | Defaults to 128. 90 | batch_size (int, optional): how many samples per batch to load 91 | Defaults to 8. 92 | num_workers (int, optional): how many subprocesses to use for data 93 | loading. If 0, the data will be loaded in the main process. 94 | Defaults to 4. 95 | 96 | Returns: 97 | A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). 98 | """ 99 | 100 | if samples is not None and len(dataset) > samples: 101 | indices = np.random.choice(len(dataset), samples, replace=False) 102 | dataset = torch.utils.data.Subset(dataset, indices) 103 | 104 | dataloader = torch.utils.data.DataLoader( 105 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) 106 | 107 | n = 0 # number of elements taken (should be equal to samples by end of for loop) 108 | s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) 109 | s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) 110 | # for (x, *_) in tqdm.tqdm(dataloader): 111 | for (x,_,*_) in tqdm.tqdm(dataloader): 112 | x = x.transpose(0, 1).contiguous().view(3, -1) 113 | n += x.shape[1] 114 | s1 += torch.sum(x, dim=1).numpy() 115 | s2 += torch.sum(x ** 2, dim=1).numpy() 116 | mean = s1 / n # type: np.ndarray 117 | std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray 118 | 119 | mean = mean.astype(np.float32) 120 | std = std.astype(np.float32) 121 | 122 | return mean, std 123 | 124 | 125 | def bootstrap(a, b, func, samples=10000): 126 | """Computes a bootstrapped confidence intervals for ``func(a, b)''. 127 | 128 | Args: 129 | a (array_like): first argument to `func`. 130 | b (array_like): second argument to `func`. 131 | func (callable): Function to compute confidence intervals for. 132 | ``dataset[i][0]'' is expected to be the i-th video in the dataset, which 133 | should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) 134 | samples (int, optional): Number of samples to compute. 135 | Defaults to 10000. 136 | 137 | Returns: 138 | A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). 139 | """ 140 | a = np.array(a) 141 | b = np.array(b) 142 | 143 | bootstraps = [] 144 | for _ in range(samples): 145 | ind = np.random.choice(len(a), len(a)) 146 | bootstraps.append(func(a[ind], b[ind])) 147 | bootstraps = sorted(bootstraps) 148 | 149 | return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] 150 | 151 | 152 | def latexify(): 153 | """Sets matplotlib params to appear more like LaTeX. 154 | 155 | Based on https://nipunbatra.github.io/blog/2014/latexify.html 156 | """ 157 | params = {'backend': 'pdf', 158 | 'axes.titlesize': 8, 159 | 'axes.labelsize': 8, 160 | 'font.size': 8, 161 | 'legend.fontsize': 8, 162 | 'xtick.labelsize': 8, 163 | 'ytick.labelsize': 8, 164 | 'font.family': 'DejaVu Serif', 165 | 'font.serif': 'Computer Modern', 166 | } 167 | matplotlib.rcParams.update(params) 168 | 169 | 170 | def dice_similarity_coefficient(inter, union): 171 | """Computes the dice similarity coefficient. 172 | 173 | Args: 174 | inter (iterable): iterable of the intersections 175 | union (iterable): iterable of the unions 176 | """ 177 | return 2 * sum(inter) / (sum(union) + sum(inter)) 178 | 179 | 180 | __all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"] 181 | -------------------------------------------------------------------------------- /echonet/datasets/echo.py: -------------------------------------------------------------------------------- 1 | """EchoNet-Dynamic Dataset.""" 2 | 3 | import os 4 | import collections 5 | import pandas 6 | import datetime 7 | 8 | import numpy as np 9 | import skimage.draw 10 | import torchvision 11 | import echonet 12 | 13 | 14 | class Echo(torchvision.datasets.VisionDataset): 15 | """EchoNet-Dynamic Dataset. 16 | 17 | Args: 18 | root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`) 19 | split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''} 20 | target_type (string or list, optional): Type of target to use, 21 | ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'', 22 | ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'', 23 | or ``SmallTrace'' 24 | Can also be a list to output a tuple with all specified target types. 25 | The targets represent: 26 | ``Filename'' (string): filename of video 27 | ``EF'' (float): ejection fraction 28 | ``EDV'' (float): end-diastolic volume 29 | ``ESV'' (float): end-systolic volume 30 | ``LargeIndex'' (int): index of large (diastolic) frame in video 31 | ``SmallIndex'' (int): index of small (systolic) frame in video 32 | ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame 33 | ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame 34 | ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation 35 | value of 0 indicates pixel is outside left ventricle 36 | 1 indicates pixel is inside left ventricle 37 | ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation 38 | value of 0 indicates pixel is outside left ventricle 39 | 1 indicates pixel is inside left ventricle 40 | Defaults to ``EF''. 41 | mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel. 42 | Used for normalizing the video. Defaults to 0 (video is not shifted). 43 | std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel. 44 | Used for normalizing the video. Defaults to 0 (video is not scaled). 45 | length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned. 46 | Defaults to 16. 47 | period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken) 48 | Defaults to 2. 49 | max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively 50 | long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video. 51 | Defaults to 250. 52 | clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips. 53 | Defaults to 1. 54 | pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation). 55 | and a window of the original size is taken. If ``None'', no padding occurs. 56 | Defaults to ``None''. 57 | noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added. 58 | Defaults to ``None''. 59 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 60 | external_test_location (string): Path to videos to use for external testing. 61 | """ 62 | 63 | def __init__(self, root=None, 64 | split="train", target_type="EF", 65 | mean=0., std=1., 66 | length=16, period=2, 67 | max_length=250, 68 | clips=1, 69 | pad=None, 70 | noise=None, 71 | target_transform=None, 72 | external_test_location=None): 73 | if root is None: 74 | root = echonet.config.DATA_DIR 75 | 76 | super().__init__(root, target_transform=target_transform) 77 | 78 | self.split = split.upper() 79 | if not isinstance(target_type, list): 80 | target_type = [target_type] 81 | self.target_type = target_type 82 | self.mean = mean 83 | self.std = std 84 | self.length = length 85 | self.max_length = max_length 86 | self.period = period 87 | self.clips = clips 88 | self.pad = pad 89 | self.noise = noise 90 | self.target_transform = target_transform 91 | self.external_test_location = external_test_location 92 | 93 | self.fnames, self.outcome = [], [] 94 | 95 | if self.split == "EXTERNAL_TEST": 96 | self.fnames = sorted(os.listdir(self.external_test_location)) 97 | else: 98 | print(os.path.join(self.root, "FileList.csv")) 99 | 100 | with open(os.path.join(self.root, "FileList.csv")) as f: 101 | data = pandas.read_csv(f) 102 | data["Split"].map(lambda x: x.upper()) 103 | 104 | if self.split != "ALL": 105 | data = data[data["Split"] == self.split] 106 | 107 | 108 | data["EF_bkt"] = data["EF"]//0.05 # // 0.02 // 0.01 109 | EF_freq = data['EF_bkt'].value_counts(dropna=False).rename_axis('EF_bkt_key').reset_index(name='counts') 110 | EF_freq = EF_freq.sort_values(by=['EF_bkt_key']).reset_index() 111 | 112 | EF_dict = {} 113 | 114 | for key_itr_idx in range(len(EF_freq['EF_bkt_key'])): 115 | if key_itr_idx == 0: 116 | EF_dict[EF_freq['EF_bkt_key'][key_itr_idx]] = EF_freq['counts'][key_itr_idx] 117 | else: 118 | EF_dict[EF_freq['EF_bkt_key'][key_itr_idx]] = EF_dict[EF_freq['EF_bkt_key'][key_itr_idx-1]] + EF_freq['counts'][key_itr_idx] 119 | 120 | for key_itr_idx in range(len(EF_freq['EF_bkt_key'])): 121 | EF_dict[EF_freq['EF_bkt_key'][key_itr_idx]] = EF_dict[EF_freq['EF_bkt_key'][key_itr_idx]] - EF_freq['counts'][key_itr_idx]/2 122 | 123 | data['EF_CLS'] = data["EF_bkt"].apply(lambda x: EF_dict[x]) 124 | 125 | self.header = data.columns.tolist() 126 | self.fnames = data["FileName"].tolist() 127 | 128 | self.outcome = data.values.tolist() 129 | 130 | 131 | missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos"))) 132 | if len(missing) != 0: 133 | print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos"))) 134 | for f in sorted(missing): 135 | print("\t", f) 136 | raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0])) 137 | 138 | self.frames = collections.defaultdict(list) 139 | self.trace = collections.defaultdict(_defaultdict_of_lists) 140 | 141 | with open(os.path.join(self.root, "VolumeTracings.csv")) as f: 142 | header = f.readline().strip().split(",") 143 | assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"] 144 | 145 | for line in f: 146 | filename, x1, y1, x2, y2, frame = line.strip().split(',') 147 | filename = filename + ".avi" 148 | x1 = float(x1) 149 | y1 = float(y1) 150 | x2 = float(x2) 151 | y2 = float(y2) 152 | frame = int(frame) 153 | if frame not in self.trace[filename]: 154 | self.frames[filename].append(frame) 155 | self.trace[filename][frame].append((x1, y1, x2, y2)) 156 | for filename in self.frames: 157 | for frame in self.frames[filename]: 158 | self.trace[filename][frame] = np.array(self.trace[filename][frame]) 159 | 160 | keep = [len(self.frames[f]) >= 2 for f in self.fnames] 161 | self.fnames = [f for (f, k) in zip(self.fnames, keep) if k] 162 | self.outcome = [f for (f, k) in zip(self.outcome, keep) if k] 163 | 164 | def __getitem__(self, index): 165 | if self.split == "EXTERNAL_TEST": 166 | video_path = os.path.join(self.external_test_location, self.fnames[index]) 167 | elif self.split == "CLINICAL_TEST": 168 | video_path = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index]) 169 | else: 170 | video_path = os.path.join(self.root, "Videos", self.fnames[index]) 171 | 172 | video = echonet.utils.loadvideo(video_path).astype(np.float32) 173 | 174 | if self.noise is not None: 175 | n = video.shape[1] * video.shape[2] * video.shape[3] 176 | ind = np.random.choice(n, round(self.noise * n), replace=False) 177 | f = ind % video.shape[1] 178 | ind //= video.shape[1] 179 | i = ind % video.shape[2] 180 | ind //= video.shape[2] 181 | j = ind 182 | video[:, f, i, j] = 0 183 | 184 | if isinstance(self.mean, (float, int)): 185 | video -= self.mean 186 | else: 187 | video -= self.mean.reshape(3, 1, 1, 1) 188 | 189 | if isinstance(self.std, (float, int)): 190 | video /= self.std 191 | else: 192 | video /= self.std.reshape(3, 1, 1, 1) 193 | 194 | c, f, h, w = video.shape 195 | if self.length is None: 196 | length = f // self.period 197 | else: 198 | length = self.length 199 | 200 | if self.max_length is not None: 201 | length = min(length, self.max_length) 202 | 203 | if f < length * self.period: 204 | video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1) 205 | c, f, h, w = video.shape 206 | 207 | if self.clips == "all": 208 | start = np.arange(f - (length - 1) * self.period) 209 | else: 210 | start = np.random.choice(f - (length - 1) * self.period, self.clips) 211 | 212 | 213 | target = [] 214 | target_cls = [] 215 | for t in self.target_type: 216 | key = self.fnames[index] 217 | if t == "Filename": 218 | target.append(self.fnames[index]) 219 | elif t == "LargeIndex": 220 | target.append(np.int(self.frames[key][-1])) 221 | elif t == "SmallIndex": 222 | target.append(np.int(self.frames[key][0])) 223 | elif t == "LargeFrame": 224 | target.append(video[:, self.frames[key][-1], :, :]) 225 | elif t == "SmallFrame": 226 | target.append(video[:, self.frames[key][0], :, :]) 227 | elif t in ["LargeTrace", "SmallTrace"]: 228 | if t == "LargeTrace": 229 | t = self.trace[key][self.frames[key][-1]] 230 | else: 231 | t = self.trace[key][self.frames[key][0]] 232 | x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3] 233 | x = np.concatenate((x1[1:], np.flip(x2[1:]))) 234 | y = np.concatenate((y1[1:], np.flip(y2[1:]))) 235 | 236 | r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3])) 237 | mask = np.zeros((video.shape[2], video.shape[3]), np.float32) 238 | mask[r, c] = 1 239 | target.append(mask) 240 | else: 241 | if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST": 242 | target.append(np.float32(0)) 243 | else: 244 | target.append(np.float32(self.outcome[index][self.header.index(t)])) 245 | target_cls.append(np.float32(self.outcome[index][self.header.index('EF_CLS')])) 246 | 247 | if target != []: 248 | target = tuple(target) if len(target) > 1 else target[0] 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | if target_cls !=[]: 252 | target_cls = tuple(target_cls) if len(target_cls) > 1 else target_cls[0] 253 | 254 | video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start) 255 | 256 | if self.clips == 1: 257 | video = video[0] 258 | else: 259 | video = np.stack(video) 260 | 261 | if self.pad is not None: 262 | 263 | jit1 = np.random.random()*0.1 264 | jit2 = np.random.random()*0.1 265 | 266 | # video1 = video + jit1 267 | # video2 = video + jit2 268 | 269 | video1 = video.copy() 270 | video2 = video.copy() 271 | 272 | 273 | c, l, h, w = video.shape 274 | 275 | temp1 = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype) 276 | temp1[:, :, self.pad:-self.pad, self.pad:-self.pad] = video1 277 | 278 | temp2 = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype) 279 | temp2[:, :, self.pad:-self.pad, self.pad:-self.pad] = video2 280 | 281 | i1, j1 = np.random.randint(0, 2 * self.pad, 2) 282 | i2, j2 = np.random.randint(0, 2 * self.pad, 2) 283 | 284 | video1 = temp1[:, :, i1:(i1 + h), j1:(j1 + w)] 285 | video2 = temp2[:, :, i2:(i2 + h), j2:(j2 + w)] 286 | 287 | 288 | else: 289 | video1 = video.copy() 290 | video2 = video.copy() 291 | 292 | return video1, video2, target, target_cls, start, video_path 293 | 294 | def __len__(self): 295 | return len(self.fnames) 296 | 297 | def extra_repr(self) -> str: 298 | """Additional information to add at end of __repr__.""" 299 | lines = ["Target type: {target_type}", "Split: {split}"] 300 | return '\n'.join(lines).format(**self.__dict__) 301 | 302 | 303 | def _defaultdict_of_lists(): 304 | """Returns a defaultdict of lists. 305 | 306 | This is used to avoid issues with Windows (if this function is anonymous, 307 | the Echo dataset cannot be used in a dataloader). 308 | """ 309 | 310 | return collections.defaultdict(list) 311 | -------------------------------------------------------------------------------- /echonet/models/rnet2dp1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | from torchvision.models.utils import load_state_dict_from_url 6 | # from ..utils import load_state_dict_from_url 7 | 8 | 9 | __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] 10 | 11 | model_urls = { 12 | 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', 13 | 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', 14 | 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', 15 | } 16 | 17 | 18 | class Conv3DSimple(nn.Conv3d): 19 | def __init__(self, 20 | in_planes, 21 | out_planes, 22 | midplanes=None, 23 | stride=1, 24 | padding=1): 25 | 26 | super(Conv3DSimple, self).__init__( 27 | in_channels=in_planes, 28 | out_channels=out_planes, 29 | kernel_size=(3, 3, 3), 30 | stride=stride, 31 | padding=padding, 32 | bias=False) 33 | 34 | @staticmethod 35 | def get_downsample_stride(stride): 36 | return stride, stride, stride 37 | 38 | 39 | class Conv2Plus1D(nn.Sequential): 40 | 41 | def __init__(self, 42 | in_planes, 43 | out_planes, 44 | midplanes, 45 | stride=1, 46 | padding=1): 47 | super(Conv2Plus1D, self).__init__( 48 | nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), 49 | stride=(1, stride, stride), padding=(0, padding, padding), 50 | bias=False), 51 | nn.BatchNorm3d(midplanes), 52 | nn.ReLU(inplace=True), 53 | nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), 54 | stride=(stride, 1, 1), padding=(padding, 0, 0), 55 | bias=False)) 56 | 57 | @staticmethod 58 | def get_downsample_stride(stride): 59 | return stride, stride, stride 60 | 61 | 62 | class Conv3DNoTemporal(nn.Conv3d): 63 | 64 | def __init__(self, 65 | in_planes, 66 | out_planes, 67 | midplanes=None, 68 | stride=1, 69 | padding=1): 70 | 71 | super(Conv3DNoTemporal, self).__init__( 72 | in_channels=in_planes, 73 | out_channels=out_planes, 74 | kernel_size=(1, 3, 3), 75 | stride=(1, stride, stride), 76 | padding=(0, padding, padding), 77 | bias=False) 78 | 79 | @staticmethod 80 | def get_downsample_stride(stride): 81 | return 1, stride, stride 82 | 83 | 84 | class BasicBlock(nn.Module): 85 | 86 | expansion = 1 87 | 88 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 89 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 90 | 91 | super(BasicBlock, self).__init__() 92 | self.conv1 = nn.Sequential( 93 | conv_builder(inplanes, planes, midplanes, stride), 94 | nn.BatchNorm3d(planes), 95 | nn.ReLU(inplace=True) 96 | ) 97 | self.conv2 = nn.Sequential( 98 | conv_builder(planes, planes, midplanes), 99 | nn.BatchNorm3d(planes) 100 | ) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | 105 | def forward(self, x): 106 | residual = x 107 | 108 | out = self.conv1(x) 109 | out = self.conv2(out) 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class Bottleneck(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): 123 | 124 | super(Bottleneck, self).__init__() 125 | midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) 126 | 127 | # 1x1x1 128 | self.conv1 = nn.Sequential( 129 | nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), 130 | nn.BatchNorm3d(planes), 131 | nn.ReLU(inplace=True) 132 | ) 133 | # Second kernel 134 | self.conv2 = nn.Sequential( 135 | conv_builder(planes, planes, midplanes, stride), 136 | nn.BatchNorm3d(planes), 137 | nn.ReLU(inplace=True) 138 | ) 139 | 140 | # 1x1x1 141 | self.conv3 = nn.Sequential( 142 | nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), 143 | nn.BatchNorm3d(planes * self.expansion) 144 | ) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.downsample = downsample 147 | self.stride = stride 148 | 149 | def forward(self, x): 150 | residual = x 151 | 152 | out = self.conv1(x) 153 | out = self.conv2(out) 154 | out = self.conv3(out) 155 | 156 | if self.downsample is not None: 157 | residual = self.downsample(x) 158 | 159 | out += residual 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class BasicStem(nn.Sequential): 166 | """The default conv-batchnorm-relu stem 167 | """ 168 | def __init__(self): 169 | super(BasicStem, self).__init__( 170 | nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), 171 | padding=(1, 3, 3), bias=False), 172 | nn.BatchNorm3d(64), 173 | nn.ReLU(inplace=True)) 174 | 175 | 176 | class R2Plus1dStem(nn.Sequential): 177 | """R(2+1)D stem is different than the default one as it uses separated 3D convolution 178 | """ 179 | def __init__(self): 180 | super(R2Plus1dStem, self).__init__( 181 | nn.Conv3d(3, 45, kernel_size=(1, 7, 7), 182 | stride=(1, 2, 2), padding=(0, 3, 3), 183 | bias=False), 184 | nn.BatchNorm3d(45), 185 | nn.ReLU(inplace=True), 186 | nn.Conv3d(45, 64, kernel_size=(3, 1, 1), 187 | stride=(1, 1, 1), padding=(1, 0, 0), 188 | bias=False), 189 | nn.BatchNorm3d(64), 190 | nn.ReLU(inplace=True)) 191 | 192 | 193 | class VideoResNet(nn.Module): 194 | 195 | def __init__(self, block, conv_makers, layers, 196 | stem, num_classes=400, 197 | zero_init_residual=False): 198 | """Generic resnet video generator. 199 | 200 | Args: 201 | block (nn.Module): resnet building block 202 | conv_makers (list(functions)): generator function for each layer 203 | layers (List[int]): number of blocks per layer 204 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 205 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 206 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 207 | """ 208 | super(VideoResNet, self).__init__() 209 | self.inplanes = 64 210 | 211 | self.stem = stem() 212 | 213 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 214 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 215 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 216 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 217 | 218 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 219 | self.fc = nn.Linear(512 * block.expansion, num_classes) 220 | 221 | # init weights 222 | self._initialize_weights() 223 | 224 | if zero_init_residual: 225 | for m in self.modules(): 226 | if isinstance(m, Bottleneck): 227 | nn.init.constant_(m.bn3.weight, 0) 228 | 229 | def forward(self, x): 230 | x = self.stem(x) 231 | 232 | x = self.layer1(x) 233 | x = self.layer2(x) 234 | x = self.layer3(x) 235 | x = self.layer4(x) 236 | 237 | x = self.avgpool(x) 238 | # Flatten the layer to fc 239 | x = x.flatten(1) 240 | x = self.fc(x) 241 | 242 | return x 243 | 244 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 245 | downsample = None 246 | 247 | if stride != 1 or self.inplanes != planes * block.expansion: 248 | ds_stride = conv_builder.get_downsample_stride(stride) 249 | downsample = nn.Sequential( 250 | nn.Conv3d(self.inplanes, planes * block.expansion, 251 | kernel_size=1, stride=ds_stride, bias=False), 252 | nn.BatchNorm3d(planes * block.expansion) 253 | ) 254 | layers = [] 255 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 256 | 257 | self.inplanes = planes * block.expansion 258 | for i in range(1, blocks): 259 | layers.append(block(self.inplanes, planes, conv_builder)) 260 | 261 | return nn.Sequential(*layers) 262 | 263 | def _initialize_weights(self): 264 | for m in self.modules(): 265 | if isinstance(m, nn.Conv3d): 266 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 267 | nonlinearity='relu') 268 | if m.bias is not None: 269 | nn.init.constant_(m.bias, 0) 270 | elif isinstance(m, nn.BatchNorm3d): 271 | nn.init.constant_(m.weight, 1) 272 | nn.init.constant_(m.bias, 0) 273 | elif isinstance(m, nn.Linear): 274 | nn.init.normal_(m.weight, 0, 0.01) 275 | nn.init.constant_(m.bias, 0) 276 | 277 | 278 | 279 | class VideoResNet_Cntrst(nn.Module): 280 | 281 | def __init__(self, block, conv_makers, layers, 282 | stem, num_classes=400, 283 | zero_init_residual=False): 284 | """Generic resnet video generator. 285 | 286 | Args: 287 | block (nn.Module): resnet building block 288 | conv_makers (list(functions)): generator function for each layer 289 | layers (List[int]): number of blocks per layer 290 | stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. 291 | num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. 292 | zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. 293 | """ 294 | super(VideoResNet_Cntrst, self).__init__() 295 | self.inplanes = 64 296 | 297 | self.stem = stem() 298 | 299 | self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) 300 | self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) 301 | self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) 302 | self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) 303 | 304 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 305 | self.fc = nn.Linear(512 * block.expansion, num_classes) 306 | 307 | 308 | self._avg_pooling = nn.AdaptiveAvgPool3d((1, 1, 1)) 309 | self.fc_ctr = nn.Sequential(nn.Linear(512 * block.expansion, 512 * block.expansion), nn.ReLU(), nn.Linear(512 * block.expansion, 128)) 310 | 311 | # init weights 312 | self._initialize_weights() 313 | 314 | if zero_init_residual: 315 | for m in self.modules(): 316 | if isinstance(m, Bottleneck): 317 | nn.init.constant_(m.bn3.weight, 0) 318 | 319 | def forward(self, x): 320 | x = self.stem(x) 321 | 322 | x = self.layer1(x) 323 | x = self.layer2(x) 324 | x = self.layer3(x) 325 | x_common = self.layer4(x) 326 | 327 | x_ctrst = self._avg_pooling(x_common) 328 | x_ctrst = x_ctrst.flatten(1) 329 | x_ctrst = self.fc_ctr(x_ctrst) 330 | x_ctrst = F.normalize(x_ctrst, dim=1) 331 | 332 | x_reg = self.avgpool(x_common) 333 | # Flatten the layer to fc 334 | x_reg = x_reg.flatten(1) 335 | x_reg = self.fc(x_reg) 336 | 337 | return x_reg, x_ctrst 338 | 339 | def _make_layer(self, block, conv_builder, planes, blocks, stride=1): 340 | downsample = None 341 | 342 | if stride != 1 or self.inplanes != planes * block.expansion: 343 | ds_stride = conv_builder.get_downsample_stride(stride) 344 | downsample = nn.Sequential( 345 | nn.Conv3d(self.inplanes, planes * block.expansion, 346 | kernel_size=1, stride=ds_stride, bias=False), 347 | nn.BatchNorm3d(planes * block.expansion) 348 | ) 349 | layers = [] 350 | layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) 351 | 352 | self.inplanes = planes * block.expansion 353 | for i in range(1, blocks): 354 | layers.append(block(self.inplanes, planes, conv_builder)) 355 | 356 | return nn.Sequential(*layers) 357 | 358 | def _initialize_weights(self): 359 | for m in self.modules(): 360 | if isinstance(m, nn.Conv3d): 361 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 362 | nonlinearity='relu') 363 | if m.bias is not None: 364 | nn.init.constant_(m.bias, 0) 365 | elif isinstance(m, nn.BatchNorm3d): 366 | nn.init.constant_(m.weight, 1) 367 | nn.init.constant_(m.bias, 0) 368 | elif isinstance(m, nn.Linear): 369 | nn.init.normal_(m.weight, 0, 0.01) 370 | nn.init.constant_(m.bias, 0) 371 | 372 | 373 | 374 | def _video_resnet(arch, pretrained=False, progress=True, **kwargs): 375 | model = VideoResNet(**kwargs) 376 | 377 | if pretrained: 378 | state_dict = load_state_dict_from_url(model_urls[arch], 379 | progress=progress) 380 | model.load_state_dict(state_dict) 381 | return model 382 | 383 | 384 | def r3d_18(pretrained=False, progress=True, **kwargs): 385 | """Construct 18 layer Resnet3D model as in 386 | https://arxiv.org/abs/1711.11248 387 | 388 | Args: 389 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 390 | progress (bool): If True, displays a progress bar of the download to stderr 391 | 392 | Returns: 393 | nn.Module: R3D-18 network 394 | """ 395 | 396 | return _video_resnet('r3d_18', 397 | pretrained, progress, 398 | block=BasicBlock, 399 | conv_makers=[Conv3DSimple] * 4, 400 | layers=[2, 2, 2, 2], 401 | stem=BasicStem, **kwargs) 402 | 403 | 404 | def mc3_18(pretrained=False, progress=True, **kwargs): 405 | """Constructor for 18 layer Mixed Convolution network as in 406 | https://arxiv.org/abs/1711.11248 407 | 408 | Args: 409 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 410 | progress (bool): If True, displays a progress bar of the download to stderr 411 | 412 | Returns: 413 | nn.Module: MC3 Network definition 414 | """ 415 | return _video_resnet('mc3_18', 416 | pretrained, progress, 417 | block=BasicBlock, 418 | conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, 419 | layers=[2, 2, 2, 2], 420 | stem=BasicStem, **kwargs) 421 | 422 | 423 | def r2plus1d_18(pretrained=False, progress=True, **kwargs): 424 | """Constructor for the 18 layer deep R(2+1)D network as in 425 | https://arxiv.org/abs/1711.11248 426 | 427 | Args: 428 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 429 | progress (bool): If True, displays a progress bar of the download to stderr 430 | 431 | Returns: 432 | nn.Module: R(2+1)D-18 network 433 | """ 434 | return _video_resnet('r2plus1d_18', 435 | pretrained, progress, 436 | block=BasicBlock, 437 | conv_makers=[Conv2Plus1D] * 4, 438 | layers=[2, 2, 2, 2], 439 | stem=R2Plus1dStem, **kwargs) 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | def _video_resnet_ctrst(arch, pretrained=False, progress=True, **kwargs): 448 | model = VideoResNet_Cntrst(**kwargs) 449 | 450 | if pretrained: 451 | state_dict = load_state_dict_from_url(model_urls[arch], 452 | progress=progress) 453 | model.load_state_dict(state_dict, strict=False) 454 | return model 455 | 456 | 457 | def r2plus1d_18_ctrst(pretrained=False, progress=True, **kwargs): 458 | """Constructor for the 18 layer deep R(2+1)D network as in 459 | https://arxiv.org/abs/1711.11248 460 | 461 | Args: 462 | pretrained (bool): If True, returns a model pre-trained on Kinetics-400 463 | progress (bool): If True, displays a progress bar of the download to stderr 464 | 465 | Returns: 466 | nn.Module: R(2+1)D-18 network 467 | """ 468 | return _video_resnet_ctrst('r2plus1d_18', 469 | pretrained, progress, 470 | block=BasicBlock, 471 | conv_makers=[Conv2Plus1D] * 4, 472 | layers=[2, 2, 2, 2], 473 | stem=R2Plus1dStem, **kwargs) 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | class SupConLoss_admargin(nn.Module): 482 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 483 | It also supports the unsupervised contrastive loss in SimCLR""" 484 | def __init__(self, temperature=1, contrast_mode='all', 485 | base_temperature=1): 486 | super(SupConLoss_admargin, self).__init__() 487 | self.temperature = temperature 488 | self.contrast_mode = contrast_mode 489 | self.base_temperature = base_temperature 490 | 491 | def forward(self, features, labels=None, mask=None, dist = None, norm_val = 0.2, scale_s = 150): 492 | """Compute loss for model. If both `labels` and `mask` are None, 493 | it degenerates to SimCLR unsupervised loss: 494 | https://arxiv.org/pdf/2002.05709.pdf 495 | 496 | Args: 497 | features: hidden vector of shape [bsz, n_views, ...]. 498 | labels: ground truth of shape [bsz]. 499 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 500 | has the same class as sample i. Can be asymmetric. 501 | Returns: 502 | A loss scalar. 503 | """ 504 | device = (torch.device('cuda') 505 | if features.is_cuda 506 | else torch.device('cpu')) 507 | 508 | if len(features.shape) < 3: 509 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 510 | 'at least 3 dimensions are required') 511 | if len(features.shape) > 3: 512 | features = features.view(features.shape[0], features.shape[1], -1) 513 | 514 | batch_size = features.shape[0] 515 | if labels is not None and mask is not None: 516 | raise ValueError('Cannot define both `labels` and `mask`') 517 | elif labels is None and mask is None: 518 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 519 | elif labels is not None: 520 | labels = labels.contiguous().view(-1, 1) 521 | if labels.shape[0] != batch_size: 522 | raise ValueError('Num of labels does not match num of features') 523 | mask = torch.eq(labels, labels.T).float().to(device) 524 | else: 525 | mask = mask.float().to(device) 526 | 527 | contrast_count = features.shape[1] 528 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 529 | if self.contrast_mode == 'one': 530 | anchor_feature = features[:, 0] 531 | anchor_count = 1 532 | elif self.contrast_mode == 'all': 533 | anchor_feature = contrast_feature 534 | anchor_count = contrast_count 535 | # print(anchor_count) 536 | # exit() 537 | else: 538 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 539 | 540 | # print(dist) 541 | dist_expand = dist.unsqueeze(dim=-1) 542 | dist_expand = dist_expand.expand(-1, batch_size) 543 | dist_abdiff = torch.clamp(torch.multiply(torch.abs(torch.sub(dist_expand, dist_expand.T)), norm_val),0,2) 544 | 545 | dist_fullabdiff = dist_abdiff.repeat(anchor_count, contrast_count) 546 | ones_fullabdiff = torch.ones_like(dist_fullabdiff) 547 | 548 | anchor_dot_contrast = torch.div( 549 | torch.matmul(anchor_feature, contrast_feature.T), 550 | self.temperature) 551 | 552 | 553 | mask = mask.repeat(anchor_count, contrast_count) 554 | 555 | adjn_abdiff = torch.multiply(torch.sub(ones_fullabdiff, mask), dist_fullabdiff) 556 | adj_abdiff = adjn_abdiff 557 | 558 | anchor_dot_contrast = scale_s* (anchor_dot_contrast + adj_abdiff) 559 | 560 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 561 | logits = anchor_dot_contrast - logits_max.detach() 562 | 563 | logits_mask = torch.scatter( 564 | torch.ones_like(mask), 565 | 1, 566 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 567 | 0 568 | ) 569 | 570 | mask = mask * logits_mask 571 | 572 | exp_logits = torch.exp(logits) * logits_mask 573 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 574 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 575 | 576 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 577 | loss = loss.view(anchor_count, batch_size).mean() 578 | 579 | return loss 580 | 581 | 582 | 583 | 584 | 585 | -------------------------------------------------------------------------------- /echonet/utils/video.py: -------------------------------------------------------------------------------- 1 | """Functions for training and running EF prediction.""" 2 | 3 | import math 4 | import os 5 | import time 6 | import shutil 7 | import datetime 8 | import pandas as pd 9 | 10 | import click 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import sklearn.metrics 14 | import torch 15 | import torchvision 16 | import tqdm 17 | 18 | import echonet 19 | import echonet.models 20 | 21 | criterion_cntrst = echonet.models.SupConLoss_admargin(temperature = 1, base_temperature = 1).to("cuda") 22 | 23 | @click.command("video") 24 | @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) 25 | @click.option("--output", type=click.Path(file_okay=False), default=None) 26 | @click.option("--task", type=str, default="EF") 27 | @click.option("--model_name", type=click.Choice( 28 | sorted(name for name in torchvision.models.video.__dict__ 29 | if name.islower() and not name.startswith("__") and callable(torchvision.models.video.__dict__[name]))), 30 | default="r2plus1d_18") 31 | @click.option("--pretrained/--random", default=True) 32 | @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) 33 | @click.option("--run_test/--skip_test", default=False) 34 | @click.option("--stgno", type=int, default=0) 35 | @click.option("--num_epochs", type=int, default=45) 36 | @click.option("--lr", type=float, default=1e-4) 37 | @click.option("--weight_decay", type=float, default=1e-4) 38 | @click.option("--lr_step_period", type=int, default=15) 39 | @click.option("--frames", type=int, default=32) 40 | @click.option("--period", type=int, default=2) 41 | @click.option("--num_train_patients", type=int, default=None) 42 | @click.option("--num_workers", type=int, default=4) 43 | @click.option("--batch_size", type=int, default=20) 44 | @click.option("--device", type=str, default=None) 45 | @click.option("--seed", type=int, default=0) 46 | @click.option("--ctr_w", type=float, default=0.75) 47 | def run( 48 | data_dir=None, 49 | output=None, 50 | task="EF", 51 | 52 | model_name="r2plus1d_18", 53 | pretrained=True, 54 | weights=None, 55 | 56 | run_test=False, 57 | num_epochs=45, 58 | lr=1e-4, 59 | weight_decay=1e-4, 60 | lr_step_period=15, 61 | frames=32, 62 | period=2, 63 | num_train_patients=None, 64 | num_workers=4, 65 | batch_size=20, 66 | device=None, 67 | seed=0, 68 | 69 | stgno = 0, 70 | ctr_w = 0.75 71 | ): 72 | """Trains/tests EF prediction model. 73 | 74 | \b 75 | Args: 76 | data_dir (str, optional): Directory containing dataset. Defaults to 77 | `echonet.config.DATA_DIR`. 78 | output (str, optional): Directory to place outputs. Defaults to 79 | output/video/_/. 80 | task (str, optional): Name of task to predict. Options are the headers 81 | of FileList.csv. Defaults to ``EF''. 82 | model_name (str, optional): Name of model. One of ``mc3_18'', 83 | ``r2plus1d_18'', or ``r3d_18'' 84 | (options are torchvision.models.video.) 85 | Defaults to ``r2plus1d_18''. 86 | pretrained (bool, optional): Whether to use pretrained weights for model 87 | Defaults to True. 88 | weights (str, optional): Path to checkpoint containing weights to 89 | initialize model. Defaults to None. 90 | run_test (bool, optional): Whether or not to run on test. 91 | Defaults to False. 92 | num_epochs (int, optional): Number of epochs during training. 93 | Defaults to 45. 94 | lr (float, optional): Learning rate for SGD 95 | Defaults to 1e-4. 96 | weight_decay (float, optional): Weight decay for SGD 97 | Defaults to 1e-4. 98 | lr_step_period (int or None, optional): Period of learning rate decay 99 | (learning rate is decayed by a multiplicative factor of 0.1) 100 | Defaults to 15. 101 | frames (int, optional): Number of frames to use in clip 102 | Defaults to 32. 103 | period (int, optional): Sampling period for frames 104 | Defaults to 2. 105 | n_train_patients (int or None, optional): Number of training patients 106 | for ablations. Defaults to all patients. 107 | num_workers (int, optional): Number of subprocesses to use for data 108 | loading. If 0, the data will be loaded in the main process. 109 | Defaults to 4. 110 | device (str or None, optional): Name of device to run on. Options from 111 | https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device 112 | Defaults to ``cuda'' if available, and ``cpu'' otherwise. 113 | batch_size (int, optional): Number of samples to load per batch 114 | Defaults to 20. 115 | seed (int, optional): Seed for random number generator. Defaults to 0. 116 | """ 117 | 118 | # Seed RNGs 119 | np.random.seed(seed) 120 | torch.manual_seed(seed) 121 | 122 | def worker_init_fn(worker_id): 123 | np.random.seed(np.random.get_state()[1][0] + worker_id) 124 | 125 | # Set default output directory 126 | if output is None: 127 | output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) 128 | os.makedirs(output, exist_ok=True) 129 | 130 | if os.path.isdir(os.path.join(output, "echonet_{}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))): 131 | shutil.rmtree(os.path.join(output, "echonet_{}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))) 132 | shutil.copytree("echonet", os.path.join(output, "echonet_{}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))) 133 | 134 | # Set device for computations 135 | if device is None: 136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 | 138 | model = echonet.models.rnet2dp1.r2plus1d_18_ctrst(pretrained=pretrained) 139 | model.fc = torch.nn.Linear(model.fc.in_features, 1) 140 | model.fc.bias.data[0] = 55.6 141 | if device.type == "cuda": 142 | model = torch.nn.DataParallel(model) 143 | model.to(device) 144 | 145 | 146 | if weights is not None: 147 | checkpoint = torch.load(weights) 148 | model.load_state_dict(checkpoint['state_dict']) 149 | 150 | optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) 151 | if lr_step_period is None: 152 | lr_step_period = math.inf 153 | scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) 154 | 155 | mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) 156 | print("mean std", mean, std) 157 | kwargs = {"target_type": task, 158 | "mean": mean, 159 | "std": std, 160 | "length": frames, 161 | "period": period, 162 | } 163 | 164 | 165 | dataset = {} 166 | dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12) 167 | if num_train_patients is not None and len(dataset["train"]) > num_train_patients: 168 | indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) 169 | dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) 170 | dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) 171 | 172 | with open(os.path.join(output, "log.csv"), "a") as f: 173 | epoch_resume = 0 174 | bestLoss = float("inf") 175 | try: 176 | checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) 177 | model.load_state_dict(checkpoint['state_dict']) 178 | optim.load_state_dict(checkpoint['opt_dict']) 179 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 180 | epoch_resume = checkpoint["epoch"] + 1 181 | bestLoss = checkpoint["best_loss"] 182 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 183 | except FileNotFoundError: 184 | f.write("Starting run from scratch\n") 185 | 186 | for epoch in range(epoch_resume, num_epochs): 187 | print("Epoch #{}".format(epoch), flush=True) 188 | for phase in ['train', 'val']: 189 | start_time = time.time() 190 | for i in range(torch.cuda.device_count()): 191 | torch.cuda.reset_peak_memory_stats(i) 192 | 193 | ds = dataset[phase] 194 | dataloader = torch.utils.data.DataLoader( 195 | ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"), worker_init_fn=worker_init_fn) 196 | 197 | loss, loss_reg, loss_ctr, yhat, y, _, _ = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device, stgno = stgno, ctr_w = ctr_w) 198 | 199 | f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, 200 | phase, 201 | loss, 202 | sklearn.metrics.r2_score(y, yhat), 203 | time.time() - start_time, 204 | y.size, 205 | sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), 206 | sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), 207 | batch_size, 208 | loss_reg, 209 | loss_ctr)) 210 | f.flush() 211 | scheduler.step() 212 | 213 | save = { 214 | 'epoch': epoch, 215 | 'state_dict': model.state_dict(), 216 | 'period': period, 217 | 'frames': frames, 218 | 'best_loss': bestLoss, 219 | 'loss': loss, 220 | 'r2': sklearn.metrics.r2_score(y, yhat), 221 | 'opt_dict': optim.state_dict(), 222 | 'scheduler_dict': scheduler.state_dict(), 223 | } 224 | torch.save(save, os.path.join(output, "checkpoint.pt")) 225 | if loss_reg < bestLoss: 226 | torch.save(save, os.path.join(output, "best.pt")) 227 | bestLoss = loss_reg 228 | 229 | if num_epochs != 0: 230 | checkpoint = torch.load(os.path.join(output, "best.pt")) 231 | model.load_state_dict(checkpoint['state_dict']) 232 | f.write("Best validation loss {} from epoch {}, R2 {}\n".format(checkpoint["loss"], checkpoint["epoch"], checkpoint["r2"])) 233 | f.flush() 234 | 235 | if run_test: 236 | for split in ["test", "val"]: 237 | np.random.seed(seed) 238 | torch.manual_seed(seed) 239 | 240 | ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all") 241 | yhat, y = echonet.utils.video.test_epoch_all(model, ds, False, None, device, save_all=True, block_size=batch_size, run_dir = output, test_val = split, **kwargs) 242 | 243 | f.write("Seed is {} \n".format(seed)) 244 | f.write("{} - {} (all clips, mod) R2: {:.3f} ({:.3f} - {:.3f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) 245 | f.write("{} - {} (all clips, mod) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) 246 | f.write("{} - {} (all clips, mod) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) 247 | f.flush() 248 | 249 | echonet.utils.latexify() 250 | 251 | fig = plt.figure(figsize=(3, 3)) 252 | lower = min(y.min(), yhat.min()) 253 | upper = max(y.max(), yhat.max()) 254 | plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2) 255 | plt.plot([0, 100], [0, 100], linewidth=1, zorder=3) 256 | plt.axis([lower - 3, upper + 3, lower - 3, upper + 3]) 257 | plt.gca().set_aspect("equal", "box") 258 | plt.xlabel("Actual EF (%)") 259 | plt.ylabel("Predicted EF (%)") 260 | plt.xticks([10, 20, 30, 40, 50, 60, 70, 80]) 261 | plt.yticks([10, 20, 30, 40, 50, 60, 70, 80]) 262 | plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1) 263 | plt.tight_layout() 264 | plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split))) 265 | plt.close(fig) 266 | 267 | # Plot AUROC 268 | fig = plt.figure(figsize=(3, 3)) 269 | plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--") 270 | for thresh in [35, 40, 45, 50]: 271 | fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat) 272 | print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat)) 273 | plt.plot(fpr, tpr) 274 | 275 | plt.axis([-0.01, 1.01, -0.01, 1.01]) 276 | plt.xlabel("False Positive Rate") 277 | plt.ylabel("True Positive Rate") 278 | plt.tight_layout() 279 | plt.savefig(os.path.join(output, "{}_roc.pdf".format(split))) 280 | plt.close(fig) 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | def test_epoch_all(model, dataset, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, mean = None, std = None, length = None, period = None, target_type = None): 291 | model.train(False) 292 | 293 | total = 0 294 | total_reg = 0 295 | total_ctr = 0 296 | 297 | n = 0 298 | s1 = 0 299 | s2 = 0 300 | 301 | yhat = [] 302 | y = [] 303 | 304 | if (mean is None) or (std is None) or (length is None) or (period is None): 305 | assert 1==2, "missing key params" 306 | 307 | max_length = 250 308 | 309 | if run_dir: 310 | outftcltdir = os.path.join(run_dir, "feat_collect_{}".format(test_val)) 311 | if not os.path.isdir(outftcltdir): 312 | os.makedirs(outftcltdir) 313 | 314 | temp_savefile = os.path.join(run_dir, "temp_inference_{}.csv".format(test_val)) 315 | 316 | with torch.set_grad_enabled(False): 317 | orig_filelist = dataset.fnames 318 | 319 | if os.path.isfile(temp_savefile): 320 | exist_data = pd.read_csv(temp_savefile) 321 | exist_file = list(exist_data['fnames']) 322 | target_filelist = sorted(list(set(orig_filelist) - set(exist_file))) 323 | else: 324 | target_filelist = sorted(list(orig_filelist)) 325 | exist_data = pd.DataFrame(columns = ['fnames', 'yhat']) 326 | 327 | for filelistitr_idx in range(len(target_filelist)): 328 | filelistitr = target_filelist[filelistitr_idx] 329 | 330 | video_path = os.path.join(echonet.config.DATA_DIR, "Videos", filelistitr) 331 | ### Get data 332 | video = echonet.utils.loadvideo(video_path).astype(np.float32) 333 | 334 | if isinstance(mean, (float, int)): 335 | video -= mean 336 | else: 337 | video -= mean.reshape(3, 1, 1, 1) 338 | 339 | if isinstance(std, (float, int)): 340 | video /= std 341 | else: 342 | video /= std.reshape(3, 1, 1, 1) 343 | 344 | c, f, h, w = video.shape 345 | if length is None: 346 | length = f // period 347 | else: 348 | length = length 349 | 350 | if max_length is not None: 351 | length = min(length, max_length) 352 | 353 | if f < length * period: 354 | video = np.concatenate((video, np.zeros((c, length * period - f, h, w), video.dtype)), axis=1) 355 | c, f, h, w = video.shape 356 | 357 | start = np.arange(f - (length - 1) * period) 358 | 359 | reg1 = [] 360 | n_clips = start.shape[0] 361 | batch = 1 362 | for s_itr in range(0, start.shape[0], block_size): 363 | print("{}, processing file {} out of {}, block {} out of {}".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), filelistitr_idx, len(target_filelist), s_itr, start.shape[0]), flush=True) 364 | 365 | vid_samp = tuple(video[:, s + period * np.arange(length), :, :] for s in start[s_itr: s_itr + block_size]) 366 | X1 = torch.tensor(np.stack(vid_samp)) 367 | X1 = X1.to(device) 368 | 369 | all_output = model(X1) 370 | reg1.append(all_output[0].detach().cpu().numpy()) 371 | reg1 = np.vstack(reg1) 372 | reg1_mean = reg1.reshape(batch, n_clips, -1).mean(1) 373 | 374 | exist_data = exist_data.append({'fnames':filelistitr, 'yhat':reg1_mean[0,0]}, ignore_index=True) 375 | 376 | if filelistitr_idx % 20 == 0: 377 | exist_data.to_csv(temp_savefile, index = False) 378 | 379 | 380 | label_data_path = os.path.join(echonet.config.DATA_DIR, "FileList.csv") 381 | label_data = pd.read_csv(label_data_path) 382 | label_data_select = label_data[['FileName','EF']] 383 | label_data_select.columns = ['fnames','EF'] 384 | with_predict = exist_data.merge(label_data_select, on='fnames') 385 | 386 | predict_out_path = os.path.join(run_dir, "{}_predictions.csv".format(test_val)) 387 | with_predict.to_csv(predict_out_path, index=False) 388 | 389 | 390 | return with_predict['yhat'].to_numpy(), with_predict['EF'].to_numpy() 391 | 392 | 393 | def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None, run_dir = None, test_val = None, stgno = 0, ctr_w = 0.75): 394 | """Run one epoch of training/evaluation for segmentation. 395 | 396 | Args: 397 | model (torch.nn.Module): Model to train/evaulate. 398 | dataloder (torch.utils.data.DataLoader): Dataloader for dataset. 399 | train (bool): Whether or not to train model. 400 | optim (torch.optim.Optimizer): Optimizer 401 | device (torch.device): Device to run on 402 | save_all (bool, optional): If True, return predictions for all 403 | test-time augmentations separately. If False, return only 404 | the mean prediction. 405 | Defaults to False. 406 | block_size (int or None, optional): Maximum number of augmentations 407 | to run on at the same time. Use to limit the amount of memory 408 | used. If None, always run on all augmentations simultaneously. 409 | Default is None. 410 | """ 411 | 412 | model.train(train) 413 | 414 | total = 0 415 | total_reg = 0 416 | total_ctr = 0 417 | 418 | n = 0 419 | s1 = 0 420 | s2 = 0 421 | 422 | yhat = [] 423 | y = [] 424 | start_frame_record = [] 425 | vidpath_record = [] 426 | 427 | if run_dir: 428 | outftcltdir = os.path.join(run_dir, "feat_collect_{}".format(test_val)) 429 | if not os.path.isdir(outftcltdir): 430 | os.makedirs(outftcltdir) 431 | 432 | with torch.set_grad_enabled(train): 433 | with tqdm.tqdm(total=len(dataloader)) as pbar: 434 | enum_idx = 0 435 | for (X1, X2, outcome, outcome_cls, start_frame, video_path) in dataloader: 436 | if run_dir: 437 | outpath_ftvl = os.path.join(outftcltdir, "featvl_{}".format(enum_idx)) 438 | outpath_ftv2 = os.path.join(outftcltdir, "featv2_{}".format(enum_idx)) 439 | outpath_lb = os.path.join(outftcltdir, "lb_{}".format(enum_idx)) 440 | enum_idx = enum_idx + 1 441 | 442 | bsz = outcome.shape[0] 443 | 444 | if run_dir: 445 | np.save(outpath_lb, outcome.detach().cpu().numpy()) 446 | 447 | y.append(outcome.detach().cpu().numpy()) 448 | X1 = X1.to(device) 449 | X2 = X2.to(device) 450 | X = torch.cat((X1, X2), dim=0).to('cuda') 451 | 452 | outcome = outcome.to(device) 453 | outcome_cls = outcome_cls.to(device) 454 | 455 | n_clips = 0 456 | s1 += outcome.sum() 457 | s2 += (outcome ** 2).sum() 458 | 459 | all_ouput = model(X) 460 | outputs = all_ouput[0] 461 | ctr_feat = all_ouput[1] 462 | 463 | f1, f2 = torch.split(ctr_feat, [bsz, bsz], dim=0) 464 | 465 | if run_dir: 466 | np.save(outpath_ftvl, f1.detach().cpu().numpy()) 467 | np.save(outpath_ftv2, f2.detach().cpu().numpy()) 468 | 469 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 470 | 471 | reg1, reg2 = torch.split(outputs, [bsz, bsz], dim=0) 472 | 473 | yhat.append(reg1.view(-1).to("cpu").detach().numpy()) 474 | 475 | if not train: 476 | start_frame_record.append(start_frame.view(-1).to("cpu").detach().numpy()) 477 | vidpath_record.append(video_path) 478 | 479 | 480 | loss_reg = torch.nn.functional.mse_loss(reg1.view(-1), outcome) 481 | loss_ctr = criterion_cntrst(features, outcome_cls, dist = outcome_cls, norm_val=2/7465, scale_s = 150) 482 | 483 | loss = 1 * loss_reg + 0.75 * loss_ctr 484 | 485 | if train: 486 | optim.zero_grad() 487 | loss.backward() 488 | optim.step() 489 | 490 | total += loss.item() * outcome.size(0) 491 | total_reg += loss_reg.item() * outcome.size(0) 492 | total_ctr += loss_ctr.item() * outcome.size(0) 493 | 494 | n += outcome.size(0) 495 | pbar.set_postfix_str("{:.2f} {:.2f} {:.2f} ({:.2f}) / {:.2f} {}".format(total / n, total_reg / n, total_ctr / n, loss.item(), s2 / n - (s1 / n) ** 2, n_clips)) 496 | pbar.update() 497 | 498 | if not save_all: 499 | yhat = np.concatenate(yhat) 500 | if not train: 501 | start_frame_record = np.concatenate(start_frame_record) 502 | 503 | y = np.concatenate(y) 504 | 505 | return total / n, total_reg / n, total_ctr / n, yhat, y, start_frame_record, vidpath_record 506 | 507 | 508 | 509 | --------------------------------------------------------------------------------