├── .gitignore ├── res ├── env │ ├── dataset.json │ └── paths.json ├── models │ ├── online_bk.json │ ├── online_d.json │ └── online_fm.json ├── run │ ├── hp_bk.json │ ├── hp_d.json │ └── hp_fm.json └── datasets │ ├── Spine.json │ ├── DDH.json │ └── Fetus.json ├── datasets ├── functional │ ├── __init__.py │ └── common.py ├── __init__.py ├── BaseDataset.py ├── Spine.py ├── DDH.py └── Fetus.py ├── models ├── functional │ ├── __init__.py │ └── common.py ├── layers │ ├── __init__.py │ ├── convolutional_rnn │ │ ├── utils.py │ │ ├── __init__.py │ │ ├── functional.py │ │ └── module.py │ ├── canny.py │ └── resnet3d.py ├── __init__.py ├── BaseModel.py ├── online_backbone.py ├── online_discriminator.py └── online_framework.py ├── requirements.txt ├── configs ├── __init__.py ├── Env.py ├── Run.py └── BaseConfig.py ├── utils ├── __init__.py ├── common.py ├── image.py ├── logger.py ├── metric.py ├── reconstruction.py └── simulation.py ├── README.md ├── LICENSE ├── scripts ├── ol_curve.py └── visual_reco.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | save 4 | test -------------------------------------------------------------------------------- /res/env/dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_workers": 4, 3 | "pin_memory": true 4 | } -------------------------------------------------------------------------------- /datasets/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common 2 | 3 | 4 | __all__ = ['common'] 5 | -------------------------------------------------------------------------------- /models/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common 2 | 3 | 4 | __all__ = ['common'] 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | torch 4 | scipy 5 | timm 6 | argparse 7 | pyvista -------------------------------------------------------------------------------- /res/models/online_bk.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Online_Backbone", 3 | "weight_motion": 0.1 4 | } -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import convolutional_rnn 2 | from . import canny 3 | from . import resnet3d 4 | -------------------------------------------------------------------------------- /res/models/online_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Online_Discriminator", 3 | "backbone_weight": "./save/online_bk-hp_bk-Spine/online_bk_backbone_last.pth" 4 | } -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseConfig import BaseConfig 2 | from .Env import env 3 | from .Run import Run 4 | 5 | __all__ = ['BaseConfig', 'env', 'Run'] 6 | -------------------------------------------------------------------------------- /models/functional/common.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | __all__ = ['find'] 4 | 5 | 6 | def find(name): 7 | model = getattr(models, name, None) 8 | return model if model is not None and issubclass(model, models.BaseModel) else None 9 | -------------------------------------------------------------------------------- /res/run/hp_bk.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hp_backbone", 3 | "batch_size": 1, 4 | "test_batch_size": 1, 5 | "epochs": 200, 6 | "save_step": 1, 7 | "lr": 1e-4, 8 | "betas": [0.9, 0.999], 9 | "step_size": 30, 10 | "gamma": 0.5 11 | } -------------------------------------------------------------------------------- /res/run/hp_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hp_discriminator", 3 | "batch_size": 1, 4 | "test_batch_size": 1, 5 | "epochs": 50, 6 | "save_step": 1, 7 | "lr": 1e-4, 8 | "betas": [0.9, 0.999], 9 | "step_size": 30, 10 | "gamma": 0.5 11 | } -------------------------------------------------------------------------------- /res/run/hp_fm.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hp_framework", 3 | "batch_size": 1, 4 | "test_batch_size": 1, 5 | "epochs": 0, 6 | "save_step": 1, 7 | "ol_epochs": 30, 8 | "lr_psc": 5e-6, 9 | "lr_fcc_gas": 5e-7, 10 | "betas": [0.9, 0.999] 11 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | 3 | from . import common 4 | from . import metric, image 5 | from . import simulation, reconstruction 6 | 7 | 8 | __all__ = [ 9 | 'Logger', 10 | 'common', 11 | 'metric', 'image', 12 | 'simulation', 'reconstruction', 13 | ] 14 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseDataset import BaseDataset, BaseSplit 2 | from . import functional 3 | 4 | from .DDH import DDH 5 | from .Fetus import Fetus 6 | from .Spine import Spine 7 | 8 | 9 | __all__ = [ 10 | 'BaseDataset', 'BaseSplit', 'functional', 11 | 12 | 'DDH', 'Fetus', 'Spine', 13 | ] 14 | -------------------------------------------------------------------------------- /res/env/paths.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_folder": ".", 3 | "dataset_cfgs_folder": "res/datasets", 4 | "model_cfgs_folder": "res/models", 5 | "run_cfgs_folder": "res/run", 6 | "save_folder": "save", 7 | "check_file": "_checkpoint.pth", 8 | "loss_file": "_loss.npy", 9 | "predict_file": "_predict.npy", 10 | "logging_file": "_logging.log" 11 | } -------------------------------------------------------------------------------- /datasets/functional/common.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | 3 | __all__ = ['more', 'find'] 4 | 5 | 6 | def more(cfg): 7 | dataset = getattr(datasets, cfg.name, None) 8 | return dataset.more(dataset._more(cfg)) if dataset else cfg 9 | 10 | 11 | def find(name): 12 | dataset = getattr(datasets, name, None) 13 | return dataset if dataset is not None and issubclass(dataset, datasets.BaseDataset) else None 14 | -------------------------------------------------------------------------------- /res/models/online_fm.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Online_Framework", 3 | "backbone_weight": "./save/online_bk-hp_bk-Spine/online_bk_backbone_last.pth", 4 | "discriminator_weight": "./save/online_d-hp_d-Spine/online_d_discriminator_last.pth", 5 | "discriminator_opt_cycle": 5, 6 | "down_ratio": 0.3, 7 | "psc_epoch": 3, 8 | "psc_threshold": 0.4, 9 | "psc_max_acquisition": 20, 10 | "reco_rate": 0.5, 11 | "weight_fcc": 0.01 12 | } -------------------------------------------------------------------------------- /res/datasets/Spine.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Spine", 3 | "paths": { 4 | "source": "../../datasets/spine/data_300.npy", 5 | "target": "../../datasets/spine/label_dp_300.npy" 6 | }, 7 | "series_per_data": [100, 20, 20], 8 | "train_test_range": [45, 12, 11], 9 | "frame_rate": [2, 11], 10 | "source": { 11 | "width": 300, 12 | "height": 300, 13 | "channel": 6 14 | }, 15 | "target": { 16 | "elements": 15 17 | } 18 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseModel import BaseModel 2 | 3 | from . import functional 4 | from . import layers 5 | 6 | from .online_backbone import Online_Backbone 7 | from .online_discriminator import Online_Discriminator 8 | from .online_framework import Online_Framework 9 | 10 | 11 | __all__ = [ 12 | 'BaseModel', 'functional', 13 | 14 | 'layers', 15 | 16 | 'Online_Backbone', 17 | 'Online_Discriminator', 18 | 'Online_Framework', 19 | ] 20 | -------------------------------------------------------------------------------- /models/layers/convolutional_rnn/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from itertools import repeat 3 | 4 | 5 | """ Copied from torch.nn.modules.utils """ 6 | 7 | 8 | def _ntuple(n): 9 | def parse(x): 10 | if isinstance(x, Iterable): 11 | return x 12 | return tuple(repeat(x, n)) 13 | return parse 14 | 15 | 16 | _single = _ntuple(1) 17 | _pair = _ntuple(2) 18 | _triple = _ntuple(3) 19 | _quadruple = _ntuple(4) 20 | -------------------------------------------------------------------------------- /res/datasets/DDH.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DDH", 3 | "paths": { 4 | "source": "../../datasets/3D_DDH/all", 5 | "order": "../../datasets/3D_DDH/all.npy" 6 | }, 7 | "series_per_data": [40, 20, 20], 8 | "series_min_length": 90, 9 | "series_max_length": 90, 10 | "ps": 3, 11 | "train_test_range": [135, 16, 18], 12 | "frame_rate": [1, 10], 13 | "load_mode": "memory", 14 | "source": { 15 | "width": 300, 16 | "height": 300, 17 | "channel": 6, 18 | "origin": [185, 250, 199], 19 | "max_distance": [160, 60, 120] 20 | }, 21 | "target": { 22 | "elements": 15 23 | } 24 | } -------------------------------------------------------------------------------- /res/datasets/Fetus.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Fetus", 3 | "paths": { 4 | "source": "../../datasets/us_pose_data/imgs_trans_128", 5 | "order": "../../datasets/us_pose_data/Fetus_order_128.npy" 6 | }, 7 | "series_per_data": [60, 20, 20], 8 | "series_min_length": 90, 9 | "series_max_length": 90, 10 | "ps": 1, 11 | "train_test_range": [98, 15, 15], 12 | "frame_rate": [1, 10], 13 | "load_mode": "memory", 14 | "source": { 15 | "width": 300, 16 | "height": 300, 17 | "channel": 6, 18 | "origin": [232, 262, 274], 19 | "max_distance": [100, 50, 50] 20 | }, 21 | "target": { 22 | "elements": 15 23 | } 24 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RecON 2 | 3 | This repository is the official implementation for "[RecON: Online Learning for Sensorless Freehand 3D Ultrasound Reconstruction](https://doi.org/10.1016/j.media.2023.102810)". 4 | 5 | ## Environment 6 | - PyTorch with GPU 7 | - OpenCV-Python build from CUDA 8 | - Run `pip install -r requirements.txt` 9 | 10 | ## Training 11 | - Backbone 12 | ```shell 13 | python3 -m main -m online_bk -d Spine -r hp_bk -g0 14 | ``` 15 | - Discriminator 16 | ```shell 17 | python3 -m main -m online_d -d Spine -r hp_d -g0 18 | ``` 19 | 20 | ## Online Learning 21 | ```shell 22 | python3 -m main -m online_fm -d Spine -r hp_fm -g0 -t0 23 | ``` 24 | 25 | ## Demo 26 | An interactive demo is available in [here](http://apps.myluo.cn/RecON). -------------------------------------------------------------------------------- /models/layers/convolutional_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kamo-naoyuki/pytorch_convolutional_rnn 2 | 3 | from .module import Conv1dRNN 4 | from .module import Conv1dLSTM 5 | from .module import Conv1dPeepholeLSTM 6 | from .module import Conv1dGRU 7 | 8 | from .module import Conv2dRNN 9 | from .module import Conv2dLSTM 10 | from .module import Conv2dPeepholeLSTM 11 | from .module import Conv2dGRU 12 | 13 | from .module import Conv3dRNN 14 | from .module import Conv3dLSTM 15 | from .module import Conv3dPeepholeLSTM 16 | from .module import Conv3dGRU 17 | 18 | from .module import Conv1dRNNCell 19 | from .module import Conv1dLSTMCell 20 | from .module import Conv1dPeepholeLSTMCell 21 | from .module import Conv1dGRUCell 22 | 23 | from .module import Conv2dRNNCell 24 | from .module import Conv2dLSTMCell 25 | from .module import Conv2dPeepholeLSTMCell 26 | from .module import Conv2dGRUCell 27 | 28 | from .module import Conv3dRNNCell 29 | from .module import Conv3dLSTMCell 30 | from .module import Conv3dPeepholeLSTMCell 31 | from .module import Conv3dGRUCell 32 | -------------------------------------------------------------------------------- /configs/Env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import configs 4 | 5 | __all__ = ['env'] 6 | 7 | 8 | class Env(configs.BaseConfig): 9 | 10 | def __init__(self): 11 | super(Env, self).__init__({}) 12 | cfg_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../res/env')) 13 | if not os.path.exists(cfg_dir): 14 | os.makedirs(cfg_dir) 15 | for file in os.listdir(cfg_dir): 16 | setattr(self, os.path.splitext(file)[0], configs.BaseConfig(os.path.join(cfg_dir, file))) 17 | if hasattr(self, 'paths') and hasattr(self.paths, 'root_folder'): 18 | self.paths.root_folder = os.path.abspath( 19 | os.path.join(os.path.dirname(__file__), '..', self.paths.root_folder)) 20 | else: 21 | raise ValueError('Lack of `res/env/paths.json` file or `root_folder` value') 22 | 23 | def getdir(self, path): 24 | return os.path.abspath(os.path.join(self.paths.root_folder, path)) 25 | 26 | def chdir(self, path): 27 | os.chdir(self.getdir(path)) 28 | 29 | 30 | env = Env() 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2020-present Mingyuan Luo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /configs/Run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import configs 6 | 7 | __all__ = ['Run'] 8 | 9 | 10 | class Run(configs.BaseConfig): 11 | 12 | epochs: int 13 | save_step: int 14 | batch_size: int 15 | test_batch_size: int 16 | 17 | def __init__(self, cfg, gpus: str = '0', **kwargs): 18 | super(Run, self).__init__(cfg, gpus=gpus, **kwargs) 19 | self._more() 20 | 21 | def _more(self): 22 | self._set_gpus() 23 | if self.gpus: 24 | self.cuda = torch.cuda.is_available() and getattr(self, 'cuda', True) 25 | self.device = torch.device("cuda", 0) if self.cuda else torch.device("cpu") 26 | else: 27 | self.cuda = False 28 | self.device = torch.device('cpu') 29 | 30 | torch.backends.cudnn.enabled = True 31 | torch.backends.cudnn.benchmark = True 32 | torch.backends.cudnn.deterministic = True 33 | 34 | def _set_gpus(self): 35 | if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): 36 | self.gpus = os.environ['CUDA_VISIBLE_DEVICES'] 37 | 38 | if self.gpus.lower() == 'cpu': 39 | self.gpus = [] 40 | elif self.gpus == '': 41 | self.gpus = list(range(torch.cuda.device_count())) 42 | else: 43 | self.gpus = [int(g) for g in self.gpus.split(',')] 44 | 45 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in self.gpus]) 46 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import configs 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | __all__ = ['set_seed', 'merge_dict', 'get_filename', 'get_path', 'real_config_path'] 10 | 11 | 12 | def set_seed(seed=0): 13 | seed = int(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | 21 | 22 | def merge_dict(dst: dict, src: dict): 23 | for key, value in src.items(): 24 | if isinstance(value, torch.Tensor): 25 | value = value.unsqueeze(-1) 26 | if key in dst.keys(): 27 | dst[key] = torch.cat([dst[key], value.detach()]) 28 | else: 29 | dst[key] = value.detach() 30 | else: 31 | if key in dst.keys(): 32 | dst[key].append(value) 33 | else: 34 | dst[key] = [value] 35 | 36 | 37 | def get_filename(path): 38 | return os.path.splitext(os.path.split(path)[1])[0] 39 | 40 | 41 | def get_path(model_cfg, dataset_cfg, run_cfg): 42 | dirname = get_filename(model_cfg._path) + '-' + get_filename(run_cfg._path) + '-' + get_filename(dataset_cfg._path) 43 | return os.path.join(configs.env.getdir(configs.env.paths.save_folder), dirname) 44 | 45 | 46 | def real_config_path(config_path, set_folder): 47 | if os.path.exists(config_path): 48 | return os.path.abspath(config_path) 49 | else: 50 | return configs.env.getdir(os.path.join(set_folder, config_path + '.json')) 51 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from models.layers.canny import Canny2D 8 | 9 | 10 | def get_optical_flow(slices, device): 11 | height, width = slices.shape[-2], slices.shape[-1] 12 | if not hasattr(get_optical_flow, 'of'): 13 | if hasattr(cv2, 'cuda_NvidiaOpticalFlow_2_0'): 14 | of = cv2.cuda_NvidiaOpticalFlow_2_0.create((width, height), 5, 1, 1, False, False, False, device.index) 15 | elif hasattr(cv2, 'cuda_NvidiaOpticalFlow_1_0'): 16 | of = cv2.cuda_NvidiaOpticalFlow_1_0.create(width, height, 5, False, False, False, device.index) 17 | warnings.warn('use cuda_NvidiaOpticalFlow_1_0!') 18 | else: 19 | warnings.warn('opencv-python not support cuda!') 20 | return None 21 | get_optical_flow.of = of 22 | 23 | of = get_optical_flow.of 24 | slices_np = slices.type(torch.uint8).cpu().numpy() 25 | flows = [] 26 | for i in range(1, len(slices)): 27 | flow = of.calc(slices_np[i - 1, :, :], slices_np[i, :, :], None) 28 | if hasattr(cv2, 'cuda_NvidiaOpticalFlow_2_0'): 29 | flow = of.convertToFloat(flow[0], None) 30 | else: 31 | flow = of.upSampler(flow[0], width, height, of.getGridSize(), None) 32 | flows.append(flow) 33 | flows = np.stack(flows, axis=0).transpose((0, 3, 1, 2)) 34 | flows = torch.from_numpy(flows).type(slices.dtype).to(slices.device) 35 | return flows 36 | 37 | 38 | def get_edge(slices, device, threshold=0.1): 39 | if not hasattr(get_edge, 'canny'): 40 | get_edge.canny = Canny2D(threshold=threshold).to(device) 41 | with torch.no_grad(): 42 | edge = get_edge.canny(slices.unsqueeze(1)) 43 | return edge 44 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | 5 | import numpy as np 6 | 7 | import configs 8 | 9 | __all__ = ['Logger'] 10 | 11 | 12 | class Logger(object): 13 | 14 | def __init__(self, path, prefix): 15 | self.path = path 16 | self.logging_file = os.path.join(path, prefix + configs.env.paths.logging_file) 17 | self._mkfile() 18 | self.logger = self._setlogger() 19 | 20 | def _mkfile(self): 21 | if not os.path.exists(self.path): 22 | os.makedirs(self.path) 23 | if platform.system() == 'Windows': 24 | open(self.logging_file, 'a') 25 | else: 26 | not os.path.exists(self.logging_file) and os.mknod(self.logging_file) 27 | 28 | def _setlogger(self): 29 | global _utils_logger 30 | if '_utils_logger' not in globals(): 31 | _utils_logger = logging.getLogger() 32 | else: 33 | for idx in reversed(range(len(_utils_logger.handlers))): 34 | _utils_logger.handlers[idx].close() 35 | _utils_logger.removeHandler(_utils_logger.handlers[idx]) 36 | 37 | _utils_logger.setLevel(level=logging.INFO) 38 | formatter = logging.Formatter('%(asctime)s : %(message)s') 39 | 40 | handler = logging.FileHandler(self.logging_file) 41 | handler.setLevel(logging.INFO) 42 | handler.setFormatter(formatter) 43 | _utils_logger.addHandler(handler) 44 | console = logging.StreamHandler() 45 | console.setLevel(logging.INFO) 46 | console.setFormatter(formatter) 47 | _utils_logger.addHandler(console) 48 | 49 | return _utils_logger 50 | 51 | def info(self, msg): 52 | self.logger.info(msg) 53 | 54 | def info_scalars(self, msg: str, infos: tuple, scalars: dict): 55 | scalars_list = list() 56 | if scalars: 57 | for name, value in scalars.items(): 58 | if not name.startswith('_'): 59 | msg += ' ' + name + ': {:.6f}' 60 | scalars_list.append(value) 61 | self.info(msg.format(*infos, *scalars_list)) 62 | 63 | def save_npy(self, filename, data): 64 | np.save(os.path.join(self.path, filename), data) 65 | -------------------------------------------------------------------------------- /scripts/ol_curve.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | 8 | if __name__ == '__main__': 9 | device = torch.device("cuda:0") 10 | dir_save = r'../save/online_fm-hp_fm-Spine/RecON' 11 | 12 | MEA, FDR, ADR, MD, SD, HD = [], [], [], [], [], [] 13 | for file in sorted(os.listdir(dir_save)): 14 | if not file.startswith('value_'): 15 | continue 16 | value = torch.load(os.path.join(dir_save, file), map_location=device) 17 | mea, fdr, adr, md, sd, hd = [], [], [], [], [], [] 18 | for idx, loss in enumerate(value['loss']): 19 | mea.append(loss['MEA']) 20 | fdr.append(loss['FDR']) 21 | adr.append(loss['ADR']) 22 | md.append(loss['MD']) 23 | sd.append(loss['SD']) 24 | hd.append(loss['HD']) 25 | MEA.append(torch.tensor(mea, device=device)) 26 | FDR.append(torch.tensor(fdr, device=device)) 27 | ADR.append(torch.tensor(adr, device=device)) 28 | MD.append(torch.tensor(md, device=device)) 29 | SD.append(torch.tensor(sd, device=device)) 30 | HD.append(torch.tensor(hd, device=device)) 31 | MEA = torch.stack(MEA, dim=0) 32 | FDR = torch.stack(FDR, dim=0) 33 | ADR = torch.stack(ADR, dim=0) 34 | MD = torch.stack(MD, dim=0) 35 | SD = torch.stack(SD, dim=0) 36 | HD = torch.stack(HD, dim=0) 37 | 38 | MEA = torch.mean(MEA, dim=0).cpu().numpy() 39 | FDR = torch.mean(FDR, dim=0).cpu().numpy() 40 | ADR = torch.mean(ADR, dim=0).cpu().numpy() 41 | MD = torch.mean(MD, dim=0).cpu().numpy() 42 | SD = torch.mean(SD, dim=0).cpu().numpy() 43 | HD = torch.mean(HD, dim=0).cpu().numpy() 44 | x = np.linspace(0, len(MEA) - 1, len(MEA)) 45 | 46 | plt.figure() 47 | plt.subplot(2, 3, 1) 48 | plt.title('MEA') 49 | plt.plot(x, MEA) 50 | plt.subplot(2, 3, 2) 51 | plt.title('FDR') 52 | plt.plot(x, FDR) 53 | plt.subplot(2, 3, 3) 54 | plt.title('ADR') 55 | plt.plot(x, ADR) 56 | plt.subplot(2, 3, 4) 57 | plt.title('MD') 58 | plt.plot(x, MD) 59 | plt.subplot(2, 3, 5) 60 | plt.title('SD') 61 | plt.plot(x, SD) 62 | plt.subplot(2, 3, 6) 63 | plt.title('HD') 64 | plt.plot(x, HD) 65 | plt.show() 66 | -------------------------------------------------------------------------------- /configs/BaseConfig.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Union 3 | 4 | __all__ = ['BaseConfig'] 5 | 6 | 7 | class BaseConfig(object): 8 | 9 | _path: str 10 | name: str 11 | 12 | def __init__(self, cfg: Union[dict, str], **kwargs): 13 | self._space = 0 14 | self._load(cfg) if isinstance(cfg, dict) else self._fromfile(cfg) 15 | 16 | for k, v in kwargs.items(): 17 | setattr(self, k, v) 18 | 19 | @staticmethod 20 | def _values(key: str): 21 | return not (key.startswith('_') or key == 'name') 22 | 23 | def dict(self): 24 | value_dict = dict(vars(self)) 25 | for key in list(value_dict.keys()): 26 | if not self._values(key): 27 | value_dict.pop(key, None) 28 | return value_dict 29 | 30 | def __eq__(self, other): 31 | if isinstance(other, self.__class__): 32 | for key in vars(other).keys(): 33 | if other._values(key) and not hasattr(self, key): 34 | return False 35 | for key, value in vars(self).items(): 36 | if self._values(key) and (not hasattr(other, key) or value != getattr(other, key)): 37 | return False 38 | return True 39 | else: 40 | return False 41 | 42 | def __repr__(self): 43 | s, sp = '', ' ' * self._space 44 | if self._space == 0: 45 | s += self.__class__.__name__ 46 | if hasattr(self, 'name'): 47 | s += ' (' + str(self.name) + ')' 48 | s += ': ' 49 | s += '{\n' 50 | v = list(vars(self).items()) 51 | v.sort() 52 | for key, value in v: 53 | if not self._values(key): 54 | continue 55 | s += sp + ' ' + key + ': ' 56 | if issubclass(value.__class__, BaseConfig): 57 | value._space = self._space + 2 58 | if isinstance(value, str): 59 | s += "'" 60 | s += str(value) 61 | if isinstance(value, str): 62 | s += "'" 63 | s += '\n' 64 | s += sp + '}' 65 | if self._space == 0: 66 | s += '\n' 67 | return s 68 | 69 | def _load(self, config: dict): 70 | for key, value in config.items(): 71 | setattr(self, key, BaseConfig(value) if isinstance(value, dict) else value) 72 | 73 | def _fromfile(self, path: str): 74 | with open(path, 'r') as f: 75 | self._load(json.load(f)) 76 | setattr(self, '_path', path) 77 | -------------------------------------------------------------------------------- /scripts/visual_reco.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pyvista as pv 5 | import torch 6 | 7 | import utils 8 | 9 | 10 | def polyline_from_points(points): 11 | poly = pv.PolyData() 12 | poly.points = points 13 | the_cell = np.arange(0, len(points), dtype=np.int_) 14 | the_cell = np.insert(the_cell, 0, len(points)) 15 | poly.lines = the_cell 16 | return poly 17 | 18 | 19 | def line_points(points: np.ndarray): 20 | line = polyline_from_points(points) 21 | line["scalars"] = np.arange(line.n_points) 22 | tube = line.tube(radius=2) 23 | return tube 24 | 25 | 26 | def pv_series(p, points): 27 | p.add_mesh(line_points(points[:, 0, :]), point_size=20.0, render_points_as_spheres=True) 28 | p.add_mesh(line_points(points[:, 1, :]), point_size=20.0, render_points_as_spheres=True) 29 | p.add_mesh(line_points(points[:, 2, :]), point_size=20.0, render_points_as_spheres=True) 30 | p.add_mesh(line_points(2 * points[:, 0, :] - points[:, 1, :]), point_size=20.0, render_points_as_spheres=True) 31 | p.add_mesh(line_points(2 * points[:, 0, :] - points[:, 2, :]), point_size=20.0, render_points_as_spheres=True) 32 | 33 | 34 | if __name__ == '__main__': 35 | device = torch.device("cuda:0") 36 | dir_save = r'../save/online_fm-hp_fm-Spine/RecON' 37 | 38 | n = len(os.listdir(dir_save)) // 2 39 | for idx in range(0, n): 40 | source = torch.load(os.path.join(dir_save, 'source_' + str(idx) + '.pth'), map_location=device) 41 | value = torch.load(os.path.join(dir_save, 'value_' + str(idx) + '.pth'), map_location=device) 42 | 43 | volume_real = utils.reconstruction.reco(source[0].squeeze(0).squeeze(1), value['real_series'].detach()) 44 | volume = utils.reconstruction.reco(source[0].squeeze(0).squeeze(1), value['fake_series'][-1].detach()) 45 | 46 | pv.set_plot_theme('document') 47 | p = pv.Plotter(shape=(1, 2)) 48 | p.subplot(0, 0) 49 | p.add_text('GT', position='upper_left') 50 | p.add_volume(volume_real[0].cpu().numpy(), cmap='gray', mapper='gpu') 51 | pv_series(p, value['real_series'].detach().cpu().numpy() - volume_real[1].cpu().numpy()) 52 | p.show_bounds(location='origin') 53 | p.subplot(0, 1) 54 | p.add_text('RecON', position='upper_left') 55 | p.add_volume(volume[0].cpu().numpy(), cmap='gray', mapper='gpu') 56 | pv_series(p, value['fake_series'][-1].detach().cpu().numpy() - volume[1].cpu().numpy()) 57 | p.show_bounds(location='origin') 58 | p.link_views() 59 | p.add_axes() 60 | p.show() 61 | -------------------------------------------------------------------------------- /datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import time 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | import configs 9 | from utils import common, Logger 10 | 11 | __all__ = ['BaseDataset', 'BaseSplit'] 12 | 13 | 14 | class BaseDataset(Dataset, metaclass=abc.ABCMeta): 15 | 16 | logger: Logger 17 | 18 | def __init__(self, cfg, **kwargs): 19 | self.name = os.path.splitext(os.path.split(cfg._path)[1])[0] 20 | self.cfg = self.more(self._more(cfg)) 21 | self.data, self.cfg.data_count = self.load() 22 | 23 | for k, v in kwargs.items(): 24 | setattr(self, k, v) 25 | 26 | @staticmethod 27 | def _more(cfg): 28 | for name, value in configs.env.dataset.dict().items(): 29 | setattr(cfg, name, getattr(cfg, name, value)) 30 | return cfg 31 | 32 | @staticmethod 33 | def more(cfg): 34 | return cfg 35 | 36 | @abc.abstractmethod 37 | def load(self): 38 | raise NotImplementedError 39 | 40 | @abc.abstractmethod 41 | def __getitem__(self, index): 42 | raise NotImplementedError 43 | 44 | def __len__(self): 45 | return self.cfg.data_count 46 | 47 | def split(self): 48 | self.trainset_length = int(self.cfg.series_per_data[0] * self.cfg.train_test_range[0]) 49 | self.valset_length = int(self.cfg.series_per_data[1] * self.cfg.train_test_range[1]) 50 | self.testset_length = len(self) - self.trainset_length - self.valset_length 51 | 52 | index_range_trainset = [[0, self.trainset_length]] 53 | index_range_valset = [[self.trainset_length, self.trainset_length + self.valset_length]] 54 | index_range_testset = [[self.trainset_length + self.valset_length, len(self)]] 55 | 56 | return BaseSplit(self, index_range_trainset), BaseSplit(self, index_range_valset), BaseSplit(self, index_range_testset) 57 | 58 | def get_idx(self, index): 59 | if index < self.trainset_length: 60 | idx = torch.div(index, self.cfg.series_per_data[0], rounding_mode='floor') 61 | common.set_seed(int(time.time() * 1000) % (1 << 32) + index) 62 | elif index < self.trainset_length + self.valset_length: 63 | idx = torch.div(index - self.trainset_length, self.cfg.series_per_data[1], rounding_mode='floor') \ 64 | + torch.div(self.trainset_length, self.cfg.series_per_data[0], rounding_mode='floor') 65 | common.set_seed(index * 3) 66 | else: 67 | idx = torch.div(index - self.trainset_length - self.valset_length, self.cfg.series_per_data[2], rounding_mode='floor') \ 68 | + torch.div(self.valset_length, self.cfg.series_per_data[1], rounding_mode='floor') \ 69 | + torch.div(self.trainset_length, self.cfg.series_per_data[0], rounding_mode='floor') 70 | common.set_seed(index * 3) 71 | return idx 72 | 73 | 74 | class BaseSplit(Dataset): 75 | 76 | def __init__(self, dataset, index_range_set): 77 | self.dataset = dataset 78 | self.indexset = self._index(index_range_set) 79 | self.count = len(self.indexset) 80 | 81 | if hasattr(self.dataset, 'logger'): 82 | self.logger = self.dataset.logger 83 | 84 | def _index(self, index_range_set): 85 | indexset = [] 86 | for index_range in index_range_set: 87 | indexset.extend(range(index_range[0], index_range[1])) 88 | return indexset 89 | 90 | def __getitem__(self, index): 91 | return self.dataset[self.indexset[index]][0], index 92 | 93 | def __len__(self): 94 | return self.count 95 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial 3 | import torch 4 | 5 | import utils 6 | 7 | 8 | def correlation_loss(labels, outputs, eps=1e-6): 9 | x = outputs.flatten() 10 | y = labels.flatten() 11 | xy = x * y 12 | mean_xy = torch.mean(xy) 13 | mean_x = torch.mean(x) 14 | mean_y = torch.mean(y) 15 | cov_xy = mean_xy - mean_x * mean_y 16 | 17 | var_x = torch.sum((x - mean_x) ** 2 / x.shape[0]) + eps 18 | var_y = torch.sum((y - mean_y) ** 2 / y.shape[0]) + eps 19 | 20 | corr_xy = cov_xy / (torch.sqrt(var_x * var_y)) 21 | 22 | loss = 1 - corr_xy 23 | return loss 24 | 25 | 26 | def final_drift_rate(predict, target, eps=1e-6): 27 | final_drift = torch.norm(target[-1] - predict[-1], 2, dim=-1) 28 | dist = torch.sum(torch.norm(target[1:] - target[:-1], 2, dim=-1)) + eps 29 | return final_drift / dist 30 | 31 | 32 | def average_drift_rate(predict, target, eps=0.1): 33 | drift = torch.norm(target[1:] - predict[1:], 2, dim=-1) 34 | dist = torch.cumsum(torch.norm(target[1:] - target[:-1], 2, dim=-1), dim=0) 35 | if eps is not None: 36 | flag = dist >= eps 37 | drift = drift[flag] 38 | dist = dist[flag] 39 | return torch.mean(drift / dist, dim=0) 40 | 41 | 42 | def max_drift(predict, target): 43 | drift = torch.norm(target[1:] - predict[1:], 2, dim=-1) 44 | return torch.max(drift, dim=0)[0] 45 | 46 | 47 | def sum_drift(predict, target): 48 | drift = torch.norm(target[1:] - predict[1:], 2, dim=-1) 49 | return torch.sum(drift) 50 | 51 | 52 | def symmetric_hausdorff_distance(predict, target): 53 | h_pt = scipy.spatial.distance.directed_hausdorff(predict, target)[0] 54 | h_tp = scipy.spatial.distance.directed_hausdorff(target, predict)[0] 55 | return max(h_tp, h_pt) 56 | 57 | 58 | def get_metric(real_series, fake_series): 59 | metric_dict = {} 60 | 61 | real_axis = utils.simulation.get_axis(real_series) 62 | fake_axis = utils.simulation.get_axis(fake_series) 63 | cos = torch.sum(real_axis * fake_axis, dim=-1) 64 | cos.clamp_(-1.0 + 1.0e-7, 1.0 - 1.0e-7) 65 | angle = torch.acos(cos) * 180 / np.pi 66 | metric_dict['MEA'] = torch.mean(angle) 67 | 68 | fdr_pc = final_drift_rate(fake_series[:, 0, :], real_series[:, 0, :]) 69 | fdr_p1 = final_drift_rate(fake_series[:, 1, :], real_series[:, 1, :]) 70 | fdr_p2 = final_drift_rate(fake_series[:, 2, :], real_series[:, 2, :]) 71 | metric_dict['FDR'] = (fdr_pc + fdr_p1 + fdr_p2) / 3 72 | 73 | adr_pc = average_drift_rate(fake_series[:, 0, :], real_series[:, 0, :]) 74 | adr_p1 = average_drift_rate(fake_series[:, 1, :], real_series[:, 1, :]) 75 | adr_p2 = average_drift_rate(fake_series[:, 2, :], real_series[:, 2, :]) 76 | metric_dict['ADR'] = (adr_pc + adr_p1 + adr_p2) / 3 77 | 78 | md_pc = max_drift(fake_series[:, 0, :], real_series[:, 0, :]) 79 | md_p1 = max_drift(fake_series[:, 1, :], real_series[:, 1, :]) 80 | md_p2 = max_drift(fake_series[:, 2, :], real_series[:, 2, :]) 81 | metric_dict['MD'] = (md_pc + md_p1 + md_p2) / 3 82 | 83 | sd_pc = sum_drift(fake_series[:, 0, :], real_series[:, 0, :]) 84 | sd_p1 = sum_drift(fake_series[:, 1, :], real_series[:, 1, :]) 85 | sd_p2 = sum_drift(fake_series[:, 2, :], real_series[:, 2, :]) 86 | metric_dict['SD'] = (sd_pc + sd_p1 + sd_p2) / 3 87 | 88 | fake_series_cpu, real_series_cpu = fake_series.cpu().numpy(), real_series.cpu().numpy() 89 | hausdorff_pc = symmetric_hausdorff_distance(fake_series_cpu[:, 0, :], real_series_cpu[:, 0, :]) 90 | hausdorff_p1 = symmetric_hausdorff_distance(fake_series_cpu[:, 1, :], real_series_cpu[:, 1, :]) 91 | hausdorff_p2 = symmetric_hausdorff_distance(fake_series_cpu[:, 2, :], real_series_cpu[:, 2, :]) 92 | metric_dict['HD'] = (hausdorff_pc + hausdorff_p1 + hausdorff_p2) / 3 93 | 94 | metric_dict = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=real_series.dtype, device=real_series.device) for k, v in metric_dict.items()} 95 | 96 | return metric_dict 97 | -------------------------------------------------------------------------------- /datasets/Spine.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import configs 8 | import datasets 9 | import utils 10 | 11 | __all__ = ['Spine'] 12 | 13 | 14 | class Spine(datasets.BaseDataset): 15 | 16 | @staticmethod 17 | def more(cfg): 18 | cfg.source.elements = cfg.source.width * cfg.source.height * cfg.source.channel 19 | cfg.paths.source = configs.env.getdir(cfg.paths.source) 20 | cfg.paths.target = configs.env.getdir(cfg.paths.target) 21 | 22 | cfg.num_workers = 0 23 | cfg.pin_memory = False 24 | 25 | cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | return cfg 28 | 29 | def load(self): 30 | source_raw = np.load(self.cfg.paths.source, allow_pickle=True)[()] 31 | target_raw = np.load(self.cfg.paths.target, allow_pickle=True)[()] 32 | 33 | source, target_dof, target_point = [], [], [] 34 | for k1 in sorted(source_raw.keys()): 35 | for k2 in sorted(source_raw[k1].keys()): 36 | for k3 in sorted(source_raw[k1][k2].keys()): 37 | slices = torch.from_numpy(source_raw[k1][k2][k3]) 38 | source.append(slices) 39 | tp = torch.from_numpy(target_raw[k1][k2][k3]['point'].astype(np.float32)) 40 | tp = tp.view(-1, 3, 3) 41 | tp = torch.cat([ 42 | tp[:, 0:1, :], 43 | tp[:, 0:1, :] + (tp[:, 1:2, :] * self.cfg.source.height / 2 - tp[:, 2:3, :] * self.cfg.source.width / 2), 44 | tp[:, 0:1, :] + (tp[:, 1:2, :] * self.cfg.source.height / 2 + tp[:, 2:3, :] * self.cfg.source.width / 2) 45 | ], dim=1) 46 | tp = tp.view(-1, 9) 47 | tp, td = self.preprocessing(tp) 48 | target_point.append(tp) 49 | target_dof.append(td) 50 | 51 | trainset_length = int(self.cfg.series_per_data[0] * self.cfg.train_test_range[0]) 52 | valset_length = int(self.cfg.series_per_data[1] * self.cfg.train_test_range[1]) 53 | testset_length = int(self.cfg.series_per_data[2] * self.cfg.train_test_range[2]) 54 | data_count = trainset_length + valset_length + testset_length 55 | 56 | return {'source': source, 'target_dof': target_dof, 'target_point': target_point}, data_count 57 | 58 | def preprocessing(self, tp): 59 | tp = tp.view(-1, 3, 3) 60 | pall = torch.cat([tp, 2 * tp[:, 0:1, :] - tp[:, 1:2, :], 2 * tp[:, 0:1, :] - tp[:, 2:3, :]], dim=1) 61 | min_loca = torch.min(pall.reshape(-1, 3), dim=0)[0] 62 | tp = tp - min_loca.unsqueeze(0).unsqueeze(0) 63 | td = utils.simulation.series_to_dof(tp) 64 | tp = tp.view(-1, 9) 65 | return tp, td 66 | 67 | def __getitem__(self, index): 68 | idx = self.get_idx(index) 69 | 70 | source = self.data['source'][idx].to(self.cfg.device) 71 | target_point = self.data['target_point'][idx] 72 | 73 | frame_rate = torch.randint(self.cfg.frame_rate[0], self.cfg.frame_rate[1] + 1, (1,)) 74 | source = source[::frame_rate] 75 | target_point = target_point[::frame_rate] 76 | target_point, target_dof = self.preprocessing(target_point.view(-1, 3, 3)) 77 | optical_flow = utils.image.get_optical_flow(source, device=self.cfg.device) 78 | edge = utils.image.get_edge(source, device=self.cfg.device) 79 | 80 | source_out = source.unsqueeze(1) 81 | target_out = torch.cat([F.pad(target_dof, (0, 0, 0, 1)), target_point.view(-1, 9)], dim=-1) 82 | 83 | sample_dict = { 84 | 'source': source_out, 'target': target_out, 85 | 'optical_flow': optical_flow, 86 | 'edge': edge, 87 | 'frame_rate': frame_rate, 88 | 'info': torch.tensor(len(source_out)) 89 | } 90 | 91 | utils.common.set_seed(int(time.time() * 1000) % (1 << 32) + index) 92 | return sample_dict, index 93 | -------------------------------------------------------------------------------- /models/BaseModel.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import configs 8 | from datasets import BaseDataset 9 | from utils import Logger 10 | from utils.common import get_path 11 | 12 | __all__ = ['BaseModel'] 13 | 14 | 15 | class _ProcessHook(object, metaclass=abc.ABCMeta): 16 | 17 | @abc.abstractmethod 18 | def train(self, epoch_info: dict, sample_dict: dict): 19 | raise NotImplementedError 20 | 21 | def train_return_hook(self, epoch_info: dict, return_all: dict): 22 | return return_all 23 | 24 | @abc.abstractmethod 25 | def test(self, epoch_info: dict, sample_dict: dict): 26 | raise NotImplementedError 27 | 28 | def test_return_hook(self, epoch_info: dict, return_all: dict): 29 | return return_all 30 | 31 | 32 | class BaseModel(_ProcessHook, metaclass=abc.ABCMeta): 33 | 34 | dataset: BaseDataset 35 | logger: Logger 36 | main_msg: dict 37 | 38 | def __init__(self, cfg, data_cfg, run, **kwargs): 39 | self.name = os.path.splitext(os.path.split(cfg._path)[1])[0] 40 | self.cfg = cfg 41 | self.data_cfg = data_cfg 42 | self.run = run 43 | self.path = get_path(cfg, data_cfg, run) 44 | self.device = self.run.device 45 | 46 | self._save_list = [] 47 | 48 | for k, v in kwargs.items(): 49 | setattr(self, k, v) 50 | 51 | def apply(self, fn): 52 | for name, value in self.__dict__.items(): 53 | if isinstance(value, nn.Module): 54 | self.__dict__[name].apply(fn) 55 | 56 | def modules(self): 57 | m = {} 58 | for name, value in list(vars(self).items()): 59 | if isinstance(value, nn.Module): 60 | m[name] = value 61 | return m 62 | 63 | def train_return_hook(self, epoch_info: dict, return_all: dict): 64 | _count = torch.tensor(return_all.pop('_count'), dtype=torch.float32, device=self.device) 65 | _count_sum = torch.sum(_count) 66 | for key, value in return_all.items(): 67 | if not isinstance(value, torch.Tensor): 68 | value = torch.tensor(value, dtype=torch.float32, device=self.device) 69 | elif value.device != self.device: 70 | value = value.to(self.device) 71 | return_all[key] = _count @ value / _count_sum 72 | return return_all 73 | 74 | def load(self, start_epoch=None, path=None): 75 | assert start_epoch is None or (isinstance(start_epoch, int) and start_epoch >= 0) 76 | path = path or self.path 77 | if start_epoch is None: 78 | check_path = os.path.join(path, self.name + configs.env.paths.check_file) 79 | if os.path.exists(check_path): 80 | check_data = torch.load(check_path) 81 | start_epoch = check_data['epoch'] 82 | self.main_msg = check_data['main_msg'] 83 | else: 84 | start_epoch = 0 85 | if start_epoch > 0: 86 | for name, value in self.__dict__.items(): 87 | if isinstance(value, (nn.Module, torch.optim.Optimizer)) or name in self._save_list: 88 | load_path = os.path.join(path, self.name + '_' + name + '_' + str(start_epoch) + '.pth') 89 | if not os.path.exists(load_path) and isinstance(value, torch.optim.Optimizer): 90 | self.logger.info(f"IGNORE! Optimizer weight `{load_path}` not found!") 91 | continue 92 | load_value = torch.load(load_path, map_location=self.device) 93 | if isinstance(value, (nn.Module, torch.optim.Optimizer)): 94 | self.__dict__[name].load_state_dict(load_value) 95 | else: 96 | self.__dict__[name] = load_value 97 | return start_epoch 98 | 99 | def save(self, epoch, path=None): 100 | path = path or self.path 101 | if not os.path.exists(path): 102 | os.makedirs(path) 103 | for name, value in self.__dict__.items(): 104 | if isinstance(value, (nn.Module, torch.optim.Optimizer)) or name in self._save_list: 105 | save_value = value.state_dict() if isinstance(value, (nn.Module, torch.optim.Optimizer)) else value 106 | torch.save(save_value, os.path.join(path, self.name + '_' + name + '_' + str(epoch) + '.pth')) 107 | torch.save(dict(epoch=epoch, main_msg=self.main_msg), 108 | os.path.join(path, self.name + configs.env.paths.check_file)) 109 | -------------------------------------------------------------------------------- /models/online_backbone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import timm 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import models 8 | import utils 9 | 10 | 11 | class Backbone(nn.Module): 12 | 13 | def __init__(self, in_planes, num_classes): 14 | super().__init__() 15 | self.resnet = timm.create_model('resnet18', pretrained=True, in_chans=in_planes, num_classes=0, global_pool='') 16 | self.lstm = models.layers.convolutional_rnn.Conv2dLSTM(512, 512, kernel_size=3, batch_first=True) 17 | self.avg = nn.AdaptiveAvgPool2d(1) 18 | self.fc = nn.Linear(512, num_classes) 19 | 20 | def forward(self, x, return_feature=False): 21 | b, t, c, h, w = x.shape 22 | x = (x - torch.mean(x, dim=[3, 4], keepdim=True)) / (torch.std(x, dim=[3, 4], keepdim=True) + 1e-6) 23 | x = x.view(b * t, c, h, w) 24 | x = self.resnet(x) 25 | x = x.view(b, t, *x.shape[1:]) 26 | if return_feature: 27 | f = self.avg(x) 28 | f = f.view(f.size(0), f.size(1), -1) 29 | else: 30 | f = None 31 | x = self.lstm(x)[0] 32 | x = self.avg(x) 33 | x = x.view(x.size(0), x.size(1), -1) 34 | x = self.fc(x) 35 | return x, f 36 | 37 | 38 | class Online_Backbone(models.BaseModel): 39 | 40 | def __init__(self, cfg, data_cfg, run, **kwargs): 41 | super().__init__(cfg, data_cfg, run, **kwargs) 42 | self.backbone = Backbone(self.data_cfg.source.channel, self.data_cfg.target.elements - 9).to(self.device) 43 | self.optimizer = torch.optim.Adam(self.backbone.parameters(), lr=self.run.lr, betas=self.run.betas) 44 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.run.step_size, gamma=self.run.gamma) 45 | self.flag_motion = True 46 | 47 | def criterion(self, real_target, fake_target, feature=None): 48 | real_dist, real_angle = real_target.split([3, self.data_cfg.target.elements - 12], dim=-1) 49 | fake_dist, fake_angle = fake_target.split([3, self.data_cfg.target.elements - 12], dim=-1) 50 | 51 | loss_dist = F.l1_loss(real_dist, fake_dist) * 3 52 | loss_angle = F.l1_loss(real_angle, fake_angle) * 3 53 | loss_corr = utils.metric.correlation_loss(real_target, fake_target) 54 | 55 | loss_dict = {'loss_dist': loss_dist, 'loss_angle': loss_angle, 'loss_corr': loss_corr} 56 | 57 | if self.flag_motion: 58 | fake_motion = torch.norm(fake_dist, p=2, dim=-1) + 1e-6 59 | feature = torch.norm(feature, p=2, dim=-1) 60 | loss_motion = torch.mean(feature / fake_motion) * self.cfg.weight_motion 61 | loss_dict['loss_motion'] = loss_motion 62 | 63 | return loss_dict 64 | 65 | def train(self, epoch_info, sample_dict): 66 | real_source = sample_dict['source'].to(self.device) 67 | real_target = sample_dict['target'].to(self.device) 68 | edge = sample_dict['edge'].to(self.device) 69 | optical_flow = sample_dict['optical_flow'].to(self.device) 70 | 71 | real_target = real_target[:, :-1, :-9] 72 | real_target[:, :, 3:] = real_target[:, :, 3:] * 100 73 | 74 | self.backbone.train() 75 | self.optimizer.zero_grad() 76 | input = torch.cat([real_source[:, :-1, ...], real_source[:, 1:, ...], edge[:, :-1, ...], edge[:, 1:, ...], optical_flow], dim=2) 77 | fake_target, feature = self.backbone(input, return_feature=self.flag_motion) 78 | 79 | losses = self.criterion(real_target, fake_target, feature) 80 | loss = sum(losses.values()) 81 | loss.backward() 82 | self.optimizer.step() 83 | self.scheduler.step(epoch_info['epoch']) 84 | 85 | return {'loss': loss, **losses} 86 | 87 | def test(self, epoch_info, sample_dict): 88 | real_source = sample_dict['source'].to(self.device) 89 | real_target = sample_dict['target'].to(self.device).squeeze(0) 90 | edge = sample_dict['edge'].to(self.device) 91 | optical_flow = sample_dict['optical_flow'].to(self.device) 92 | 93 | real_series = real_target[:, -9:].view(-1, 3, 3) 94 | 95 | self.backbone.eval() 96 | input = torch.cat([real_source[:, :-1, ...], real_source[:, 1:, ...], edge[:, :-1, ...], edge[:, 1:, ...], optical_flow], dim=2) 97 | fake_gaps, _ = self.backbone(input) 98 | fake_gaps = fake_gaps[0, :, :] 99 | fake_gaps[:, 3:] /= 100 100 | 101 | fake_series = utils.simulation.dof_to_series(real_series[0:1, :, :], fake_gaps.unsqueeze(0)).squeeze(0) 102 | losses = utils.metric.get_metric(real_series, fake_series) 103 | 104 | return losses 105 | 106 | def test_return_hook(self, epoch_info, return_all): 107 | return_info = {} 108 | for key, value in return_all.items(): 109 | return_info[key] = np.sum(value) / epoch_info['batch_per_epoch'] 110 | if return_info: 111 | self.logger.info_scalars('{} Epoch: {}\t', (epoch_info['log_text'], epoch_info['epoch']), return_info) 112 | return return_all 113 | -------------------------------------------------------------------------------- /models/online_discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import configs 7 | import models 8 | import utils 9 | 10 | 11 | class Discriminator(nn.Module): 12 | 13 | def __init__(self): 14 | super(Discriminator, self).__init__() 15 | self.resnet3d = models.layers.resnet3d.generate_model(10, n_input_channels=1, n_classes=1) 16 | 17 | def forward(self, x): 18 | x = self.resnet3d(x) 19 | return x 20 | 21 | 22 | class Online_Discriminator(models.BaseModel): 23 | 24 | def __init__(self, cfg, data_cfg, run, **kwargs): 25 | super().__init__(cfg, data_cfg, run, **kwargs) 26 | self.backbone = models.online_backbone.Backbone(self.data_cfg.source.channel, self.data_cfg.target.elements - 9).to(self.device) 27 | self.backbone.load_state_dict(torch.load(configs.env.getdir(self.cfg.backbone_weight))) 28 | self.discriminator = Discriminator().to(self.device) 29 | 30 | self.optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.run.lr, betas=self.run.betas) 31 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.run.step_size, gamma=self.run.gamma) 32 | self.criterion = torch.nn.BCEWithLogitsLoss().to(self.device) 33 | 34 | self.down_ratio = 0.3 35 | self.mat_scale = torch.eye(4, dtype=torch.float32, device=self.device) 36 | self.mat_scale[0, 0] = self.down_ratio 37 | self.mat_scale[1, 1] = self.down_ratio 38 | self.mat_scale[2, 2] = self.down_ratio 39 | 40 | def get_reco(self, epoch_info, real_source, real_target, edge, optical_flow): 41 | self.backbone.eval() 42 | with torch.no_grad(): 43 | down_source = F.interpolate(real_source.squeeze(-3), scale_factor=self.down_ratio).unsqueeze(-3) 44 | real_input = torch.cat([real_source[:, :-1, ...], real_source[:, 1:, ...], edge[:, :-1, ...], edge[:, 1:, ...], optical_flow], dim=2) 45 | real_gaps = real_target[:, :-1, :-9] 46 | real_series_0 = real_target[:, 0, -9:].view(-1, 3, 3) 47 | 48 | fake_gaps, _ = self.backbone(real_input, return_feature=False) 49 | fake_gaps = torch.cat([fake_gaps[:, :, :3], fake_gaps[:, :, 3:] / 100], dim=-1) 50 | 51 | real_series = utils.simulation.dof_to_series(real_series_0, real_gaps) 52 | fake_series = utils.simulation.dof_to_series(real_series_0, fake_gaps) 53 | 54 | reco, label = [], [] 55 | for idx, index in enumerate(epoch_info['index']): 56 | if np.random.rand() >= 0.5: 57 | r_reco, _ = utils.reconstruction.reco(down_source[idx].squeeze(0).squeeze(1), real_series[idx].squeeze(0), mat_scale=self.mat_scale) 58 | reco.append(r_reco) 59 | label.append(1) 60 | else: 61 | r_reco, _ = utils.reconstruction.reco(down_source[idx].squeeze(0).squeeze(1), fake_series[idx].squeeze(0), mat_scale=self.mat_scale) 62 | reco.append(r_reco) 63 | label.append(0) 64 | torch.cuda.empty_cache() 65 | reco = torch.stack(reco, dim=0).unsqueeze(1) 66 | label = torch.tensor(label, dtype=torch.float32, device=self.device) 67 | 68 | return reco, label 69 | 70 | def train(self, epoch_info, sample_dict): 71 | real_source = sample_dict['source'].to(self.device) 72 | real_target = sample_dict['target'].to(self.device) 73 | edge = sample_dict['edge'].to(self.device) 74 | optical_flow = sample_dict['optical_flow'].to(self.device) 75 | 76 | self.discriminator.train() 77 | self.optimizer.zero_grad() 78 | reco, label = self.get_reco(epoch_info, real_source, real_target, edge, optical_flow) 79 | pred = self.discriminator(reco) 80 | loss = self.criterion(pred, label.unsqueeze(-1)) 81 | loss.backward() 82 | self.optimizer.step() 83 | self.scheduler.step(epoch_info['epoch']) 84 | 85 | return {'loss': loss} 86 | 87 | def test(self, epoch_info, sample_dict): 88 | real_source = sample_dict['source'].to(self.device) 89 | real_target = sample_dict['target'].to(self.device) 90 | edge = sample_dict['edge'].to(self.device) 91 | optical_flow = sample_dict['optical_flow'].to(self.device) 92 | 93 | self.discriminator.eval() 94 | reco, label = self.get_reco(epoch_info, real_source, real_target, edge, optical_flow) 95 | pred = self.discriminator(reco) 96 | pred_label = pred >= 0 97 | accuracy = pred_label.eq(label.view_as(pred_label)).sum().item() / len(reco) 98 | accuracy = torch.tensor(accuracy, dtype=torch.float32, device=self.device) 99 | 100 | return {'accuracy': accuracy} 101 | 102 | def test_return_hook(self, epoch_info, return_all): 103 | return_info = {} 104 | for key, value in return_all.items(): 105 | return_info[key] = np.sum(value) / epoch_info['batch_per_epoch'] 106 | if return_info: 107 | self.logger.info_scalars('{} Epoch: {}\t', (epoch_info['log_text'], epoch_info['epoch']), return_info) 108 | return return_all 109 | -------------------------------------------------------------------------------- /utils/reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_axis(series, eps=1e-20): 6 | old_dtype = series.dtype 7 | series = series.type(torch.float64) 8 | 9 | p1p2 = series[:, 1:3, :] - series[:, 0:1, :] 10 | ax_x = p1p2[:, 1, :] - p1p2[:, 0, :] 11 | ax_x = F.normalize(ax_x, p=2, dim=-1, eps=eps) 12 | ax_y = -p1p2[:, 1, :] - p1p2[:, 0, :] 13 | ax_y = F.normalize(ax_y, p=2, dim=-1, eps=eps) 14 | ax_z = torch.cross(ax_x, ax_y, dim=-1) 15 | ax_z = F.normalize(ax_z, p=2, dim=-1, eps=eps) 16 | axis = torch.stack([ax_x, ax_y, ax_z], dim=1) 17 | 18 | axis = axis.type(old_dtype) 19 | return axis 20 | 21 | 22 | def _get_weight(dist, iter=2, temperature=0.001, eps=1e-10): 23 | weight = torch.reciprocal(dist + eps) 24 | w_iter = weight * torch.softmax(weight / temperature, dim=0) 25 | for _ in range(iter - 1): 26 | w_iter = weight * torch.softmax(torch.abs(weight - w_iter) / temperature, dim=0) + w_iter 27 | weight = w_iter / torch.sum(w_iter, dim=0, keepdim=True) 28 | return weight 29 | 30 | 31 | def _reco_block(slices, matrix, mesh): 32 | n = len(slices) 33 | rm_shape = mesh.shape[:3] 34 | 35 | loca = torch.einsum('Nij,XYZj->NXYZi', matrix, mesh) 36 | 37 | flag = (torch.sum(~(loca[..., 2] < -0.5), dim=0) == 0) | (torch.sum(~(loca[..., 2] > 0.5), dim=0) == 0) 38 | weight = _get_weight(torch.abs(loca[..., 2])) 39 | grid = loca[..., :2].view(n, 1, -1, 2) 40 | grid = torch.cat([grid[..., 0:1] / (slices.shape[-2] / 2), grid[..., 1:2] / (slices.shape[-1] / 2)], dim=-1) 41 | 42 | value = F.grid_sample(slices.unsqueeze(1), grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=False) 43 | value = value.view(n, *rm_shape) 44 | 45 | flag3 = grid.view(n, *rm_shape, 2) 46 | flag3 = (torch.abs(flag3[..., 0]) > 1) | (torch.abs(flag3[..., 1]) > 1) 47 | weight[flag3] = 0 48 | 49 | weight[:, flag] = 0 50 | volume = torch.einsum('NXYZ,NXYZ->XYZ', weight, value) 51 | return volume 52 | 53 | 54 | def _reco_split(slices, matrix, reco_mesh, chunk_size=None): 55 | chunk_size = chunk_size or (50, 50, 50) 56 | cza = [] 57 | for cz in torch.split(reco_mesh, chunk_size[2], dim=2): 58 | cya = [] 59 | for cy in torch.split(cz, chunk_size[1], dim=1): 60 | cxa = [] 61 | for cx in torch.split(cy, chunk_size[0], dim=0): 62 | cx = _reco_block(slices, matrix, cx) 63 | cxa.append(cx) 64 | cxa = torch.cat(cxa, dim=0) 65 | cya.append(cxa) 66 | cya = torch.cat(cya, dim=1) 67 | cza.append(cya) 68 | cza = torch.cat(cza, dim=2) 69 | return cza 70 | 71 | 72 | def _reco_stack(down_source, matrix, reco_mesh, chunk_size=None): 73 | reco = [] 74 | for idx in range(len(down_source)): 75 | reco.append(_reco_split(down_source[idx].squeeze(1), matrix[idx, ...], reco_mesh, chunk_size=chunk_size)) 76 | reco = torch.stack(reco, dim=0).unsqueeze(1) 77 | return reco 78 | 79 | 80 | def get_reco_size(series, mat_scale=None): 81 | series = torch.cat([series, 2 * series[:, 0:1, :] - series[:, 1:2, :], 2 * series[:, 0:1, :] - series[:, 2:3, :]], dim=1) 82 | min_point = torch.min(series.view(-1, 3), dim=0)[0] 83 | max_point = torch.max(series.view(-1, 3), dim=0)[0] 84 | range_point = max_point - min_point + 1 85 | if mat_scale is not None: 86 | range_point[0] *= mat_scale[0, 0] 87 | range_point[1] *= mat_scale[1, 1] 88 | range_point[2] *= mat_scale[2, 2] 89 | reco_size = torch.ceil(range_point).long().tolist() 90 | bias = min_point - 0.5 91 | return reco_size, bias 92 | 93 | 94 | def get_matrix(series, mat_scale=None): 95 | axis = get_axis(series).permute(0, 2, 1) 96 | axis = torch.cat([axis, series[:, 0:1, :].permute(0, 2, 1)], dim=-1) 97 | axis = F.pad(axis, [0, 0, 0, 1]) 98 | axis[:, -1, -1] = 1 99 | 100 | if mat_scale is not None: 101 | mat_scale_inv = torch.inverse(mat_scale) 102 | mat_scale = mat_scale.unsqueeze(0).expand(len(axis), 4, 4) 103 | mat_scale_inv = mat_scale_inv.unsqueeze(0).expand(len(axis), 4, 4) 104 | axis = torch.bmm(mat_scale, torch.bmm(axis, mat_scale_inv)) 105 | 106 | axis = torch.inverse(axis) 107 | return axis 108 | 109 | 110 | def transform(points, height, width): 111 | axis = get_axis(points).permute(0, 2, 1) 112 | 113 | if not hasattr(transform, 'mesh') or height != transform.height or width != transform.width: 114 | range_x = torch.arange(-(height - 1) / 2, (height + 1) / 2, dtype=points.dtype, device=points.device) 115 | range_y = torch.arange(-(width - 1) / 2, (width + 1) / 2, dtype=points.dtype, device=points.device) 116 | mesh_x, mesh_y = torch.meshgrid(range_x, range_y, indexing='ij') 117 | mesh = torch.stack([mesh_y, -mesh_x, torch.zeros_like(mesh_x)], dim=-1) 118 | transform.mesh = mesh 119 | transform.height = height 120 | transform.width = width 121 | 122 | center = points[:, 0, :].unsqueeze(1).unsqueeze(1) 123 | 124 | local_mesh = torch.einsum('Nij,HWj->NHWi', axis, transform.mesh) + center 125 | return local_mesh 126 | 127 | 128 | def reco(source, series, mat_scale=None, volume_size=None): 129 | if volume_size is not None: 130 | reco_size = volume_size 131 | if not hasattr(reco, 'bias'): 132 | reco.bias = -torch.tensor(volume_size, dtype=series.dtype, device=series.device) / 2 133 | bias = reco.bias 134 | else: 135 | reco_size, bias = get_reco_size(series, mat_scale) 136 | series = series - bias 137 | 138 | matrix = get_matrix(series, mat_scale) 139 | matrix = torch.stack([-matrix[:, 1], matrix[:, 0], matrix[:, 2]], dim=1) 140 | 141 | if not hasattr(reco, 'reco_mesh') or reco_size != reco.reco_size: 142 | reco_mesh = torch.meshgrid([torch.arange(0.5, length + 0.5, dtype=source.dtype, device=series.device) for length in reco_size], indexing='ij') 143 | reco_mesh = torch.stack(reco_mesh, dim=-1) 144 | reco_mesh = F.pad(reco_mesh, (0, 1)) 145 | reco_mesh[..., -1] = 1 146 | reco.reco_mesh = reco_mesh 147 | reco.reco_size = reco_size 148 | 149 | volume = _reco_stack(source.unsqueeze(0).unsqueeze(2), matrix.unsqueeze(0), reco.reco_mesh, chunk_size=volume_size).squeeze(0).squeeze(0) 150 | return volume, bias 151 | 152 | 153 | def get_slice(volume, series, shape): 154 | mesh = transform(series, *shape) 155 | if not hasattr(get_slice, 'volume_size') or volume.shape != get_slice.shape: 156 | get_slice.volume_size = torch.tensor(volume.shape, dtype=mesh.dtype, device=mesh.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) 157 | get_slice.shape = volume.shape 158 | mesh = mesh.unsqueeze(0) / get_slice.volume_size 159 | mesh = (mesh * 2 - 1).flip(-1) 160 | slices = F.grid_sample(volume.unsqueeze(0).unsqueeze(0), mesh, mode='bilinear', padding_mode='border', align_corners=False) 161 | slices = slices.squeeze(1).unsqueeze(2) 162 | return slices 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | 8 | import configs 9 | import datasets 10 | import models 11 | import utils 12 | 13 | 14 | class Main(object): 15 | 16 | def __init__(self, args): 17 | self.args = args 18 | self.model_cfg = configs.BaseConfig(utils.common.real_config_path( 19 | args.model_config_path, configs.env.paths.model_cfgs_folder)) 20 | self.run_cfg = configs.Run(utils.common.real_config_path( 21 | args.run_config_path, configs.env.paths.run_cfgs_folder), gpus=args.gpus) 22 | self.dataset_cfg = datasets.functional.common.more(configs.BaseConfig( 23 | utils.common.real_config_path(args.dataset_config_path, configs.env.paths.dataset_cfgs_folder))) 24 | print(args) 25 | 26 | self._init() 27 | self._get_component() 28 | self.show_cfgs() 29 | 30 | def _init(self): 31 | utils.common.set_seed(0) 32 | self.msg = {} 33 | 34 | def _get_component(self): 35 | self.path = utils.common.get_path(self.model_cfg, self.dataset_cfg, self.run_cfg) 36 | self.logger = utils.Logger(self.path, utils.common.get_filename(self.model_cfg._path)) 37 | 38 | self.dataset = datasets.functional.common.find(self.dataset_cfg.name)(self.dataset_cfg, logger=self.logger) 39 | self.model = models.functional.common.find(self.model_cfg.name)( 40 | self.model_cfg, self.dataset.cfg, self.run_cfg, dataset=self.dataset, logger=self.logger, main_msg=self.msg) 41 | self.start_epoch = self.model.load(self.args.test_epoch) 42 | 43 | def show_cfgs(self): 44 | self.logger.info(self.model.cfg) 45 | self.logger.info(self.run_cfg) 46 | self.logger.info(self.dataset.cfg) 47 | 48 | def split(self): 49 | self.trainset, self.valset, self.testset = self.dataset.split() 50 | 51 | self.train_loader = torch.utils.data.DataLoader( 52 | self.trainset, 53 | batch_size=self.run_cfg.batch_size, 54 | shuffle=True, 55 | collate_fn=getattr(self.trainset.dataset, 'collate_fn', None), 56 | num_workers=self.dataset.cfg.num_workers, 57 | pin_memory=self.dataset.cfg.pin_memory, 58 | sampler=None 59 | ) 60 | 61 | self.run_cfg.test_batch_size = getattr(self.run_cfg, 'test_batch_size', self.run_cfg.batch_size) 62 | self.val_loader = torch.utils.data.DataLoader( 63 | self.valset, 64 | batch_size=self.run_cfg.test_batch_size, 65 | shuffle=False, 66 | collate_fn=getattr(self.valset.dataset, 'collate_fn', None), 67 | num_workers=self.dataset.cfg.num_workers, 68 | pin_memory=self.dataset.cfg.pin_memory, 69 | sampler=None 70 | ) 71 | self.test_loader = torch.utils.data.DataLoader( 72 | self.testset, 73 | batch_size=self.run_cfg.test_batch_size, 74 | shuffle=False, 75 | collate_fn=getattr(self.testset.dataset, 'collate_fn', None), 76 | num_workers=self.dataset.cfg.num_workers, 77 | pin_memory=self.dataset.cfg.pin_memory, 78 | sampler=None 79 | ) 80 | 81 | def train(self, epoch): 82 | utils.common.set_seed(int(time.time()) + epoch) 83 | torch.cuda.empty_cache() 84 | count, loss_all = 0, {} 85 | batch_per_epoch, count_data = len(self.train_loader), len(self.train_loader.dataset) 86 | log_step = 1 87 | epoch_info = {'epoch': epoch, 'batch_per_epoch': batch_per_epoch, 'count_data': count_data} 88 | for batch_idx, (sample_dict, index) in enumerate(self.train_loader): 89 | _count = len(list(sample_dict.values())[0]) 90 | epoch_info['batch_idx'] = batch_idx 91 | epoch_info['index'] = index 92 | epoch_info['batch_count'] = _count 93 | loss_dict = self.model.train(epoch_info, sample_dict) 94 | loss_dict['_count'] = _count 95 | utils.common.merge_dict(loss_all, loss_dict) 96 | count += _count 97 | if batch_idx % log_step == 0: 98 | self.logger.info_scalars('Train Epoch: {} [{}/{} ({:.0f}%)]\t', (epoch, count, count_data, 100. * count / count_data), loss_dict) 99 | if epoch % self.run_cfg.save_step == 0: 100 | loss_file = os.path.join(self.path, self.model.name + '_' + str(epoch) + configs.env.paths.loss_file) 101 | self.logger.save_npy(loss_file, {k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v for k, v in loss_all.items()}) 102 | loss_all = self.model.train_return_hook(epoch_info, loss_all) 103 | self.logger.info_scalars('Train Epoch: {}\t', (epoch,), loss_all) 104 | if epoch % self.run_cfg.save_step == 0: 105 | self.model.save(epoch) 106 | 107 | def test(self, epoch, data_loader=None, log_text=None): 108 | utils.common.set_seed(int(time.time()) + epoch) 109 | torch.cuda.empty_cache() 110 | predict, count = {}, 0 111 | data_loader = data_loader or self.test_loader 112 | log_text = log_text or 'Test' 113 | with torch.no_grad(): 114 | batch_per_epoch, count_data = len(data_loader), len(data_loader.dataset) 115 | log_step = max(int(np.power(10, np.floor(np.log10(batch_per_epoch / 10)))), 1) if batch_per_epoch > 0 else 1 116 | epoch_info = {'epoch': epoch, 'batch_per_epoch': batch_per_epoch, 'count_data': count_data, 'log_text': log_text} 117 | for batch_idx, (sample_dict, index) in enumerate(data_loader): 118 | _count = len(list(sample_dict.values())[0]) 119 | epoch_info['batch_idx'] = batch_idx 120 | epoch_info['index'] = index 121 | epoch_info['batch_count'] = _count 122 | output_dict = self.model.test(epoch_info, sample_dict) 123 | count += _count 124 | if batch_idx % log_step == 0: 125 | self.logger.info('{} Epoch: {} [{}/{} ({:.0f}%)]'.format(log_text, epoch, count, count_data, 100. * count / count_data)) 126 | for name, value in output_dict.items(): 127 | v = value.float() if value.shape else value.unsqueeze(0) 128 | v = v.cpu().numpy() 129 | predict[name] = np.concatenate([predict[name], v]) if name in predict.keys() else v 130 | predict = self.model.test_return_hook(epoch_info, predict) 131 | predict_file = os.path.join(self.path, self.model.name + '_' + str(epoch) + configs.env.paths.predict_file) 132 | self.logger.save_npy(predict_file, predict) 133 | 134 | def val_test(self, epoch): 135 | self.test(epoch, data_loader=self.val_loader, log_text='Val') 136 | self.test(epoch, data_loader=self.test_loader, log_text='Test') 137 | 138 | 139 | def run(): 140 | parser = argparse.ArgumentParser(description='RecON') 141 | parser.add_argument('-m', '--model_config_path', type=str, required=True, metavar='/path/to/model/config.json', 142 | help='Path to model config .json file') 143 | parser.add_argument('-r', '--run_config_path', type=str, required=True, metavar='/path/to/run/config.json', 144 | help='Path to run config .json file') 145 | parser.add_argument('-d', '--dataset_config_path', type=str, required=True, metavar='/path/to/dataset/config.json', 146 | help='Path to dataset config .json file') 147 | parser.add_argument('-g', '--gpus', type=str, default='0', metavar='cuda device, i.e. 0 or cpu', 148 | help='cuda device, i.e. 0 or cpu') 149 | parser.add_argument('-t', '--test_epoch', type=int, metavar='epoch want to test', help='epoch want to test') 150 | args = parser.parse_args() 151 | 152 | main = Main(args) 153 | main.split() 154 | if args.test_epoch is None: 155 | if main.start_epoch == 0: 156 | main.val_test(main.start_epoch) 157 | for epoch in range(main.start_epoch + 1, main.run_cfg.epochs + 1): 158 | main.train(epoch) 159 | if epoch % main.run_cfg.save_step == 0: 160 | main.val_test(epoch) 161 | else: 162 | main.test(main.start_epoch) 163 | 164 | 165 | if __name__ == '__main__': 166 | run() 167 | -------------------------------------------------------------------------------- /models/layers/canny.py: -------------------------------------------------------------------------------- 1 | # https://github.com/DCurro/CannyEdgePytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from scipy.signal.windows import gaussian 7 | 8 | 9 | class Canny(nn.Module): 10 | 11 | def __init__(self, threshold=10.0, filter_size=5, dtype=torch.float32): 12 | super().__init__() 13 | self.threshold = threshold 14 | self.filter_size = filter_size 15 | 16 | generated_filters = gaussian(filter_size, std=1.0).reshape([1, filter_size]) 17 | gf = torch.from_numpy(generated_filters).type(dtype) 18 | 19 | self.gaussian_filter_horizontal = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, filter_size), padding=(0, filter_size // 2)) 20 | self.gaussian_filter_horizontal.weight.data.copy_(gf) 21 | nn.init.constant_(self.gaussian_filter_horizontal.bias, 0.0) 22 | self.gaussian_filter_vertical = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(filter_size, 1), padding=(filter_size // 2, 0)) 23 | self.gaussian_filter_vertical.weight.data.copy_(gf.T) 24 | nn.init.constant_(self.gaussian_filter_vertical.bias, 0.0) 25 | 26 | sobel_filter = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 27 | sf = torch.from_numpy(sobel_filter).type(dtype) 28 | 29 | self.sobel_filter_horizontal = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2) 30 | self.sobel_filter_horizontal.weight.data.copy_(sf) 31 | nn.init.constant_(self.sobel_filter_horizontal.bias, 0.0) 32 | self.sobel_filter_vertical = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2) 33 | self.sobel_filter_vertical.weight.data.copy_(sf.T) 34 | nn.init.constant_(self.sobel_filter_vertical.bias, 0.0) 35 | 36 | # filters were flipped manually 37 | filter_0 = np.array([[0, 0, 0], [0, 1, -1], [0, 0, 0]]) 38 | filter_45 = np.array([[0, 0, 0], [0, 1, 0], [0, 0, -1]]) 39 | filter_90 = np.array([[0, 0, 0], [0, 1, 0], [0, -1, 0]]) 40 | filter_135 = np.array([[0, 0, 0], [0, 1, 0], [-1, 0, 0]]) 41 | filter_180 = np.array([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]) 42 | filter_225 = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 0]]) 43 | filter_270 = np.array([[0, -1, 0], [0, 1, 0], [0, 0, 0]]) 44 | filter_315 = np.array([[0, 0, -1], [0, 1, 0], [0, 0, 0]]) 45 | all_filters = np.stack([filter_0, filter_45, filter_90, filter_135, filter_180, filter_225, filter_270, filter_315]) 46 | 47 | self.directional_filter = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=filter_0.shape, padding=filter_0.shape[-1] // 2) 48 | self.directional_filter.weight.data.copy_(torch.from_numpy(all_filters[:, None, ...]).type(dtype)) 49 | nn.init.constant_(self.directional_filter.bias, 0.0) 50 | 51 | def forward(self, img): 52 | img_r = img[:, 0:1] 53 | img_g = img[:, 1:2] 54 | img_b = img[:, 2:3] 55 | 56 | blur_horizontal = self.gaussian_filter_horizontal(img_r) 57 | blurred_img_r = self.gaussian_filter_vertical(blur_horizontal) 58 | blur_horizontal = self.gaussian_filter_horizontal(img_g) 59 | blurred_img_g = self.gaussian_filter_vertical(blur_horizontal) 60 | blur_horizontal = self.gaussian_filter_horizontal(img_b) 61 | blurred_img_b = self.gaussian_filter_vertical(blur_horizontal) 62 | 63 | blurred_img = torch.stack([blurred_img_r, blurred_img_g, blurred_img_b], dim=1) 64 | blurred_img = torch.stack([torch.squeeze(blurred_img)]) 65 | 66 | grad_x_r = self.sobel_filter_horizontal(blurred_img_r) 67 | grad_y_r = self.sobel_filter_vertical(blurred_img_r) 68 | grad_x_g = self.sobel_filter_horizontal(blurred_img_g) 69 | grad_y_g = self.sobel_filter_vertical(blurred_img_g) 70 | grad_x_b = self.sobel_filter_horizontal(blurred_img_b) 71 | grad_y_b = self.sobel_filter_vertical(blurred_img_b) 72 | 73 | # COMPUTE THICK EDGES 74 | 75 | grad_mag = torch.sqrt(grad_x_r ** 2 + grad_y_r ** 2) 76 | grad_mag += torch.sqrt(grad_x_g ** 2 + grad_y_g ** 2) 77 | grad_mag += torch.sqrt(grad_x_b ** 2 + grad_y_b ** 2) 78 | grad_orientation = (torch.atan2(grad_y_r + grad_y_g + grad_y_b, grad_x_r + grad_x_g + grad_x_b) * (180.0 / np.pi)) 79 | grad_orientation += 180.0 80 | grad_orientation = torch.round(grad_orientation / 45.0) * 45.0 81 | 82 | # THIN EDGES (NON-MAX SUPPRESSION) 83 | 84 | all_filtered = self.directional_filter(grad_mag) 85 | 86 | inidices_positive = (grad_orientation / 45) % 8 87 | inidices_negative = ((grad_orientation / 45) + 4) % 8 88 | 89 | height = inidices_positive.size(2) 90 | width = inidices_positive.size(3) 91 | pixel_count = height * width 92 | pixel_range = torch.arange(pixel_count, device=img.device) 93 | 94 | indices = (inidices_positive.view(-1).data * pixel_count + pixel_range).squeeze() 95 | channel_select_filtered_positive = all_filtered.view(-1)[indices.long()].view(1, height, width) 96 | 97 | indices = (inidices_negative.view(-1).data * pixel_count + pixel_range).squeeze() 98 | channel_select_filtered_negative = all_filtered.view(-1)[indices.long()].view(1, height, width) 99 | 100 | channel_select_filtered = torch.stack([channel_select_filtered_positive, channel_select_filtered_negative]) 101 | 102 | is_max = channel_select_filtered.min(dim=0)[0] > 0.0 103 | is_max = torch.unsqueeze(is_max, dim=0) 104 | 105 | thin_edges = grad_mag.clone() 106 | thin_edges[is_max == 0] = 0.0 107 | 108 | # THRESHOLD 109 | 110 | thresholded = thin_edges.clone() 111 | thresholded[thin_edges < self.threshold] = 0.0 112 | 113 | early_threshold = grad_mag.clone() 114 | early_threshold[grad_mag < self.threshold] = 0.0 115 | 116 | assert grad_mag.size() == grad_orientation.size() == thin_edges.size() == thresholded.size() == early_threshold.size() 117 | 118 | return blurred_img, grad_mag, grad_orientation, thin_edges, thresholded, early_threshold 119 | 120 | 121 | class Canny2D(nn.Module): 122 | 123 | def __init__(self, threshold=10.0, filter_size=5, dtype=torch.float32): 124 | super().__init__() 125 | self.threshold = threshold 126 | self.filter_size = filter_size 127 | 128 | generated_filters = gaussian(filter_size, std=1.0).reshape([1, filter_size]) 129 | gf = torch.from_numpy(generated_filters).type(dtype) 130 | 131 | self.gaussian_filter_horizontal = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, filter_size), padding=(0, filter_size // 2)) 132 | self.gaussian_filter_horizontal.weight.data.copy_(gf) 133 | nn.init.constant_(self.gaussian_filter_horizontal.bias, 0.0) 134 | self.gaussian_filter_vertical = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(filter_size, 1), padding=(filter_size // 2, 0)) 135 | self.gaussian_filter_vertical.weight.data.copy_(gf.T) 136 | nn.init.constant_(self.gaussian_filter_vertical.bias, 0.0) 137 | 138 | sobel_filter = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 139 | sf = torch.from_numpy(sobel_filter).type(dtype) 140 | 141 | self.sobel_filter_horizontal = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2) 142 | self.sobel_filter_horizontal.weight.data.copy_(sf) 143 | nn.init.constant_(self.sobel_filter_horizontal.bias, 0.0) 144 | self.sobel_filter_vertical = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=sobel_filter.shape, padding=sobel_filter.shape[0] // 2) 145 | self.sobel_filter_vertical.weight.data.copy_(sf.T) 146 | nn.init.constant_(self.sobel_filter_vertical.bias, 0.0) 147 | 148 | def forward(self, img): 149 | horizontal = self.gaussian_filter_horizontal(img) 150 | blurred_img = self.gaussian_filter_vertical(horizontal) 151 | 152 | grad_x = self.sobel_filter_horizontal(blurred_img) 153 | grad_y = self.sobel_filter_vertical(blurred_img) 154 | 155 | grad_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2) 156 | 157 | grad_mag = (grad_mag - grad_mag.min()) / (grad_mag.max() - grad_mag.min()) 158 | grad_mag[grad_mag < self.threshold] = 0.0 159 | 160 | return grad_mag 161 | -------------------------------------------------------------------------------- /models/layers/resnet3d.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kenshohara/3D-ResNets-PyTorch 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.checkpoint import checkpoint 9 | 10 | 11 | def get_inplanes(): 12 | return [64, 128, 256, 512] 13 | 14 | 15 | def conv3x3x3(in_planes, out_planes, stride=1): 16 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | 18 | 19 | def conv1x1x1(in_planes, out_planes, stride=1): 20 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, in_planes, planes, stride=1, downsample=None): 27 | super().__init__() 28 | 29 | self.conv1 = conv3x3x3(in_planes, planes, stride) 30 | self.bn1 = nn.BatchNorm3d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm3d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, in_planes, planes, stride=1, downsample=None): 60 | super().__init__() 61 | 62 | self.conv1 = conv1x1x1(in_planes, planes) 63 | self.bn1 = nn.BatchNorm3d(planes) 64 | self.conv2 = conv3x3x3(planes, planes, stride) 65 | self.bn2 = nn.BatchNorm3d(planes) 66 | self.conv3 = conv1x1x1(planes, planes * self.expansion) 67 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, 98 | block, 99 | layers, 100 | block_inplanes, 101 | n_input_channels=3, 102 | conv1_t_size=7, 103 | conv1_t_stride=1, 104 | no_max_pool=False, 105 | shortcut_type='B', 106 | widen_factor=1.0, 107 | n_classes=400): 108 | super().__init__() 109 | 110 | block_inplanes = [int(x * widen_factor) for x in block_inplanes] 111 | 112 | self.in_planes = block_inplanes[0] 113 | self.no_max_pool = no_max_pool 114 | 115 | self.conv1 = nn.Conv3d(n_input_channels, 116 | self.in_planes, 117 | kernel_size=(conv1_t_size, 7, 7), 118 | stride=(conv1_t_stride, 2, 2), 119 | padding=(conv1_t_size // 2, 3, 3), 120 | bias=False) 121 | self.bn1 = nn.BatchNorm3d(self.in_planes) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], 125 | shortcut_type) 126 | self.layer2 = self._make_layer(block, 127 | block_inplanes[1], 128 | layers[1], 129 | shortcut_type, 130 | stride=2) 131 | self.layer3 = self._make_layer(block, 132 | block_inplanes[2], 133 | layers[2], 134 | shortcut_type, 135 | stride=2) 136 | self.layer4 = self._make_layer(block, 137 | block_inplanes[3], 138 | layers[3], 139 | shortcut_type, 140 | stride=2) 141 | 142 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 143 | self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv3d): 147 | nn.init.kaiming_normal_(m.weight, 148 | mode='fan_out', 149 | nonlinearity='relu') 150 | elif isinstance(m, nn.BatchNorm3d): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | def _downsample_basic_block(self, x, planes, stride): 155 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 156 | zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), 157 | out.size(3), out.size(4)) 158 | if isinstance(out.data, torch.cuda.FloatTensor): 159 | zero_pads = zero_pads.cuda() 160 | 161 | out = torch.cat([out.data, zero_pads], dim=1) 162 | 163 | return out 164 | 165 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 166 | downsample = None 167 | if stride != 1 or self.in_planes != planes * block.expansion: 168 | if shortcut_type == 'A': 169 | downsample = partial(self._downsample_basic_block, 170 | planes=planes * block.expansion, 171 | stride=stride) 172 | else: 173 | downsample = nn.Sequential( 174 | conv1x1x1(self.in_planes, planes * block.expansion, stride), 175 | nn.BatchNorm3d(planes * block.expansion)) 176 | 177 | layers = [] 178 | layers.append( 179 | block(in_planes=self.in_planes, 180 | planes=planes, 181 | stride=stride, 182 | downsample=downsample)) 183 | self.in_planes = planes * block.expansion 184 | for i in range(1, blocks): 185 | layers.append(block(self.in_planes, planes)) 186 | 187 | return nn.Sequential(*layers) 188 | 189 | def forward(self, x, cp=False): 190 | x = self.conv1(x) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | if not self.no_max_pool: 194 | x = self.maxpool(x) 195 | 196 | if cp: 197 | x = checkpoint(self.layer1, x) 198 | x = checkpoint(self.layer2, x) 199 | x = checkpoint(self.layer3, x) 200 | x = checkpoint(self.layer4, x) 201 | else: 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | x = self.layer4(x) 206 | 207 | x = self.avgpool(x) 208 | 209 | x = x.view(x.size(0), -1) 210 | x = self.fc(x) 211 | 212 | return x 213 | 214 | 215 | def generate_model(model_depth, **kwargs): 216 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 217 | 218 | if model_depth == 10: 219 | model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs) 220 | elif model_depth == 18: 221 | model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs) 222 | elif model_depth == 34: 223 | model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs) 224 | elif model_depth == 50: 225 | model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs) 226 | elif model_depth == 101: 227 | model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs) 228 | elif model_depth == 152: 229 | model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs) 230 | elif model_depth == 200: 231 | model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs) 232 | 233 | return model 234 | -------------------------------------------------------------------------------- /models/layers/convolutional_rnn/functional.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | try: 6 | # pytorch<=0.4.1 7 | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend 8 | except ImportError: 9 | fusedBackend = None 10 | 11 | from .utils import _single, _pair, _triple 12 | 13 | 14 | def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 15 | """ Copied from torch.nn._functions.rnn and modified """ 16 | if linear_func is None: 17 | linear_func = F.linear 18 | hy = F.relu(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) 19 | return hy 20 | 21 | 22 | def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 23 | """ Copied from torch.nn._functions.rnn and modified """ 24 | if linear_func is None: 25 | linear_func = F.linear 26 | hy = torch.tanh(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) 27 | return hy 28 | 29 | 30 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 31 | """ Copied from torch.nn._functions.rnn and modified """ 32 | if linear_func is None: 33 | linear_func = F.linear 34 | if input.is_cuda and linear_func is F.linear and fusedBackend is not None: 35 | igates = linear_func(input, w_ih) 36 | hgates = linear_func(hidden[0], w_hh) 37 | state = fusedBackend.LSTMFused.apply 38 | return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) 39 | 40 | hx, cx = hidden 41 | gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) 42 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 43 | 44 | ingate = torch.sigmoid(ingate) 45 | forgetgate = torch.sigmoid(forgetgate) 46 | cellgate = torch.tanh(cellgate) 47 | outgate = torch.sigmoid(outgate) 48 | 49 | cy = (forgetgate * cx) + (ingate * cellgate) 50 | hy = outgate * torch.tanh(cy) 51 | 52 | return hy, cy 53 | 54 | 55 | def PeepholeLSTMCell(input, hidden, w_ih, w_hh, w_pi, w_pf, w_po, 56 | b_ih=None, b_hh=None, linear_func=None): 57 | if linear_func is None: 58 | linear_func = F.linear 59 | hx, cx = hidden 60 | gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) 61 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 62 | 63 | ingate += linear_func(cx, w_pi) 64 | forgetgate += linear_func(cx, w_pf) 65 | ingate = torch.sigmoid(ingate) 66 | forgetgate = torch.sigmoid(forgetgate) 67 | cellgate = torch.tanh(cellgate) 68 | 69 | cy = (forgetgate * cx) + (ingate * cellgate) 70 | outgate += linear_func(cy, w_po) 71 | outgate = torch.sigmoid(outgate) 72 | 73 | hy = outgate * torch.tanh(cy) 74 | 75 | return hy, cy 76 | 77 | 78 | def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 79 | """ Copied from torch.nn._functions.rnn and modified """ 80 | if linear_func is None: 81 | linear_func = F.linear 82 | if input.is_cuda and linear_func is F.linear and fusedBackend is not None: 83 | gi = linear_func(input, w_ih) 84 | gh = linear_func(hidden, w_hh) 85 | state = fusedBackend.GRUFused.apply 86 | return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) 87 | gi = linear_func(input, w_ih, b_ih) 88 | gh = linear_func(hidden, w_hh, b_hh) 89 | i_r, i_i, i_n = gi.chunk(3, 1) 90 | h_r, h_i, h_n = gh.chunk(3, 1) 91 | 92 | resetgate = torch.sigmoid(i_r + h_r) 93 | inputgate = torch.sigmoid(i_i + h_i) 94 | newgate = torch.tanh(i_n + resetgate * h_n) 95 | hy = newgate + inputgate * (hidden - newgate) 96 | 97 | return hy 98 | 99 | 100 | def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): 101 | """ Copied from torch.nn._functions.rnn and modified """ 102 | 103 | num_directions = len(inners) 104 | total_layers = num_layers * num_directions 105 | 106 | def forward(input, hidden, weight, batch_sizes): 107 | assert(len(weight) == total_layers) 108 | next_hidden = [] 109 | ch_dim = input.dim() - weight[0][0].dim() + 1 110 | 111 | if lstm: 112 | hidden = list(zip(*hidden)) 113 | 114 | for i in range(num_layers): 115 | all_output = [] 116 | for j, inner in enumerate(inners): 117 | l = i * num_directions + j 118 | 119 | hy, output = inner(input, hidden[l], weight[l], batch_sizes) 120 | next_hidden.append(hy) 121 | all_output.append(output) 122 | 123 | input = torch.cat(all_output, ch_dim) 124 | 125 | if dropout != 0 and i < num_layers - 1: 126 | input = F.dropout(input, p=dropout, training=train, inplace=False) 127 | 128 | if lstm: 129 | next_h, next_c = zip(*next_hidden) 130 | next_hidden = ( 131 | torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), 132 | torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) 133 | ) 134 | else: 135 | next_hidden = torch.cat(next_hidden, 0).view( 136 | total_layers, *next_hidden[0].size()) 137 | 138 | return next_hidden, input 139 | 140 | return forward 141 | 142 | 143 | def Recurrent(inner, reverse=False): 144 | """ Copied from torch.nn._functions.rnn without any modification """ 145 | def forward(input, hidden, weight, batch_sizes): 146 | output = [] 147 | steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) 148 | for i in steps: 149 | hidden = inner(input[i], hidden, *weight) 150 | # hack to handle LSTM 151 | output.append(hidden[0] if isinstance(hidden, tuple) else hidden) 152 | 153 | if reverse: 154 | output.reverse() 155 | output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 156 | 157 | return hidden, output 158 | 159 | return forward 160 | 161 | 162 | def variable_recurrent_factory(inner, reverse=False): 163 | """ Copied from torch.nn._functions.rnn without any modification """ 164 | if reverse: 165 | return VariableRecurrentReverse(inner) 166 | else: 167 | return VariableRecurrent(inner) 168 | 169 | 170 | def VariableRecurrent(inner): 171 | """ Copied from torch.nn._functions.rnn without any modification """ 172 | def forward(input, hidden, weight, batch_sizes): 173 | output = [] 174 | input_offset = 0 175 | last_batch_size = batch_sizes[0] 176 | hiddens = [] 177 | flat_hidden = not isinstance(hidden, tuple) 178 | if flat_hidden: 179 | hidden = (hidden,) 180 | for batch_size in batch_sizes: 181 | step_input = input[input_offset:input_offset + batch_size] 182 | input_offset += batch_size 183 | 184 | dec = last_batch_size - batch_size 185 | if dec > 0: 186 | hiddens.append(tuple(h[-dec:] for h in hidden)) 187 | hidden = tuple(h[:-dec] for h in hidden) 188 | last_batch_size = batch_size 189 | 190 | if flat_hidden: 191 | hidden = (inner(step_input, hidden[0], *weight),) 192 | else: 193 | hidden = inner(step_input, hidden, *weight) 194 | 195 | output.append(hidden[0]) 196 | hiddens.append(hidden) 197 | hiddens.reverse() 198 | 199 | hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) 200 | assert hidden[0].size(0) == batch_sizes[0] 201 | if flat_hidden: 202 | hidden = hidden[0] 203 | output = torch.cat(output, 0) 204 | 205 | return hidden, output 206 | 207 | return forward 208 | 209 | 210 | def VariableRecurrentReverse(inner): 211 | """ Copied from torch.nn._functions.rnn without any modification """ 212 | def forward(input, hidden, weight, batch_sizes): 213 | output = [] 214 | input_offset = input.size(0) 215 | last_batch_size = batch_sizes[-1] 216 | initial_hidden = hidden 217 | flat_hidden = not isinstance(hidden, tuple) 218 | if flat_hidden: 219 | hidden = (hidden,) 220 | initial_hidden = (initial_hidden,) 221 | hidden = tuple(h[:batch_sizes[-1]] for h in hidden) 222 | for i in reversed(range(len(batch_sizes))): 223 | batch_size = batch_sizes[i] 224 | inc = batch_size - last_batch_size 225 | if inc > 0: 226 | hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) 227 | for h, ih in zip(hidden, initial_hidden)) 228 | last_batch_size = batch_size 229 | step_input = input[input_offset - batch_size:input_offset] 230 | input_offset -= batch_size 231 | 232 | if flat_hidden: 233 | hidden = (inner(step_input, hidden[0], *weight),) 234 | else: 235 | hidden = inner(step_input, hidden, *weight) 236 | output.append(hidden[0]) 237 | 238 | output.reverse() 239 | output = torch.cat(output, 0) 240 | if flat_hidden: 241 | hidden = hidden[0] 242 | return hidden, output 243 | 244 | return forward 245 | 246 | 247 | def ConvNdWithSamePadding(convndim=2, stride=1, dilation=1, groups=1): 248 | def forward(input, w, b=None): 249 | if convndim == 1: 250 | ntuple = _single 251 | elif convndim == 2: 252 | ntuple = _pair 253 | elif convndim == 3: 254 | ntuple = _triple 255 | else: 256 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 257 | 258 | if input.dim() != convndim + 2: 259 | raise RuntimeError('Input dim must be {}, bot got {}'.format(convndim + 2, input.dim())) 260 | if w.dim() != convndim + 2: 261 | raise RuntimeError('w must be {}, bot got {}'.format(convndim + 2, w.dim())) 262 | 263 | insize = input.shape[2:] 264 | kernel_size = w.shape[2:] 265 | _stride = ntuple(stride) 266 | _dilation = ntuple(dilation) 267 | 268 | ps = [(i + 1 - h + s * (h - 1) + d * (k - 1)) // 2 269 | for h, k, s, d in list(zip(insize, kernel_size, _stride, _dilation))[::-1] for i in range(2)] 270 | # Padding to make the output shape to have the same shape as the input 271 | input = F.pad(input, ps, 'constant', 0) 272 | return getattr(F, 'conv{}d'.format(convndim))( 273 | input, w, b, stride=_stride, padding=ntuple(0), dilation=_dilation, groups=groups) 274 | return forward 275 | 276 | 277 | def _conv_cell_helper(mode, convndim=2, stride=1, dilation=1, groups=1): 278 | linear_func = ConvNdWithSamePadding(convndim=convndim, stride=stride, dilation=dilation, groups=groups) 279 | 280 | if mode == 'RNN_RELU': 281 | cell = partial(RNNReLUCell, linear_func=linear_func) 282 | elif mode == 'RNN_TANH': 283 | cell = partial(RNNTanhCell, linear_func=linear_func) 284 | elif mode == 'LSTM': 285 | cell = partial(LSTMCell, linear_func=linear_func) 286 | elif mode == 'GRU': 287 | cell = partial(GRUCell, linear_func=linear_func) 288 | elif mode == 'PeepholeLSTM': 289 | cell = partial(PeepholeLSTMCell, linear_func=linear_func) 290 | else: 291 | raise Exception('Unknown mode: {}'.format(mode)) 292 | return cell 293 | 294 | 295 | def AutogradConvRNN( 296 | mode, num_layers=1, batch_first=False, 297 | dropout=0, train=True, bidirectional=False, variable_length=False, 298 | convndim=2, stride=1, dilation=1, groups=1): 299 | """ Copied from torch.nn._functions.rnn and modified """ 300 | cell = _conv_cell_helper(mode, convndim=convndim, stride=stride, dilation=dilation, groups=groups) 301 | 302 | rec_factory = variable_recurrent_factory if variable_length else Recurrent 303 | 304 | if bidirectional: 305 | layer = (rec_factory(cell), rec_factory(cell, reverse=True)) 306 | else: 307 | layer = (rec_factory(cell),) 308 | 309 | func = StackedRNN(layer, num_layers, (mode in ('LSTM', 'PeepholeLSTM')), dropout=dropout, train=train) 310 | 311 | def forward(input, weight, hidden, batch_sizes): 312 | if batch_first and batch_sizes is None: 313 | input = input.transpose(0, 1) 314 | 315 | nexth, output = func(input, hidden, weight, batch_sizes) 316 | 317 | if batch_first and batch_sizes is None: 318 | output = output.transpose(0, 1) 319 | 320 | return output, nexth 321 | 322 | return forward 323 | -------------------------------------------------------------------------------- /models/online_framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import configs 8 | import models 9 | import utils 10 | 11 | 12 | def pad_volume(v1, v2): 13 | max_s0 = max(v1.shape[0], v2.shape[0]) 14 | max_s1 = max(v1.shape[1], v2.shape[1]) 15 | max_s2 = max(v1.shape[2], v2.shape[2]) 16 | diff_v1_s0 = max_s0 - v1.shape[0] 17 | diff_v1_s1 = max_s1 - v1.shape[1] 18 | diff_v1_s2 = max_s2 - v1.shape[2] 19 | diff_v2_s0 = max_s0 - v2.shape[0] 20 | diff_v2_s1 = max_s1 - v2.shape[1] 21 | diff_v2_s2 = max_s2 - v2.shape[2] 22 | v1 = F.pad(v1, (diff_v1_s2 // 2, diff_v1_s2 - diff_v1_s2 // 2, diff_v1_s1 // 2, diff_v1_s1 - diff_v1_s1 // 2, diff_v1_s0 // 2, diff_v1_s0 - diff_v1_s0 // 2)) 23 | v2 = F.pad(v2, (diff_v2_s2 // 2, diff_v2_s2 - diff_v2_s2 // 2, diff_v2_s1 // 2, diff_v2_s1 - diff_v2_s1 // 2, diff_v2_s0 // 2, diff_v2_s0 - diff_v2_s0 // 2)) 24 | return v1, v2 25 | 26 | 27 | class Online_Framework(models.BaseModel): 28 | 29 | def __init__(self, cfg, data_cfg, run, **kwargs): 30 | super().__init__(cfg, data_cfg, run, **kwargs) 31 | self.backbone = models.online_backbone.Backbone(self.data_cfg.source.channel, self.data_cfg.target.elements - 9).to(self.device) 32 | self.backbone_start_weight = torch.load(configs.env.getdir(self.cfg.backbone_weight)) 33 | self.backbone.load_state_dict(self.backbone_start_weight) 34 | 35 | self.discriminator = models.online_discriminator.Discriminator().to(self.device) 36 | self.discriminator_start_weight = torch.load(configs.env.getdir(self.cfg.discriminator_weight)) 37 | self.discriminator.load_state_dict(self.discriminator_start_weight) 38 | 39 | self.mat_scale = torch.eye(4, dtype=torch.float32, device=self.device) 40 | self.mat_scale[0, 0] = self.cfg.down_ratio 41 | self.mat_scale[1, 1] = self.cfg.down_ratio 42 | self.mat_scale[2, 2] = self.cfg.down_ratio 43 | 44 | def train(self, epoch_info, sample_dict): 45 | return {} 46 | 47 | def criterion(self, real_target, fake_target): 48 | real_dist, real_angle = real_target.split([3, self.data_cfg.target.elements - 12], dim=-1) 49 | fake_dist, fake_angle = fake_target.split([3, self.data_cfg.target.elements - 12], dim=-1) 50 | 51 | loss_dist = F.l1_loss(real_dist, fake_dist) * 3 52 | loss_angle = F.l1_loss(real_angle, fake_angle) * 3 53 | loss_corr = utils.metric.correlation_loss(real_target, fake_target) 54 | 55 | loss_dict = {'loss_dist': loss_dist, 'loss_angle': loss_angle, 'loss_corr': loss_corr} 56 | return loss_dict 57 | 58 | def test_optimize(self, epoch_info, real_source, real_target, edge, optical_flow, epoch): 59 | self.backbone.load_state_dict(self.backbone_start_weight) 60 | self.discriminator.load_state_dict(self.discriminator_start_weight) 61 | 62 | down_source = F.interpolate(real_source.squeeze(-3), scale_factor=self.cfg.down_ratio).unsqueeze(-3) 63 | real_gaps = real_target[0, :-1, :-9] 64 | real_series = real_target[0, :, -9:].view(-1, 3, 3) 65 | 66 | real_input = torch.cat([real_source[:, :-1, ...], real_source[:, 1:, ...], edge[:, :-1, ...], edge[:, 1:, ...], optical_flow], dim=2) 67 | 68 | value = {'real_source': real_source, 'real_gaps': real_gaps.clone(), 'real_series': real_series.clone()} 69 | self.backbone.eval() 70 | self.discriminator.eval() 71 | fake2_gaps, _ = self.backbone(real_input, return_feature=False) 72 | fake2_gaps = fake2_gaps[0, :, :] 73 | fake2_gaps[:, 3:] /= 100 74 | fake2_series = utils.simulation.dof_to_series(real_series[0:1, :, :], fake2_gaps.unsqueeze(0)).squeeze(0) 75 | losses2 = utils.metric.get_metric(real_series, fake2_series) 76 | value['fake_gaps'] = [fake2_gaps] 77 | value['fake_series'] = [fake2_series] 78 | value['loss'] = [losses2] 79 | 80 | self.optimizer_psc = torch.optim.Adam(self.backbone.parameters(), lr=self.run.lr_psc, betas=self.run.betas) 81 | self.optimizer_g = torch.optim.Adam(self.backbone.parameters(), lr=self.run.lr_fcc_gas, betas=self.run.betas) 82 | self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.run.lr_fcc_gas, betas=self.run.betas) 83 | 84 | with torch.enable_grad(): 85 | for idx in range(1, epoch + 1): 86 | self.logger.info(f"RecON: Data {epoch_info['index'].item() + 1}/{epoch_info['count_data']} Epoch {idx}/{epoch}") 87 | self.backbone.train() 88 | self.discriminator.train() 89 | 90 | with torch.no_grad(): 91 | fake_gaps, _ = self.backbone(real_input, return_feature=False) 92 | fake_gaps = torch.cat([fake_gaps[:, :, :3], fake_gaps[:, :, 3:] / 100], dim=-1) 93 | 94 | for idx_psc in range(1, self.cfg.psc_epoch + 1): 95 | with torch.no_grad(): 96 | corr = np.Inf 97 | acq = 0 98 | best_corr = np.Inf 99 | best_sample = None 100 | while corr > 1 - self.cfg.psc_threshold and acq < self.cfg.psc_max_acquisition: 101 | d_idx = torch.randint(self.dataset.trainset_length, (1,), dtype=torch.long, device=self.device) 102 | r_data = self.dataset[d_idx[0]][0] 103 | gaps = r_data['target'].to(self.device).unsqueeze(0) 104 | gaps = gaps[:, :-1, :-9] 105 | min_length = min(fake_gaps.shape[1], gaps.shape[1]) 106 | corr = utils.metric.correlation_loss(fake_gaps[:, :min_length], gaps[:, :min_length]) 107 | if corr < best_corr: 108 | best_corr = corr 109 | best_sample = r_data, gaps 110 | acq += 1 111 | if acq == self.cfg.psc_max_acquisition: 112 | r_data, gaps = best_sample 113 | slices = r_data['source'].to(self.device).unsqueeze(0) 114 | ed = r_data['edge'].to(self.device).unsqueeze(0) 115 | of = r_data['optical_flow'].to(self.device).unsqueeze(0) 116 | gaps[:, :, 3:] = gaps[:, :, 3:] * 100 117 | 118 | self.optimizer_psc.zero_grad() 119 | ri = torch.cat([slices[:, :-1, ...], slices[:, 1:, ...], ed[:, :-1, ...], ed[:, 1:, ...], of], dim=2) 120 | fgaps, feature = self.backbone(ri, return_feature=False) 121 | 122 | losses = self.criterion(gaps, fgaps) 123 | loss = sum(losses.values()) 124 | 125 | self.logger.info_scalars('PSC iter {}/{}\t', (idx_psc, self.cfg.psc_epoch), {'loss_psc': loss, **losses}) 126 | loss.backward() 127 | self.optimizer_psc.step() 128 | torch.cuda.empty_cache() 129 | 130 | if idx % self.cfg.discriminator_opt_cycle == 0: 131 | with torch.no_grad(): 132 | d_idx = torch.randint(self.dataset.trainset_length, (1,), dtype=torch.long, device=self.device) 133 | r_data = self.dataset[d_idx[0]][0] 134 | r_source = r_data['source'].to(self.device) 135 | r_target = r_data['target'].to(self.device) 136 | 137 | r_down = F.interpolate(r_source.unsqueeze(0).squeeze(-3), scale_factor=self.cfg.down_ratio).unsqueeze(-3) 138 | r_reco, _ = utils.reconstruction.reco(r_down[:, ::2].squeeze(0).squeeze(1), r_target[::2, -9:].view(-1, 3, 3), mat_scale=self.mat_scale) 139 | 140 | self.optimizer_d.zero_grad() 141 | 142 | fake_gaps, _ = self.backbone(real_input, return_feature=False) 143 | fake_gaps = torch.cat([fake_gaps[0, :, :3], fake_gaps[0, :, 3:] / 100], dim=-1) 144 | fake_series = utils.simulation.dof_to_series(real_series[0:1, :, :], fake_gaps.unsqueeze(0)).squeeze(0) 145 | 146 | reco, _ = utils.reconstruction.reco(down_source[:, ::2].squeeze(0).squeeze(1), fake_series[::2], mat_scale=self.mat_scale) 147 | 148 | pred_real = self.discriminator(r_reco.unsqueeze(0).unsqueeze(0)) 149 | pred_fake = self.discriminator(reco.unsqueeze(0).unsqueeze(0)) 150 | loss_d_gas = pred_fake - pred_real 151 | r_reco_resize, reco_resize = pad_volume(r_reco, reco) 152 | d_norm = 2 * (r_reco_resize - reco_resize).abs().mean() 153 | loss_qp = loss_d_gas ** 2 / d_norm 154 | loss_d = loss_d_gas + loss_qp 155 | loss_d = torch.mean(loss_d) 156 | 157 | self.logger.info_scalars('GAS_d\t\t', (), {'loss_d': loss_d.item(), 'loss_d_gas': loss_d_gas.item(), 'loss_qp': loss_qp.item()}) 158 | loss_d.backward() 159 | self.optimizer_d.step() 160 | torch.cuda.empty_cache() 161 | 162 | self.optimizer_g.zero_grad() 163 | 164 | fake_gaps, _ = self.backbone(real_input, return_feature=False) 165 | fake_gaps = torch.cat([fake_gaps[0, :, :3], fake_gaps[0, :, 3:] / 100], dim=-1) 166 | fake_series = utils.simulation.dof_to_series(real_series[0:1, :, :], fake_gaps.unsqueeze(0)).squeeze(0) 167 | 168 | index_rate = self.cfg.reco_rate 169 | index_reco = (torch.arange(int(fake_series.shape[0] * index_rate), device=self.device) / index_rate).type(torch.int64) 170 | index_slice = torch.tensor([i for i in range(fake_series.shape[0]) if i not in index_reco], device=self.device) 171 | reco, min_point = utils.reconstruction.reco(down_source.index_select(1, index_reco).squeeze(0).squeeze(1), fake_series.index_select(0, index_reco), mat_scale=self.mat_scale) 172 | r_rec = F.interpolate(reco.unsqueeze(0).unsqueeze(0), scale_factor=1 / self.cfg.down_ratio).squeeze(0).squeeze(0) 173 | slices = utils.reconstruction.get_slice(r_rec, fake_series.index_select(0, index_slice) - min_point.unsqueeze(0).unsqueeze(0), real_source.shape[-2:]) 174 | 175 | loss_fcc = F.l1_loss(slices, real_source.index_select(1, index_slice)) * self.cfg.weight_fcc 176 | pred_fake = self.discriminator(reco.unsqueeze(0).unsqueeze(0)) 177 | loss_g_gas = -torch.mean(pred_fake) 178 | loss_g = loss_g_gas + loss_fcc 179 | 180 | self.logger.info_scalars('GAS_g+FCC\t', (), {'loss_g': loss_g.item(), 'loss_g_gas': loss_g_gas.item(), 'loss_fcc': loss_fcc.item()}) 181 | loss_g.backward() 182 | self.optimizer_g.step() 183 | 184 | with torch.no_grad(): 185 | self.backbone.eval() 186 | self.discriminator.eval() 187 | fake2_gaps, _ = self.backbone(real_input, return_feature=False) 188 | fake2_gaps = fake2_gaps[0, :, :] 189 | fake2_gaps[:, 3:] /= 100 190 | fake2_series = utils.simulation.dof_to_series(real_series[0:1, :, :], fake2_gaps.unsqueeze(0)).squeeze(0) 191 | losses2 = utils.metric.get_metric(real_series, fake2_series) 192 | value['fake_gaps'].append(fake2_gaps) 193 | value['fake_series'].append(fake2_series) 194 | value['loss'].append(losses2) 195 | 196 | self.backbone.eval() 197 | self.discriminator.eval() 198 | return value 199 | 200 | def test(self, epoch_info, sample_dict): 201 | if epoch_info['index'].item() < 0: 202 | return {} 203 | utils.common.set_seed(epoch_info['index'].item() * 42) 204 | real_source = sample_dict['source'].to(self.device) 205 | real_target = sample_dict['target'].to(self.device) 206 | edge = sample_dict['edge'].to(self.device) 207 | optical_flow = sample_dict['optical_flow'].to(self.device) 208 | frame_rate = sample_dict['frame_rate'] 209 | length = min(sample_dict['info'], real_source.shape[1]) 210 | 211 | value = self.test_optimize(epoch_info, real_source, real_target, edge, optical_flow, epoch=self.run.ol_epochs) 212 | value['frame_rate'] = [frame_rate] 213 | value['length'] = [length] 214 | 215 | path = os.path.join(self.path, 'RecON') 216 | if not os.path.exists(path): 217 | os.makedirs(path) 218 | source = value.pop('real_source') 219 | torch.save(source, os.path.join(path, 'source_' + str(epoch_info['index'].item()) + '.pth')) 220 | torch.save(value, os.path.join(path, 'value_' + str(epoch_info['index'].item()) + '.pth')) 221 | 222 | return {} 223 | -------------------------------------------------------------------------------- /utils/simulation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def get_axis(series, eps: float = 1e-20): 9 | old_dtype = series.dtype 10 | series = series.type(torch.float64) 11 | 12 | p1p2 = series[:, 1:3, :] - series[:, 0:1, :] 13 | ax_x = p1p2[:, 1, :] - p1p2[:, 0, :] 14 | ax_x = F.normalize(ax_x, p=2.0, dim=-1, eps=eps) 15 | ax_y = -p1p2[:, 1, :] - p1p2[:, 0, :] 16 | ax_y = F.normalize(ax_y, p=2.0, dim=-1, eps=eps) 17 | ax_z = torch.cross(ax_x, ax_y, dim=-1) 18 | ax_z = F.normalize(ax_z, p=2.0, dim=-1, eps=eps) 19 | axis = torch.stack([ax_x, ax_y, ax_z], dim=1) 20 | 21 | axis = axis.type(old_dtype) 22 | return axis 23 | 24 | 25 | def get_normal(points, eps: float = 1e-20): 26 | old_dtype = points.dtype 27 | points = points.type(torch.float64) 28 | 29 | p1p2 = points[1:, :] - points[0:1, :] 30 | p1p2 = F.normalize(p1p2, p=2.0, dim=-1, eps=eps) 31 | normal = torch.cross(p1p2[0, :], p1p2[1, :], dim=-1) 32 | 33 | normal = normal.type(old_dtype) 34 | return normal 35 | 36 | 37 | def get_quaternion_matrix(normal, sin_2, cos_2): 38 | old_dtype = normal.dtype 39 | normal = normal.type(torch.float64) 40 | sin_2 = sin_2.type(torch.float64) if isinstance(sin_2, torch.Tensor) else sin_2 41 | cos_2 = cos_2.type(torch.float64) if isinstance(cos_2, torch.Tensor) else cos_2 42 | 43 | bcd = sin_2 * normal 44 | b2, c2, d2 = (2 * bcd ** 2).split(1, dim=-1) 45 | ab, ac, ad = (2 * cos_2 * bcd).split(1, dim=-1) 46 | 47 | bc = 2 * bcd[:, 0:1] * bcd[:, 1:2] 48 | bd = 2 * bcd[:, 0:1] * bcd[:, 2:3] 49 | cd = 2 * bcd[:, 1:2] * bcd[:, 2:3] 50 | 51 | matrix = torch.cat([1 - c2 - d2, bc - ad, ac + bd, bc + ad, 1 - b2 - d2, cd - ab, bd - ac, ab + cd, 1 - b2 - c2], dim=-1) 52 | matrix = matrix.view(matrix.shape[0], 3, 3) 53 | 54 | matrix = matrix.type(old_dtype) 55 | return matrix 56 | 57 | 58 | def euler_from_matrix(matrix, eps: float = 1e-6): 59 | i, j, k = 0, 1, 2 60 | M = matrix[:, :3, :3] 61 | 62 | cy = torch.sqrt(M[:, i, i] * M[:, i, i] + M[:, j, i] * M[:, j, i]) 63 | ax = torch.atan2(M[:, k, j], M[:, k, k]) 64 | ay = torch.atan2(-M[:, k, i], cy) 65 | az = torch.atan2(M[:, j, i], M[:, i, i]) 66 | flag = cy <= eps 67 | ax2 = torch.atan2(-M[:, j, k], M[:, j, j]) 68 | ax[flag, ...] = ax2[flag, ...] 69 | az[flag, ...] = 0 70 | 71 | a = torch.stack([ax, ay, az], dim=-1) 72 | return a 73 | 74 | 75 | def euler_matrix(angle): 76 | i, j, k = 0, 1, 2 77 | ai, aj, ak = angle[:, 0], angle[:, 1], angle[:, 2] 78 | 79 | si, sj, sk = torch.sin(ai), torch.sin(aj), torch.sin(ak) 80 | ci, cj, ck = torch.cos(ai), torch.cos(aj), torch.cos(ak) 81 | cc, cs = ci * ck, ci * sk 82 | sc, ss = si * ck, si * sk 83 | 84 | M = torch.eye(4, dtype=ai.dtype, device=ai.device).unsqueeze(0).repeat(len(ai), 1, 1) 85 | M[:, i, i] = cj * ck 86 | M[:, i, j] = sj * sc - cs 87 | M[:, i, k] = sj * cc + ss 88 | M[:, j, i] = cj * sk 89 | M[:, j, j] = sj * ss + cc 90 | M[:, j, k] = sj * cs - sc 91 | M[:, k, i] = -sj 92 | M[:, k, j] = cj * si 93 | M[:, k, k] = cj * ci 94 | 95 | return M 96 | 97 | 98 | def affine_matrix_from_points(v0, v1): 99 | t0 = -torch.mean(v0, dim=-1) 100 | v0 = v0 + t0.unsqueeze(-1) 101 | t1 = -torch.mean(v1, dim=-1) 102 | v1 = v1 + t1.unsqueeze(-1) 103 | 104 | u, s, vh = torch.svd(torch.bmm(v1, v0.permute(0, 2, 1)).cpu()) 105 | if u.device != v0.device: 106 | u, vh = torch.cat([u, vh], dim=-1).to(v0.device).split(3, dim=-1) 107 | vh = vh.permute(0, 2, 1) 108 | R = torch.bmm(u, vh) 109 | 110 | flag = torch.det(R) < 0.0 111 | out = u[:, :, 2:3] * (vh[:, 2:3, :] * 2.0) 112 | R[flag, ...] = R[flag, ...] - out[flag, ...] 113 | 114 | M = torch.cat([R, torch.sum(R * t0.unsqueeze(1), dim=-1, keepdim=True) - t1.unsqueeze(-1)], dim=-1) 115 | M = F.pad(M, [0, 0, 0, 1]) 116 | M[:, -1, -1] = 1.0 117 | return M 118 | 119 | 120 | def quaternion_rotation_mul_theta(point, origin, normal, sin_2, cos_2): 121 | old_dtype = point.dtype 122 | point = point.type(torch.float64) 123 | origin = origin.type(torch.float64) 124 | normal = normal.type(torch.float64) 125 | sin_2 = sin_2.type(torch.float64) if isinstance(sin_2, torch.Tensor) else sin_2 126 | cos_2 = cos_2.type(torch.float64) if isinstance(cos_2, torch.Tensor) else cos_2 127 | 128 | point = point - origin 129 | matrix = get_quaternion_matrix(normal, sin_2, cos_2) 130 | next_points = (matrix @ point.T.unsqueeze(0).expand(matrix.shape[0], 3, 3)).permute(0, 2, 1) 131 | next_points = next_points + origin.unsqueeze(0) 132 | 133 | next_points = next_points.type(old_dtype) 134 | return next_points 135 | 136 | 137 | def series_to_dof(series): 138 | old_dtype = series.dtype 139 | series = series.type(torch.float64) 140 | 141 | angle_mat = get_axis(series[:-1]).permute(0, 2, 1) 142 | angle_mat_inv = torch.inverse(angle_mat) 143 | 144 | p0p1 = torch.bmm(torch.cat([angle_mat_inv, angle_mat_inv], dim=0), torch.cat([series[:-1, :, :] - series[:-1, 0:1, :], series[1:, :, :] - series[:-1, 0:1, :]], dim=0).permute(0, 2, 1)) 145 | trmat_ax_p0 = affine_matrix_from_points(p0p1[:len(angle_mat_inv)], p0p1[len(angle_mat_inv):]) 146 | angle_ax_p0 = euler_from_matrix(trmat_ax_p0) 147 | 148 | dist_ax_p0_tr = trmat_ax_p0[:, :3, 3] 149 | 150 | dofs = torch.cat([dist_ax_p0_tr, angle_ax_p0], dim=-1) 151 | dofs = dofs.type(old_dtype) 152 | return dofs 153 | 154 | 155 | def dof_to_series(start_point, dof): 156 | old_type = start_point.dtype 157 | start_point = start_point.type(torch.float64) 158 | dof = dof.type(torch.float64) 159 | 160 | b, t, _ = dof.shape 161 | dof = dof.view(b * t, -1) 162 | matrix = euler_matrix(dof[:, 3:]) 163 | matrix[:, :3, 3] = dof[:, :3] 164 | matrix = matrix.view(b, t, 4, 4) 165 | 166 | start_axis = get_axis(start_point).permute(0, 2, 1) 167 | start_matrix = torch.cat([start_axis, start_point[:, 0, :].unsqueeze(-1)], dim=-1) 168 | start_matrix = F.pad(start_matrix, (0, 0, 0, 1)) 169 | start_matrix[:, 3, 3] = 1 170 | start_matrix_inv = torch.inverse(start_matrix) 171 | 172 | matrix_chain = [start_matrix] 173 | for idx in range(matrix.shape[1]): 174 | matrix_chain.append(torch.bmm(matrix_chain[-1], matrix[:, idx])) 175 | matrix_chain = torch.stack(matrix_chain, dim=1) 176 | 177 | start_point_4d = F.pad(start_point, (0, 1)) 178 | start_point_4d[:, :, 3] = 1 179 | series = torch.einsum('btij,bjk,bkl->btil', matrix_chain, start_matrix_inv, start_point_4d.permute(0, 2, 1)).permute(0, 1, 3, 2)[..., :3] 180 | 181 | series = series.type(old_type) 182 | return series 183 | 184 | 185 | def series_to_mesh(series, height, width, origin=None): 186 | axis = get_axis(series).permute(0, 2, 1) 187 | 188 | if not hasattr(series_to_mesh, 'mesh') or height != series_to_mesh.height or width != series_to_mesh.width: 189 | range_x = torch.arange(-(height - 1) / 2, (height + 1) / 2, dtype=series.dtype, device=series.device) 190 | range_y = torch.arange(-(width - 1) / 2, (width + 1) / 2, dtype=series.dtype, device=series.device) 191 | mesh_x, mesh_y = torch.meshgrid(range_x, range_y, indexing='ij') 192 | mesh = torch.stack([mesh_y, -mesh_x, torch.zeros_like(mesh_x)], dim=-1) 193 | series_to_mesh.mesh = mesh 194 | series_to_mesh.height = height 195 | series_to_mesh.width = width 196 | 197 | center = series[:, 0, :].unsqueeze(1).unsqueeze(1) 198 | 199 | local_mesh = torch.einsum('Nij,HWj->NHWi', axis, series_to_mesh.mesh) + center 200 | if origin is not None: 201 | local_mesh = local_mesh + origin.unsqueeze(0).unsqueeze(0).unsqueeze(0) 202 | return local_mesh 203 | 204 | 205 | def is_in_ellipsoid(point, radius, keepdim=False): 206 | assert point.ndim == radius.ndim and radius.ndim in [1, 2] 207 | assert point.shape[-1] == radius.shape[-1] == 3 208 | return torch.norm(point / radius, 2, dim=-1, keepdim=keepdim) <= 1 209 | 210 | 211 | def draw_sobol_normal_samples(d: int, n: int, dtype: torch.dtype, device: torch.device): 212 | engine = torch.quasirandom.SobolEngine(dimension=d, scramble=True, seed=None) 213 | samples = engine.draw(n, dtype=dtype) 214 | v = 0.5 + (1 - torch.finfo(samples.dtype).eps) * (samples - 0.5) 215 | samples = torch.erfinv(2 * v - 1) * math.sqrt(2) 216 | return samples.to(device=device) 217 | 218 | 219 | def sample_hypersphere(d: int, n: int, dtype: torch.dtype, device: torch.device): 220 | if d == 1: 221 | rnd = torch.randint(0, 2, (n, 1), dtype=dtype, device=device) 222 | return 2 * rnd - 1 223 | rnd = torch.randn(n, d, dtype=dtype, device=device) 224 | samples = rnd / torch.norm(rnd, dim=-1, keepdim=True) 225 | return samples 226 | 227 | 228 | def sample_ellipsoid(n, radius=None): 229 | if not hasattr(sample_ellipsoid, 'L_inv'): 230 | assert isinstance(radius, torch.Tensor) 231 | C = 1 / radius ** 2 * torch.eye(3, dtype=torch.float64, device=radius.device) 232 | sample_ellipsoid.L_inv = torch.inverse(torch.linalg.cholesky(C)).unsqueeze(0) 233 | sample_ellipsoid.L_inv = sample_ellipsoid.L_inv.type(radius.dtype) 234 | 235 | theta = draw_sobol_normal_samples(d=3, n=n, dtype=radius.dtype, device=radius.device) 236 | theta = F.normalize(theta, 2, dim=-1) 237 | r = torch.rand(n, 1, dtype=radius.dtype, device=radius.device) ** (1 / 3) 238 | x = r * theta 239 | u = sample_ellipsoid.L_inv.expand(x.shape[0], 3, 3).bmm(x.unsqueeze(-1)).squeeze(-1) 240 | return u 241 | 242 | 243 | def sample_points(n, height, width, radius): 244 | old_type = radius.dtype 245 | radius = radius.type(torch.float64) 246 | 247 | if not hasattr(sample_points, 'axis_x') or n != len(sample_points.axis_x): 248 | sample_points.axis_x = torch.tensor([[1.0, 0.0, 0.0]], dtype=radius.dtype, device=radius.device).expand((n, 3)) 249 | sample_points.axis_y = torch.tensor([[0.0, 1.0, 0.0]], dtype=radius.dtype, device=radius.device).expand((n, 3)) 250 | sample_points.axis_x = sample_points.axis_x.type(radius.dtype) 251 | sample_points.axis_y = sample_points.axis_y.type(radius.dtype) 252 | 253 | centers = sample_ellipsoid(n, radius=radius) 254 | ax = sample_hypersphere(d=3, n=n, dtype=radius.dtype, device=radius.device) 255 | 256 | plane_a = ax.cross(sample_points.axis_x, dim=-1) 257 | plane_a_ = ax.cross(sample_points.axis_y, dim=-1) 258 | flag_plane_a = torch.norm(plane_a, 1, dim=-1) == 0 259 | plane_a[flag_plane_a] = plane_a_[flag_plane_a] 260 | plane_a = F.normalize(plane_a, p=2, dim=-1) 261 | plane_b = ax.cross(plane_a, dim=-1) 262 | theta = torch.rand((n, 1), dtype=centers.dtype, device=centers.device) * 2 * np.pi 263 | ay = torch.cos(theta) * plane_a + torch.sin(theta) * plane_b 264 | 265 | h, w = ay * height / 2, ax * width / 2 266 | corner1s, corner2s = centers - h - w, centers - h + w 267 | 268 | return torch.stack([centers, corner1s, corner2s], dim=1).type(old_type) 269 | 270 | 271 | def sample_points_at_border(n, height, width, radius, direct_down='-axis_y'): 272 | old_type = radius.dtype 273 | radius = radius.type(torch.float64) 274 | 275 | if not hasattr(sample_points_at_border, 'axis_x') or n != len(sample_points_at_border.axis_x): 276 | sample_points_at_border.axis_x = torch.tensor([[1.0, 0.0, 0.0]], dtype=radius.dtype, device=radius.device).expand((n, 3)) 277 | sample_points_at_border.axis_y = torch.tensor([[0.0, 1.0, 0.0]], dtype=radius.dtype, device=radius.device).expand((n, 3)) 278 | sample_points_at_border.axis_z = torch.tensor([[0.0, 0.0, 1.0]], dtype=radius.dtype, device=radius.device).expand((n, 3)) 279 | sample_points_at_border.axis_x = sample_points_at_border.axis_x.type(radius.dtype) 280 | sample_points_at_border.axis_y = sample_points_at_border.axis_y.type(radius.dtype) 281 | sample_points_at_border.axis_z = sample_points_at_border.axis_z.type(radius.dtype) 282 | 283 | if direct_down.startswith('-'): 284 | direct_down = direct_down[1:] 285 | direct_down_coefficient = -1 286 | else: 287 | direct_down_coefficient = 1 288 | 289 | centers = sample_ellipsoid(n, radius) 290 | dist_flag = is_in_ellipsoid(centers, radius * 0.8, keepdim=True) | (torch.abs(centers[:, 1:2]) > 0.6 * radius[0, 1]) 291 | flag_n = torch.sum(dist_flag) 292 | while flag_n > 0: 293 | centers_ = sample_ellipsoid(flag_n, radius) 294 | centers[dist_flag.squeeze(-1), :] = centers_ 295 | dist_flag = is_in_ellipsoid(centers, radius * 0.8, keepdim=True) | (torch.abs(centers[:, 1:2]) > 0.6 * radius[0, 1]) 296 | flag_n = torch.sum(dist_flag) 297 | 298 | direct_down = direct_down_coefficient * getattr(sample_points_at_border, direct_down) 299 | direct_down = sample_points_by_limit(direct_down, min_cos=0.94, n=1, dtype=radius.dtype, device=radius.device).squeeze(1) 300 | normals = sample_points_by_limit(-centers, min_cos=0.984, n=1, dtype=radius.dtype, device=radius.device).squeeze(1) 301 | normals = normals - torch.sum(normals * direct_down, dim=-1, keepdim=True) * direct_down 302 | normals = F.normalize(normals, p=2, dim=-1) 303 | direct_plane = normals.cross(direct_down, dim=-1) 304 | 305 | h, w = direct_down * height / 2, direct_plane * width / 2 306 | corner1s, corner2s = centers + h - w, centers + h + w 307 | 308 | return torch.stack([centers, corner1s, corner2s], dim=1).type(old_type) 309 | 310 | 311 | def sample_points_by_limit(normals, min_cos=-1.0, max_cos=1.0, n=1, dtype=torch.float32, device='cuda'): 312 | assert normals.ndim == 2 313 | assert not (torch.norm(normals, 1, dim=-1) == 0).any() 314 | 315 | A0 = normals 316 | n0, n1, n2 = torch.split(normals, 1, dim=-1) 317 | n0n1, n1n2, n0n2 = n0 * n1, n1 * n2, n0 * n2 318 | s0, s1, s2 = torch.split(normals ** 2, 1, dim=-1) 319 | zero = torch.zeros_like(n0) 320 | A1 = torch.cat([-n1, n0 - n2, n1], dim=-1) 321 | A1_exp = torch.cat([n2, zero, -n0], dim=-1) 322 | A1_zero = torch.norm(A1, 1, dim=-1) == 0 323 | A1[A1_zero] = A1_exp[A1_zero] 324 | A2 = torch.cat([s1 + s2 - n0n2, -n0n1 - n1n2, s0 + s1 - n0n2], dim=-1) 325 | A2_exp = torch.cat([n1, -n0, zero], dim=-1) 326 | A2_zero = torch.norm(A2, 1, dim=-1) == 0 327 | A2[A2_zero] = A2_exp[A2_zero] 328 | 329 | A = torch.stack([A0, A1, A2], dim=1) 330 | A = A / torch.norm(A, 2, dim=-1, keepdim=True) 331 | A = A.permute(0, 2, 1) 332 | 333 | r = torch.rand((normals.shape[0], 2, n), dtype=dtype, device=device) 334 | theta, t = torch.split(r, 1, dim=1) 335 | min_theta = np.arccos(max_cos) 336 | max_theta = np.arccos(min_cos) 337 | theta = (1 - theta) * (max_theta - min_theta) + min_theta 338 | length_shallow, radius = torch.cos(theta), torch.sin(theta) 339 | t = t * np.pi * 2 340 | cos, sin = radius * torch.cos(t), radius * torch.sin(t) 341 | mat = torch.cat([length_shallow, cos, sin], dim=1) 342 | 343 | result = torch.bmm(A, mat).permute(0, 2, 1) 344 | result = result / torch.norm(result, dim=-1, keepdim=True) 345 | 346 | return result 347 | 348 | 349 | def get_slices(volume, series, height, width, volume_origin): 350 | mesh = series_to_mesh(series, height=height, width=width, origin=volume_origin) 351 | volume_size = torch.tensor(volume.shape, dtype=mesh.dtype, device=mesh.device).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) 352 | mesh = mesh.unsqueeze(0) / volume_size 353 | mesh = (mesh * 2 - 1).flip(-1) 354 | slices = F.grid_sample(volume.unsqueeze(0).unsqueeze(0), mesh, mode='bilinear', padding_mode='border', align_corners=False) 355 | slices = slices.squeeze(0).squeeze(0) 356 | return slices 357 | 358 | 359 | def solve_quadratic_equation(a, b, c): 360 | delta = b ** 2 - 4 * a * c 361 | assert delta >= 0.0, delta 362 | if a == 0.0: 363 | return -c / b 364 | theta = delta.sqrt() if isinstance(a, (torch.Tensor, np.ndarray)) else math.sqrt(delta) 365 | a2 = 2 * a 366 | x1, x2 = (-b + theta) / a2, (-b - theta) / a2 367 | mi, ma = min(x1, x2), max(x1, x2) 368 | if mi >= 0: 369 | return mi 370 | if ma >= 0: 371 | return ma 372 | return None 373 | 374 | 375 | def update_dva(d, v, a): 376 | flag = False 377 | if d < 0: 378 | d, v, a = -d, -v, -a 379 | flag = True 380 | _a, _b, _c = a / 2, v, -d 381 | nd, nv, na = d, v, a 382 | d = _b ** 2 / (4 * _c) 383 | if _a < d: 384 | na = np.random.rand() * d * 2 385 | if flag: 386 | nd, nv, na = -nd, -nv, -na 387 | return nd, nv, na 388 | 389 | 390 | def sample_line(start_points, end_point, start_velocity, start_accele, radius, n=None, min_n=None): 391 | old_dtype = start_points.dtype 392 | start_points = start_points.type(torch.float64) 393 | end_point = end_point.type(torch.float64) 394 | start_velocity = start_velocity.type(torch.float64) if isinstance(start_velocity, torch.Tensor) else start_velocity 395 | start_accele = start_accele.type(torch.float64) if isinstance(start_accele, torch.Tensor) else start_accele 396 | radius = radius.type(torch.float64) 397 | 398 | gap = end_point - start_points[0, :] 399 | gap0l = torch.norm(gap, 2, dim=-1, keepdim=True) 400 | gap0d = gap / gap0l 401 | 402 | sv0 = torch.sum(start_velocity * gap0d) * gap0d 403 | sv1 = start_velocity - sv0 404 | sa0 = torch.sum(start_accele * gap0d) * gap0d 405 | sa1 = start_accele - sa0 406 | 407 | _, _, nsa0 = update_dva(gap0l, (sv0 / gap0d)[0], (sa0 / gap0d)[0]) 408 | sa0 = nsa0 * gap0d 409 | t = solve_quadratic_equation((sa0 / gap0d)[0] / 2, (sv0 / gap0d)[0], -gap0l) 410 | b = -(6 * sv1 + 3 * sa1 * t) / (t ** 2) 411 | 412 | n = n or max(int(torch.round(t)), 1) 413 | if min_n is not None and n < min_n: 414 | return None 415 | tgap = t / n 416 | 417 | idx = torch.arange(1, n + 1, dtype=start_points.dtype, device=start_points.device).unsqueeze(-1) 418 | tidx = idx * tgap 419 | tidx = tidx + 0.05 * tgap * torch.randn_like(tidx) 420 | t2 = tidx ** 2 421 | t3 = tidx * t2 422 | 423 | s0 = sv0 * tidx + sa0 * t2 / 2 424 | s1 = sv1 * tidx + sa1 * t2 / 2 + b * t3 / 6 425 | s = s0 + s1 426 | 427 | series = start_points.unsqueeze(0) + s.unsqueeze(1) 428 | length_velocity = torch.norm(series - torch.cat([start_points.unsqueeze(0), series[:-1, :, :]], dim=0), 2, dim=[1, 2]) 429 | length_rand = torch.randn((series.shape[0], 1, 3), dtype=series.dtype, device=series.device) 430 | series = series + (0.01 * length_velocity.unsqueeze(-1).unsqueeze(-1) * length_rand).expand(series.shape[0], 3, 3) 431 | 432 | flag = is_in_ellipsoid(series[:, 0, :], radius) 433 | flag = ~(torch.cumsum(~flag, dim=-1) > 0) 434 | series = series[flag, ...] 435 | if min_n is not None and len(series) < min_n: 436 | return None 437 | 438 | series = series.type(old_dtype) 439 | return series 440 | 441 | 442 | def sample_sector(start_points, sector_point, end_direct, max_velocity, start_accele, end_accele, radius, n=None, min_n=None): 443 | old_dtype = start_points.dtype 444 | start_points = start_points.type(torch.float64) 445 | sector_point = sector_point.type(torch.float64) 446 | end_direct = end_direct.type(torch.float64) if isinstance(end_direct, torch.Tensor) else end_direct 447 | max_velocity = max_velocity.type(torch.float64) if isinstance(max_velocity, torch.Tensor) else max_velocity 448 | start_accele = start_accele.type(torch.float64) if isinstance(start_accele, torch.Tensor) else start_accele 449 | end_accele = end_accele.type(torch.float64) if isinstance(end_accele, torch.Tensor) else end_accele 450 | radius = radius.type(torch.float64) 451 | 452 | start_direct = start_points[0, :] - sector_point 453 | start_direct = start_direct / torch.norm(start_direct, 2) 454 | end_direct = end_direct / torch.norm(end_direct, 2) 455 | sector_direct = torch.cross(start_direct, end_direct, dim=-1) 456 | sector_direct = sector_direct / torch.norm(sector_direct, 2) 457 | cos = torch.sum(start_direct * end_direct) 458 | cos.clamp_(-1.0 + 1.0e-7, 1.0 - 1.0e-7) 459 | angle = torch.acos(cos).item() 460 | 461 | ts = 0 462 | gap_s = start_accele * ts ** 2 / 2 463 | te = 0 464 | gap_e = -end_accele * te ** 2 / 2 465 | gap_med = angle - gap_s - gap_e 466 | tm = gap_med / max_velocity 467 | 468 | if tm < 0: 469 | ts, te = 0, 0 470 | tm = angle / max_velocity 471 | gap_s, gap_e = 0, 0 472 | gap_med = angle 473 | 474 | t = ts + tm + te 475 | n = n or max(int(round(t)), 1) 476 | if min_n is not None and n < min_n: 477 | return None 478 | tgap = t / n 479 | 480 | idx = torch.arange(1, n + 1, dtype=start_points.dtype, device=start_points.device).unsqueeze(-1) 481 | tidx = idx * tgap 482 | tidx = tidx + 0.05 * tgap * torch.randn_like(tidx) 483 | 484 | flag_1 = tidx <= ts 485 | flag_2 = ~flag_1 & (tidx <= ts + tm) 486 | flag_3 = ~(flag_1 | flag_2) 487 | 488 | s1 = start_accele * tidx[flag_1] ** 2 / 2 489 | s2 = gap_s + max_velocity * (tidx[flag_2] - ts) 490 | tt = tidx[flag_3] - ts - tm 491 | s3 = gap_s + gap_med + max_velocity * tt + end_accele * tt ** 2 / 2 492 | s = torch.cat([s1, s2, s3], dim=-1) 493 | s = s + 0.1 * torch.cat([s[1:] - s[:-1], s[0:1]], dim=-1) * torch.randn_like(s) 494 | 495 | s_ = s.unsqueeze(-1) 496 | s_sin = torch.sin(s_ / 2) 497 | s_cos = torch.cos(s_ / 2) 498 | series = quaternion_rotation_mul_theta(start_points, sector_point.unsqueeze(0), sector_direct, s_sin, s_cos) 499 | 500 | flag = is_in_ellipsoid(series[:, 0, :], radius) 501 | flag = ~(torch.cumsum(~flag, dim=-1) > 0) 502 | series = series[flag, ...] 503 | if min_n is not None and len(series) < min_n: 504 | return None 505 | 506 | series = series.type(old_dtype) 507 | return series 508 | -------------------------------------------------------------------------------- /datasets/DDH.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import configs 11 | import datasets 12 | import utils 13 | 14 | __all__ = ['DDH'] 15 | 16 | 17 | class DDH(datasets.BaseDataset): 18 | 19 | def __init__(self, cfg, **kwargs): 20 | super().__init__(cfg, **kwargs) 21 | self.origin = torch.tensor(self.cfg.source.origin, dtype=torch.float32, device=self.cfg.device) 22 | self.rad = torch.tensor([self.cfg.source.max_distance], dtype=torch.float32, device=self.cfg.device) 23 | self.direct_down = '-axis_y' 24 | 25 | self.pd = {} 26 | self.pr = {} 27 | self.pi = {} 28 | 29 | @staticmethod 30 | def more(cfg): 31 | cfg.source.elements = cfg.source.width * cfg.source.height * cfg.source.channel 32 | cfg.paths.source = configs.env.getdir(cfg.paths.source) 33 | cfg.paths.order = configs.env.getdir(cfg.paths.order) 34 | 35 | cfg.num_workers = 0 36 | cfg.pin_memory = False 37 | 38 | cfg.cuda_data_count = 0 39 | cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | cfg.load_mode = getattr(cfg, 'load_mode', 'memory') 41 | 42 | cfg.while_max_time = 20 43 | 44 | return cfg 45 | 46 | def load(self): 47 | if self.cfg.load_mode == 'disk': 48 | source_data = [] 49 | for file in sorted(os.listdir(self.cfg.paths.source)): 50 | if file.endswith('.nii.gz'): 51 | source = os.path.join(self.cfg.paths.source, file) 52 | source_data.append(source) 53 | else: 54 | source_data_npy_path = os.path.join(self.cfg.paths.source, self.__class__.__name__ + '_data.npy') 55 | if not os.path.exists(source_data_npy_path): 56 | source_data = [] 57 | for file in sorted(os.listdir(self.cfg.paths.source)): 58 | if file.endswith('.nii.gz'): 59 | source = nib.load(os.path.join(self.cfg.paths.source, file)).get_fdata(dtype=np.float32) 60 | source_data.append(source) 61 | np.save(source_data_npy_path, source_data) 62 | else: 63 | source_data = np.load(source_data_npy_path, allow_pickle=True) 64 | 65 | if self.cfg.load_mode == 'memory': 66 | source_data = [torch.from_numpy(s).to(self.cfg.device) if i < self.cfg.cuda_data_count else torch.from_numpy(s) for i, s in enumerate(source_data)] 67 | 68 | if not os.path.exists(self.cfg.paths.order): 69 | order = np.arange(len(source_data)) 70 | np.save(self.cfg.paths.order, order) 71 | warnings.warn(f'Index file `{self.cfg.paths.order}` is created!') 72 | else: 73 | order = np.load(self.cfg.paths.order) 74 | 75 | source_data = [source_data[idx] for idx in order] 76 | self.order = order 77 | 78 | trainset_length = int(self.cfg.series_per_data[0] * self.cfg.train_test_range[0]) 79 | valset_length = int(self.cfg.series_per_data[1] * self.cfg.train_test_range[1]) 80 | testset_length = int(self.cfg.series_per_data[2] * self.cfg.train_test_range[2]) 81 | data_count = trainset_length + valset_length + testset_length 82 | 83 | return {'source': source_data}, data_count 84 | 85 | def get_volume(self, idx): 86 | if self.cfg.load_mode == 'disk': 87 | volume = torch.from_numpy(nib.load(self.data['source'][idx]).get_fdata(dtype=np.float32)) 88 | else: 89 | volume = self.data['source'][idx] 90 | volume = volume.to(self.cfg.device) 91 | return volume 92 | 93 | def sample_series(self, points, n, action, rev=False): 94 | old_dtype = points.dtype 95 | points = points.type(torch.float64) 96 | series = [points] 97 | infos = {} 98 | 99 | count_actions = 1 100 | infos['count_action'] = count_actions 101 | infos['actions'] = [] 102 | infos['count_slice_in_action'] = [1] 103 | 104 | for idx in range(count_actions): 105 | start_points = series[-1] 106 | normal = utils.simulation.get_normal(series[-1]) 107 | infos['actions'].append(action) 108 | 109 | if action == 'line': 110 | line_series = None 111 | while_time = 0 112 | nn = 7 + np.random.randint(-2, 3) 113 | while line_series is None: 114 | if while_time >= self.cfg.while_max_time: 115 | return None 116 | while_time += 1 117 | direct = utils.simulation.sample_points_by_limit(normal.unsqueeze(0) * (-1 if rev else 1), min_cos=0.98, n=1, dtype=normal.dtype, device=self.cfg.device).squeeze() 118 | length = 12 * np.random.rand() + 32 119 | end_point = start_points[0, :] + direct * length 120 | start_velocity = direct * (2.0 + 0.2 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 121 | start_accele = 0 122 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 123 | line_series = line_series[:nn - 1, ...] 124 | infos['count_slice_in_action'].append(len(line_series)) 125 | series.extend(line_series) 126 | action_length = len(line_series) 127 | 128 | line_series = None 129 | while_time = 0 130 | nn = 4 + np.random.randint(-2, 2) 131 | while line_series is None: 132 | if while_time >= self.cfg.while_max_time: 133 | return None 134 | while_time += 1 135 | length = 60 * np.random.rand() + 180 136 | end_point = end_point + direct * length 137 | start_velocity = direct * (18 + 6.0 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 138 | start_accele = 0 139 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 140 | line_series = line_series[:nn - 1, ...] 141 | infos['count_slice_in_action'][-1] += len(line_series) 142 | series.extend(line_series) 143 | action_length += len(line_series) 144 | 145 | line_series = None 146 | while_time = 0 147 | nn = 7 + np.random.randint(-2, 3) 148 | while line_series is None: 149 | if while_time >= self.cfg.while_max_time: 150 | return None 151 | while_time += 1 152 | length = 12 * np.random.rand() + 32 153 | end_point = end_point + direct * length 154 | start_velocity = direct * (2.0 + 0.2 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 155 | start_accele = 0 156 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 157 | line_series = line_series[:nn - 1, ...] 158 | infos['count_slice_in_action'][-1] += len(line_series) 159 | series.extend(line_series) 160 | action_length += len(line_series) 161 | 162 | line_series = None 163 | while_time = 0 164 | nn = 4 + np.random.randint(-2, 2) 165 | while line_series is None: 166 | if while_time >= self.cfg.while_max_time: 167 | return None 168 | while_time += 1 169 | length = 60 * np.random.rand() + 180 170 | end_point = end_point + direct * length 171 | start_velocity = direct * (18 + 6.0 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 172 | start_accele = 0 173 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 174 | line_series = line_series[:nn - 1, ...] 175 | infos['count_slice_in_action'][-1] += len(line_series) 176 | series.extend(line_series) 177 | action_length += len(line_series) 178 | 179 | elif action == 'sector': 180 | length_edge = torch.norm(start_points[2, :] - start_points[1, :], 2) 181 | 182 | sector_point = start_points[0, :] - (start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :]) / 2 183 | sector_series = None 184 | while_time = 0 185 | nn = 4 + np.random.randint(-2, 3) 186 | while sector_series is None: 187 | if while_time >= self.cfg.while_max_time: 188 | return None 189 | while_time += 1 190 | pc = start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :] 191 | 192 | pc_norm = F.normalize(pc, p=2, dim=0) 193 | p1p2 = (start_points[2, :] - start_points[1, :]) / length_edge 194 | ed_cos = 1 195 | wwhile_time = 0 196 | while ed_cos < -0.5 or ed_cos > 0.5: 197 | if wwhile_time >= self.cfg.while_max_time: 198 | return None 199 | wwhile_time += 1 200 | min_cos, max_cos = (0.98, 1.0) if rev else (0.93, 0.98) 201 | end_direct = utils.simulation.sample_points_by_limit(pc.unsqueeze(0), min_cos=min_cos, max_cos=max_cos, n=1, dtype=pc.dtype, device=self.cfg.device).squeeze() 202 | ed = end_direct - torch.sum(end_direct * pc_norm) * pc_norm 203 | ed = F.normalize(ed, p=2, dim=0) 204 | ed_cos = torch.sum(ed * p1p2) 205 | 206 | max_velocity = np.pi / 180 * (1.5 + 0.05 * np.random.randn()) 207 | start_accele = 0 208 | end_accele = 0 209 | sector_series = utils.simulation.sample_sector(series[-1], sector_point, end_direct, max_velocity, start_accele, end_accele, radius=self.rad, n=n, min_n=nn - 1) 210 | sector_series = sector_series[:nn - 1, ...] 211 | infos['count_slice_in_action'].append(len(sector_series)) 212 | series.extend(sector_series) 213 | action_length = len(sector_series) 214 | 215 | sector_point = start_points[0, :] - (start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :]) / 2 216 | sector_series = None 217 | while_time = 0 218 | nn = 4 + np.random.randint(-2, 3) 219 | while sector_series is None: 220 | if while_time >= self.cfg.while_max_time: 221 | return None 222 | while_time += 1 223 | pc = start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :] 224 | 225 | # TODO more faster, not while 226 | pc_norm = F.normalize(pc, p=2, dim=0) 227 | p1p2 = (start_points[2, :] - start_points[1, :]) / length_edge 228 | ed_cos = 1 229 | wwhile_time = 0 230 | while ed_cos < -0.5 or ed_cos > 0.5: 231 | if wwhile_time >= self.cfg.while_max_time: 232 | return None 233 | wwhile_time += 1 234 | min_cos, max_cos = (0.93, 0.98) if rev else (0.98, 1.0) 235 | end_direct = utils.simulation.sample_points_by_limit(pc.unsqueeze(0), min_cos=min_cos, max_cos=max_cos, n=1, dtype=pc.dtype, device=self.cfg.device).squeeze() 236 | ed = end_direct - torch.sum(end_direct * pc_norm) * pc_norm 237 | ed = F.normalize(ed, p=2, dim=0) 238 | ed_cos = torch.sum(ed * p1p2) 239 | 240 | max_velocity = np.pi / 180 * (1.5 + 0.05 * np.random.randn()) 241 | start_accele = 0 242 | end_accele = 0 243 | sector_series = utils.simulation.sample_sector(series[-1], sector_point, end_direct, max_velocity, start_accele, end_accele, radius=self.rad, n=n, min_n=nn - 1) 244 | sector_series = sector_series[:nn - 1, ...] 245 | infos['count_slice_in_action'][-1] += len(sector_series) 246 | series.extend(sector_series) 247 | action_length += len(sector_series) 248 | 249 | else: 250 | raise ValueError('{} is not an action type.'.format(action)) 251 | 252 | series = torch.stack(series, dim=0) 253 | infos['count_slices'] = len(series) 254 | infos['series'] = series.type(old_dtype) 255 | return infos 256 | 257 | def slice_n(self, idx, n, frame_rate=None, optical_flow=True, edge=True): 258 | data = self.get_volume(idx) 259 | 260 | points = utils.simulation.sample_points_at_border(n, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 261 | series = [] 262 | 263 | ii = 0 264 | while ii < n: 265 | while_broke_flag = False 266 | while_broke_count1 = 0 267 | while_broke_count2 = 0 268 | rev = False 269 | has_sector = False 270 | has_sector2 = False 271 | 272 | ii_loop = 0 273 | loops = 7 274 | ss_loop = None 275 | 276 | while ii_loop < loops: 277 | while_broke_flag = False 278 | 279 | if not has_sector2 and (np.random.rand() < 2 / (loops - 1) or ii_loop == loops - 2): 280 | action = 'sector' 281 | else: 282 | if np.random.rand() < 0.33: 283 | action = 'sector' 284 | else: 285 | action = 'line' 286 | 287 | start_point = points[ii] if ii_loop == 0 else ss_loop['series'][-1] 288 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 289 | while_time = 0 290 | while ss is None: 291 | while_time += 1 292 | if while_time >= self.cfg.while_max_time: 293 | if ii_loop == 0: 294 | random_point = utils.simulation.sample_points_at_border(1, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 295 | points[ii] = random_point[0, ...] 296 | while_broke_flag = True 297 | break 298 | rev = not rev 299 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 300 | if while_broke_flag: 301 | while_broke_flag = False 302 | rev = not rev 303 | while_broke_count1 += 1 304 | if while_broke_count1 >= self.cfg.while_max_time: 305 | while_broke_count1 = 0 306 | while_broke_count2 = 0 307 | rev = False 308 | has_sector = False 309 | has_sector2 = False 310 | ii_loop = 0 311 | ss_loop = None 312 | continue 313 | if action != 'sector': 314 | if frame_rate is not None: 315 | ss['series'] = ss['series'][::frame_rate[ii]] 316 | ss['count_slices'] = len(ss['series']) 317 | ss['count_slice_in_action'][-1] = ss['count_slices'] - 1 318 | 319 | slices = utils.simulation.get_slices(data, ss['series'], height=self.cfg.source.height, width=self.cfg.source.width, volume_origin=self.origin) 320 | slices_all_value = (torch.norm(slices, 1, dim=[1, 2]) == 0).any() 321 | w_time = 0 322 | while slices_all_value: 323 | w_time += 1 324 | if w_time >= self.cfg.while_max_time: 325 | while_broke_flag = True 326 | break 327 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 328 | while_time = 0 329 | while ss is None: 330 | while_time += 1 331 | if while_time >= self.cfg.while_max_time: 332 | if ii_loop == 0: 333 | random_point = utils.simulation.sample_points_at_border(1, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 334 | points[ii] = random_point[0, ...] 335 | while_broke_flag = True 336 | break 337 | rev = not rev 338 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 339 | if while_broke_flag: 340 | break 341 | if action != 'sector': 342 | if frame_rate is not None: 343 | ss['series'] = ss['series'][::frame_rate[ii]] 344 | ss['count_slices'] = len(ss['series']) 345 | ss['count_slice_in_action'][-1] = ss['count_slices'] - 1 346 | slices = utils.simulation.get_slices(data, ss['series'], height=self.cfg.source.height, width=self.cfg.source.width, volume_origin=self.origin) 347 | slices_all_value = (torch.norm(slices, 1, dim=[1, 2]) == 0).any() 348 | if while_broke_flag: 349 | while_broke_flag = False 350 | rev = not rev 351 | while_broke_count2 += 1 352 | if while_broke_count2 >= self.cfg.while_max_time: 353 | while_broke_count1 = 0 354 | while_broke_count2 = 0 355 | rev = False 356 | has_sector = False 357 | has_sector2 = False 358 | ii_loop = 0 359 | ss_loop = None 360 | continue 361 | 362 | ss['slices'] = slices 363 | if ii_loop == 0: 364 | ss_loop = ss 365 | else: 366 | ss_loop['count_slice_in_action'][-1] += ss['count_slice_in_action'][-1] 367 | ss_loop['count_slices'] += ss['count_slice_in_action'][-1] 368 | ss_loop['series'] = torch.cat([ss_loop['series'], ss['series'][1:, ...]], dim=0) 369 | ss_loop['slices'] = torch.cat([ss_loop['slices'], ss['slices'][1:, ...]], dim=0) 370 | if ss_loop['count_slices'] >= self.cfg.series_min_length: 371 | break 372 | ii_loop += 1 373 | 374 | if action == 'sector': 375 | if not has_sector: 376 | has_sector = True 377 | else: 378 | has_sector2 = True 379 | 380 | if not while_broke_flag: 381 | if frame_rate is not None: 382 | ss_loop['frame_rate'] = frame_rate[ii] 383 | ss_loop['gaps'] = utils.simulation.series_to_dof(ss_loop['series']) 384 | ss_loop['indices'] = torch.randperm(len(ss_loop['gaps'])) 385 | if optical_flow: 386 | ss_loop['opticalflow'] = utils.image.get_optical_flow(ss_loop['slices'], device=self.cfg.device) 387 | if edge: 388 | ss_loop['edge'] = utils.image.get_edge(ss_loop['slices'], device=self.cfg.device) 389 | if frame_rate is None: 390 | assert len(ss_loop['slices']) == len(ss_loop['series']) \ 391 | and len(ss_loop['series']) == ss_loop['count_slices'] \ 392 | and ss_loop['count_slices'] == ss_loop['count_slice_in_action'][-1] + 1, \ 393 | (len(ss_loop['slices']), len(ss_loop['series']), ss_loop['count_slices'], ss_loop['count_slice_in_action'][-1]) 394 | series.append(ss_loop) 395 | ii += 1 396 | 397 | return series 398 | 399 | def pad(self, data, max_length=None): 400 | if data is None: 401 | return None 402 | max_length = max_length or self.cfg.series_max_length 403 | if len(data) > max_length: 404 | data = data[:max_length] 405 | return data 406 | 407 | def __getitem__(self, index): 408 | idx_num = self.get_idx(index) 409 | idx = str(idx_num) 410 | 411 | if idx not in self.pr.keys() or len(self.pr[idx]) == 0: 412 | ps = self.cfg.ps if index < self.trainset_length else 1 413 | fr = torch.randint(self.cfg.frame_rate[0], self.cfg.frame_rate[1] + 1, (ps,)) 414 | self.pd[idx] = self.slice_n(idx_num, n=ps, frame_rate=fr) 415 | self.pr[idx] = torch.randperm(ps) 416 | data = self.pd[idx][self.pr[idx][0].item()] 417 | 418 | info = torch.tensor([len(data['slices'])]) 419 | frame_rate = data['frame_rate'] 420 | pad_slices = self.pad(data['slices']) 421 | pad_gaps = self.pad(data['gaps'], self.cfg.series_max_length - 1) 422 | pad_series = self.pad(data['series']) 423 | pad_optical_flow = self.pad(data['opticalflow'], self.cfg.series_max_length - 1) 424 | pad_edge = self.pad(data['edge']) 425 | 426 | source = pad_slices.unsqueeze(1) 427 | target = torch.cat([F.pad(pad_gaps, (0, 0, 0, 1)), pad_series.view(-1, 9)], dim=-1) 428 | 429 | self.pd[idx][self.pr[idx][0].item()] = [] 430 | self.pr[idx] = self.pr[idx][1:] 431 | 432 | sample_dict = { 433 | 'source': source, 'target': target, 434 | 'optical_flow': pad_optical_flow, 435 | 'edge': pad_edge, 436 | 'frame_rate': frame_rate, 437 | 'info': info 438 | } 439 | 440 | utils.common.set_seed(int(time.time() * 1000) % (1 << 32) + index) 441 | return sample_dict, index 442 | -------------------------------------------------------------------------------- /datasets/Fetus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import configs 11 | import datasets 12 | import utils 13 | 14 | __all__ = ['Fetus'] 15 | 16 | 17 | class Fetus(datasets.BaseDataset): 18 | 19 | def __init__(self, cfg, **kwargs): 20 | super().__init__(cfg, **kwargs) 21 | self.origin = torch.tensor(self.cfg.source.origin, dtype=torch.float32, device=self.cfg.device) 22 | self.rad = torch.tensor([self.cfg.source.max_distance], dtype=torch.float32, device=self.cfg.device) 23 | self.direct_down = 'axis_z' 24 | 25 | self.pd = {} 26 | self.pr = {} 27 | self.pi = {} 28 | 29 | @staticmethod 30 | def more(cfg): 31 | cfg.source.elements = cfg.source.width * cfg.source.height * cfg.source.channel 32 | cfg.paths.source = configs.env.getdir(cfg.paths.source) 33 | cfg.paths.order = configs.env.getdir(cfg.paths.order) 34 | 35 | cfg.num_workers = 0 36 | cfg.pin_memory = False 37 | 38 | cfg.cuda_data_count = 0 39 | cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | cfg.load_mode = getattr(cfg, 'load_mode', 'memory') 41 | 42 | cfg.while_max_time = 20 43 | 44 | return cfg 45 | 46 | def load(self): 47 | if self.cfg.load_mode == 'disk': 48 | source_data = [] 49 | for file in sorted(os.listdir(self.cfg.paths.source)): 50 | if file.endswith('.nii.gz'): 51 | source = os.path.join(self.cfg.paths.source, file) 52 | source_data.append(source) 53 | else: 54 | source_data_npy_path = os.path.join(self.cfg.paths.source, self.__class__.__name__ + '_data.npy') 55 | if not os.path.exists(source_data_npy_path): 56 | source_data = [] 57 | for file in sorted(os.listdir(self.cfg.paths.source)): 58 | if file.endswith('.nii.gz'): 59 | source = nib.load(os.path.join(self.cfg.paths.source, file)).get_fdata(dtype=np.float32) 60 | source_data.append(source) 61 | np.save(source_data_npy_path, source_data) 62 | else: 63 | source_data = np.load(source_data_npy_path, allow_pickle=True) 64 | 65 | if self.cfg.load_mode == 'memory': 66 | source_data = [torch.from_numpy(s).to(self.cfg.device) if i < self.cfg.cuda_data_count else torch.from_numpy(s) for i, s in enumerate(source_data)] 67 | 68 | if not os.path.exists(self.cfg.paths.order): 69 | order = np.arange(len(source_data)) 70 | np.save(self.cfg.paths.order, order) 71 | warnings.warn(f'Index file `{self.cfg.paths.order}` is created!') 72 | else: 73 | order = np.load(self.cfg.paths.order) 74 | 75 | source_data = [source_data[idx] for idx in order] 76 | self.order = order 77 | 78 | trainset_length = int(self.cfg.series_per_data[0] * self.cfg.train_test_range[0]) 79 | valset_length = int(self.cfg.series_per_data[1] * self.cfg.train_test_range[1]) 80 | testset_length = int(self.cfg.series_per_data[2] * self.cfg.train_test_range[2]) 81 | data_count = trainset_length + valset_length + testset_length 82 | 83 | return {'source': source_data}, data_count 84 | 85 | def get_volume(self, idx): 86 | if self.cfg.load_mode == 'disk': 87 | volume = torch.from_numpy(nib.load(self.data['source'][idx]).get_fdata(dtype=np.float32)) 88 | else: 89 | volume = self.data['source'][idx] 90 | volume = volume.to(self.cfg.device) 91 | return volume 92 | 93 | def sample_series(self, points, n, action, rev=False): 94 | old_dtype = points.dtype 95 | points = points.type(torch.float64) 96 | series = [points] 97 | infos = {} 98 | 99 | count_actions = 1 100 | infos['count_action'] = count_actions 101 | infos['actions'] = [] 102 | infos['count_slice_in_action'] = [1] 103 | 104 | for idx in range(count_actions): 105 | start_points = series[-1] 106 | normal = utils.simulation.get_normal(series[-1]) 107 | infos['actions'].append(action) 108 | 109 | if action == 'line': 110 | line_series = None 111 | while_time = 0 112 | nn = 7 + np.random.randint(-2, 3) 113 | while line_series is None: 114 | if while_time >= self.cfg.while_max_time: 115 | return None 116 | while_time += 1 117 | direct = utils.simulation.sample_points_by_limit(normal.unsqueeze(0) * (-1 if rev else 1), min_cos=0.98, n=1, dtype=normal.dtype, device=self.cfg.device).squeeze() 118 | length = 6 * np.random.rand() + 16 119 | end_point = start_points[0, :] + direct * length 120 | start_velocity = direct * (1.0 + 0.1 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 121 | start_accele = 0 122 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 123 | line_series = line_series[:nn - 1, ...] 124 | infos['count_slice_in_action'].append(len(line_series)) 125 | series.extend(line_series) 126 | action_length = len(line_series) 127 | 128 | line_series = None 129 | while_time = 0 130 | nn = 4 + np.random.randint(-2, 2) 131 | while line_series is None: 132 | if while_time >= self.cfg.while_max_time: 133 | return None 134 | while_time += 1 135 | length = 40 * np.random.rand() + 120 136 | end_point = end_point + direct * length 137 | start_velocity = direct * (12 + 4.0 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 138 | start_accele = 0 139 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 140 | line_series = line_series[:nn - 1, ...] 141 | infos['count_slice_in_action'][-1] += len(line_series) 142 | series.extend(line_series) 143 | action_length += len(line_series) 144 | 145 | line_series = None 146 | while_time = 0 147 | nn = 7 + np.random.randint(-2, 3) 148 | while line_series is None: 149 | if while_time >= self.cfg.while_max_time: 150 | return None 151 | while_time += 1 152 | length = 6 * np.random.rand() + 16 153 | end_point = end_point + direct * length 154 | start_velocity = direct * (1.0 + 0.1 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 155 | start_accele = 0 156 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 157 | line_series = line_series[:nn - 1, ...] 158 | infos['count_slice_in_action'][-1] += len(line_series) 159 | series.extend(line_series) 160 | action_length += len(line_series) 161 | 162 | line_series = None 163 | while_time = 0 164 | nn = 4 + np.random.randint(-2, 2) 165 | while line_series is None: 166 | if while_time >= self.cfg.while_max_time: 167 | return None 168 | while_time += 1 169 | length = 40 * np.random.rand() + 120 170 | end_point = end_point + direct * length 171 | start_velocity = direct * (12 + 4.0 * np.random.randn()) + 0.05 * torch.randn((3,), dtype=direct.dtype, device=self.cfg.device) 172 | start_accele = 0 173 | line_series = utils.simulation.sample_line(series[-1], end_point, start_velocity, start_accele, radius=self.rad, n=n, min_n=nn - 1) 174 | line_series = line_series[:nn - 1, ...] 175 | infos['count_slice_in_action'][-1] += len(line_series) 176 | series.extend(line_series) 177 | action_length += len(line_series) 178 | 179 | elif action == 'sector': 180 | length_edge = torch.norm(start_points[2, :] - start_points[1, :], 2) 181 | 182 | sector_point = start_points[0, :] - (start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :]) / 2 183 | sector_series = None 184 | while_time = 0 185 | nn = 4 + np.random.randint(-2, 3) 186 | while sector_series is None: 187 | if while_time >= self.cfg.while_max_time: 188 | return None 189 | while_time += 1 190 | pc = start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :] 191 | 192 | pc_norm = F.normalize(pc, p=2, dim=0) 193 | p1p2 = (start_points[2, :] - start_points[1, :]) / length_edge 194 | ed_cos = 1 195 | wwhile_time = 0 196 | while ed_cos < -0.5 or ed_cos > 0.5: 197 | if wwhile_time >= self.cfg.while_max_time: 198 | return None 199 | wwhile_time += 1 200 | min_cos, max_cos = (0.98, 1.0) if rev else (0.93, 0.98) 201 | end_direct = utils.simulation.sample_points_by_limit(pc.unsqueeze(0), min_cos=min_cos, max_cos=max_cos, n=1, dtype=pc.dtype, device=self.cfg.device).squeeze() 202 | ed = end_direct - torch.sum(end_direct * pc_norm) * pc_norm 203 | ed = F.normalize(ed, p=2, dim=0) 204 | ed_cos = torch.sum(ed * p1p2) 205 | 206 | max_velocity = np.pi / 180 * (1.5 + 0.05 * np.random.randn()) 207 | start_accele = 0 208 | end_accele = 0 209 | sector_series = utils.simulation.sample_sector(series[-1], sector_point, end_direct, max_velocity, start_accele, end_accele, radius=self.rad, n=n, min_n=nn - 1) 210 | sector_series = sector_series[:nn - 1, ...] 211 | infos['count_slice_in_action'].append(len(sector_series)) 212 | series.extend(sector_series) 213 | action_length = len(sector_series) 214 | 215 | sector_point = start_points[0, :] - (start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :]) / 2 216 | sector_series = None 217 | while_time = 0 218 | nn = 4 + np.random.randint(-2, 3) 219 | while sector_series is None: 220 | if while_time >= self.cfg.while_max_time: 221 | return None 222 | while_time += 1 223 | pc = start_points[2, :] - start_points[0, :] + start_points[1, :] - start_points[0, :] 224 | 225 | # TODO more faster, not while 226 | pc_norm = F.normalize(pc, p=2, dim=0) 227 | p1p2 = (start_points[2, :] - start_points[1, :]) / length_edge 228 | ed_cos = 1 229 | wwhile_time = 0 230 | while ed_cos < -0.5 or ed_cos > 0.5: 231 | if wwhile_time >= self.cfg.while_max_time: 232 | return None 233 | wwhile_time += 1 234 | min_cos, max_cos = (0.93, 0.98) if rev else (0.98, 1.0) 235 | end_direct = utils.simulation.sample_points_by_limit(pc.unsqueeze(0), min_cos=min_cos, max_cos=max_cos, n=1, dtype=pc.dtype, device=self.cfg.device).squeeze() 236 | ed = end_direct - torch.sum(end_direct * pc_norm) * pc_norm 237 | ed = F.normalize(ed, p=2, dim=0) 238 | ed_cos = torch.sum(ed * p1p2) 239 | 240 | max_velocity = np.pi / 180 * (1.5 + 0.05 * np.random.randn()) 241 | start_accele = 0 242 | end_accele = 0 243 | sector_series = utils.simulation.sample_sector(series[-1], sector_point, end_direct, max_velocity, start_accele, end_accele, radius=self.rad, n=n, min_n=nn - 1) 244 | sector_series = sector_series[:nn - 1, ...] 245 | infos['count_slice_in_action'][-1] += len(sector_series) 246 | series.extend(sector_series) 247 | action_length += len(sector_series) 248 | 249 | else: 250 | raise ValueError('{} is not an action type.'.format(action)) 251 | 252 | series = torch.stack(series, dim=0) 253 | infos['count_slices'] = len(series) 254 | infos['series'] = series.type(old_dtype) 255 | return infos 256 | 257 | def slice_n(self, idx, n, frame_rate=None, optical_flow=True, edge=True): 258 | data = self.get_volume(idx) 259 | 260 | points = utils.simulation.sample_points_at_border(n, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 261 | series = [] 262 | 263 | ii = 0 264 | while ii < n: 265 | while_broke_flag = False 266 | while_broke_count1 = 0 267 | while_broke_count2 = 0 268 | rev = False 269 | has_sector = False 270 | has_sector2 = False 271 | 272 | ii_loop = 0 273 | loops = 7 274 | ss_loop = None 275 | 276 | while ii_loop < loops: 277 | while_broke_flag = False 278 | 279 | if not has_sector2 and (np.random.rand() < 2 / (loops - 1) or ii_loop == loops - 2): 280 | action = 'sector' 281 | else: 282 | action = 'line' 283 | 284 | start_point = points[ii] if ii_loop == 0 else ss_loop['series'][-1] 285 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 286 | while_time = 0 287 | while ss is None: 288 | while_time += 1 289 | if while_time >= self.cfg.while_max_time: 290 | if ii_loop == 0: 291 | random_point = utils.simulation.sample_points_at_border(1, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 292 | points[ii] = random_point[0, ...] 293 | while_broke_flag = True 294 | break 295 | rev = not rev 296 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 297 | if while_broke_flag: 298 | while_broke_flag = False 299 | rev = not rev 300 | while_broke_count1 += 1 301 | if while_broke_count1 >= self.cfg.while_max_time: 302 | while_broke_count1 = 0 303 | while_broke_count2 = 0 304 | rev = False 305 | has_sector = False 306 | has_sector2 = False 307 | ii_loop = 0 308 | ss_loop = None 309 | random_point = utils.simulation.sample_points_at_border(1, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 310 | points[ii] = random_point[0, ...] 311 | continue 312 | if action != 'sector': 313 | if frame_rate is not None: 314 | ss['series'] = ss['series'][::frame_rate[ii]] 315 | ss['count_slices'] = len(ss['series']) 316 | ss['count_slice_in_action'][-1] = ss['count_slices'] - 1 317 | 318 | slices = utils.simulation.get_slices(data, ss['series'], height=self.cfg.source.height, width=self.cfg.source.width, volume_origin=self.origin) 319 | slices_all_value = (torch.norm(slices, 1, dim=[1, 2]) == 0).any() 320 | w_time = 0 321 | while slices_all_value: 322 | w_time += 1 323 | if w_time >= self.cfg.while_max_time: 324 | while_broke_flag = True 325 | break 326 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 327 | while_time = 0 328 | while ss is None: 329 | while_time += 1 330 | if while_time >= self.cfg.while_max_time: 331 | if ii_loop == 0: 332 | random_point = utils.simulation.sample_points_at_border(1, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 333 | points[ii] = random_point[0, ...] 334 | while_broke_flag = True 335 | break 336 | rev = not rev 337 | ss = self.sample_series(start_point, n=None, action=action, rev=rev) 338 | if while_broke_flag: 339 | break 340 | if action != 'sector': 341 | if frame_rate is not None: 342 | ss['series'] = ss['series'][::frame_rate[ii]] 343 | ss['count_slices'] = len(ss['series']) 344 | ss['count_slice_in_action'][-1] = ss['count_slices'] - 1 345 | slices = utils.simulation.get_slices(data, ss['series'], height=self.cfg.source.height, width=self.cfg.source.width, volume_origin=self.origin) 346 | slices_all_value = (torch.norm(slices, 1, dim=[1, 2]) == 0).any() 347 | if while_broke_flag: 348 | while_broke_flag = False 349 | rev = not rev 350 | while_broke_count2 += 1 351 | if while_broke_count2 >= self.cfg.while_max_time: 352 | while_broke_count1 = 0 353 | while_broke_count2 = 0 354 | rev = False 355 | has_sector = False 356 | has_sector2 = False 357 | ii_loop = 0 358 | ss_loop = None 359 | random_point = utils.simulation.sample_points_at_border(n, height=self.cfg.source.height, width=self.cfg.source.width, radius=self.rad, direct_down=self.direct_down) 360 | points[ii] = random_point[0, ...] 361 | continue 362 | 363 | ss['slices'] = slices 364 | if ii_loop == 0: 365 | ss_loop = ss 366 | else: 367 | ss_loop['count_slice_in_action'][-1] += ss['count_slice_in_action'][-1] 368 | ss_loop['count_slices'] += ss['count_slice_in_action'][-1] 369 | ss_loop['series'] = torch.cat([ss_loop['series'], ss['series'][1:, ...]], dim=0) 370 | ss_loop['slices'] = torch.cat([ss_loop['slices'], ss['slices'][1:, ...]], dim=0) 371 | if ss_loop['count_slices'] >= self.cfg.series_min_length: 372 | break 373 | ii_loop += 1 374 | 375 | if action == 'sector': 376 | if not has_sector: 377 | has_sector = True 378 | else: 379 | has_sector2 = True 380 | 381 | if not while_broke_flag: 382 | if frame_rate is not None: 383 | ss_loop['frame_rate'] = frame_rate[ii] 384 | ss_loop['gaps'] = utils.simulation.series_to_dof(ss_loop['series']) 385 | ss_loop['indices'] = torch.randperm(len(ss_loop['gaps'])) 386 | if optical_flow: 387 | ss_loop['opticalflow'] = utils.image.get_optical_flow(ss_loop['slices'], device=self.cfg.device) 388 | if edge: 389 | ss_loop['edge'] = utils.image.get_edge(ss_loop['slices'], device=self.cfg.device) 390 | if frame_rate is None: 391 | assert len(ss_loop['slices']) == len(ss_loop['series']) \ 392 | and len(ss_loop['series']) == ss_loop['count_slices'] \ 393 | and ss_loop['count_slices'] == ss_loop['count_slice_in_action'][-1] + 1, \ 394 | (len(ss_loop['slices']), len(ss_loop['series']), ss_loop['count_slices'], ss_loop['count_slice_in_action'][-1]) 395 | series.append(ss_loop) 396 | ii += 1 397 | 398 | return series 399 | 400 | def pad(self, data, max_length=None): 401 | if data is None: 402 | return None 403 | max_length = max_length or self.cfg.series_max_length 404 | if len(data) > max_length: 405 | data = data[:max_length] 406 | return data 407 | 408 | def __getitem__(self, index): 409 | idx_num = self.get_idx(index) 410 | idx = str(idx_num) 411 | 412 | if idx not in self.pr.keys() or len(self.pr[idx]) == 0: 413 | ps = self.cfg.ps if index < self.trainset_length else 1 414 | fr = torch.randint(self.cfg.frame_rate[0], self.cfg.frame_rate[1] + 1, (ps,)) 415 | self.pd[idx] = self.slice_n(idx_num, n=ps, frame_rate=fr) 416 | self.pr[idx] = torch.randperm(ps) 417 | data = self.pd[idx][self.pr[idx][0].item()] 418 | 419 | info = torch.tensor([len(data['slices'])]) 420 | frame_rate = data['frame_rate'] 421 | pad_slices = self.pad(data['slices']) 422 | pad_gaps = self.pad(data['gaps'], self.cfg.series_max_length - 1) 423 | pad_series = self.pad(data['series']) 424 | pad_optical_flow = self.pad(data['opticalflow'], self.cfg.series_max_length - 1) 425 | pad_edge = self.pad(data['edge']) 426 | 427 | source = pad_slices.unsqueeze(1) 428 | target = torch.cat([F.pad(pad_gaps, (0, 0, 0, 1)), pad_series.view(-1, 9)], dim=-1) 429 | 430 | self.pd[idx][self.pr[idx][0].item()] = [] 431 | self.pr[idx] = self.pr[idx][1:] 432 | 433 | sample_dict = { 434 | 'source': source, 'target': target, 435 | 'optical_flow': pad_optical_flow, 436 | 'edge': pad_edge, 437 | 'frame_rate': frame_rate, 438 | 'info': info 439 | } 440 | 441 | utils.common.set_seed(int(time.time() * 1000) % (1 << 32) + index) 442 | return sample_dict, index 443 | -------------------------------------------------------------------------------- /models/layers/convolutional_rnn/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, Sequence 3 | 4 | import torch 5 | from torch.nn import Parameter 6 | from torch.nn.utils.rnn import PackedSequence 7 | 8 | from .functional import AutogradConvRNN, _conv_cell_helper 9 | from .utils import _single, _pair, _triple 10 | 11 | 12 | class ConvNdRNNBase(torch.nn.Module): 13 | def __init__(self, 14 | mode: str, 15 | in_channels: int, 16 | out_channels: int, 17 | kernel_size: Union[int, Sequence[int]], 18 | num_layers: int=1, 19 | bias: bool=True, 20 | batch_first: bool=False, 21 | dropout: float=0., 22 | bidirectional: bool=False, 23 | convndim: int=2, 24 | stride: Union[int, Sequence[int]]=1, 25 | dilation: Union[int, Sequence[int]]=1, 26 | groups: int=1): 27 | super().__init__() 28 | self.mode = mode 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.num_layers = num_layers 32 | self.bias = bias 33 | self.batch_first = batch_first 34 | self.dropout = dropout 35 | self.bidirectional = bidirectional 36 | self.convndim = convndim 37 | 38 | if convndim == 1: 39 | ntuple = _single 40 | elif convndim == 2: 41 | ntuple = _pair 42 | elif convndim == 3: 43 | ntuple = _triple 44 | else: 45 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 46 | 47 | self.kernel_size = ntuple(kernel_size) 48 | self.stride = ntuple(stride) 49 | self.dilation = ntuple(dilation) 50 | 51 | self.groups = groups 52 | 53 | num_directions = 2 if bidirectional else 1 54 | 55 | if mode in ('LSTM', 'PeepholeLSTM'): 56 | gate_size = 4 * out_channels 57 | elif mode == 'GRU': 58 | gate_size = 3 * out_channels 59 | else: 60 | gate_size = out_channels 61 | 62 | self._all_weights = [] 63 | for layer in range(num_layers): 64 | for direction in range(num_directions): 65 | layer_input_size = in_channels if layer == 0 else out_channels * num_directions 66 | w_ih = Parameter(torch.Tensor(gate_size, layer_input_size // groups, *self.kernel_size)) 67 | w_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) 68 | 69 | b_ih = Parameter(torch.Tensor(gate_size)) 70 | b_hh = Parameter(torch.Tensor(gate_size)) 71 | 72 | if mode == 'PeepholeLSTM': 73 | w_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 74 | w_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 75 | w_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 76 | layer_params = (w_ih, w_hh, w_pi, w_pf, w_po, b_ih, b_hh) 77 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 78 | 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}'] 79 | else: 80 | layer_params = (w_ih, w_hh, b_ih, b_hh) 81 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] 82 | if bias: 83 | param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] 84 | 85 | suffix = '_reverse' if direction == 1 else '' 86 | param_names = [x.format(layer, suffix) for x in param_names] 87 | 88 | for name, param in zip(param_names, layer_params): 89 | setattr(self, name, param) 90 | self._all_weights.append(param_names) 91 | 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | stdv = 1.0 / math.sqrt(self.out_channels) 96 | for weight in self.parameters(): 97 | weight.data.uniform_(-stdv, stdv) 98 | 99 | def check_forward_args(self, input, hidden, batch_sizes): 100 | is_input_packed = batch_sizes is not None 101 | expected_input_dim = (2 if is_input_packed else 3) + self.convndim 102 | if input.dim() != expected_input_dim: 103 | raise RuntimeError( 104 | 'input must have {} dimensions, got {}'.format( 105 | expected_input_dim, input.dim())) 106 | ch_dim = 1 if is_input_packed else 2 107 | if self.in_channels != input.size(ch_dim): 108 | raise RuntimeError( 109 | 'input.size({}) must be equal to in_channels . Expected {}, got {}'.format( 110 | ch_dim, self.in_channels, input.size(ch_dim))) 111 | 112 | if is_input_packed: 113 | mini_batch = int(batch_sizes[0]) 114 | else: 115 | mini_batch = input.size(0) if self.batch_first else input.size(1) 116 | 117 | num_directions = 2 if self.bidirectional else 1 118 | expected_hidden_size = (self.num_layers * num_directions, 119 | mini_batch, self.out_channels) + input.shape[ch_dim + 1:] 120 | 121 | def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): 122 | if tuple(hx.size()) != expected_hidden_size: 123 | raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) 124 | 125 | if self.mode in ('LSTM', 'PeepholeLSTM'): 126 | check_hidden_size(hidden[0], expected_hidden_size, 127 | 'Expected hidden[0] size {}, got {}') 128 | check_hidden_size(hidden[1], expected_hidden_size, 129 | 'Expected hidden[1] size {}, got {}') 130 | else: 131 | check_hidden_size(hidden, expected_hidden_size) 132 | 133 | def forward(self, input, hx=None): 134 | is_packed = isinstance(input, PackedSequence) 135 | if is_packed: 136 | input, batch_sizes = input 137 | max_batch_size = batch_sizes[0] 138 | insize = input.shape[2:] 139 | else: 140 | batch_sizes = None 141 | max_batch_size = input.size(0) if self.batch_first else input.size(1) 142 | insize = input.shape[3:] 143 | 144 | if hx is None: 145 | num_directions = 2 if self.bidirectional else 1 146 | hx = input.new_zeros(self.num_layers * num_directions, max_batch_size, self.out_channels, 147 | *insize, requires_grad=False) 148 | if self.mode in ('LSTM', 'PeepholeLSTM'): 149 | hx = (hx, hx) 150 | 151 | self.check_forward_args(input, hx, batch_sizes) 152 | func = AutogradConvRNN( 153 | self.mode, 154 | num_layers=self.num_layers, 155 | batch_first=self.batch_first, 156 | dropout=self.dropout, 157 | train=self.training, 158 | bidirectional=self.bidirectional, 159 | variable_length=batch_sizes is not None, 160 | convndim=self.convndim, 161 | stride=self.stride, 162 | dilation=self.dilation, 163 | groups=self.groups 164 | ) 165 | output, hidden = func(input, self.all_weights, hx, batch_sizes) 166 | if is_packed: 167 | output = PackedSequence(output, batch_sizes) 168 | return output, hidden 169 | 170 | def extra_repr(self): 171 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 172 | ', stride={stride}') 173 | if self.dilation != (1,) * len(self.dilation): 174 | s += ', dilation={dilation}' 175 | if self.groups != 1: 176 | s += ', groups={groups}' 177 | if self.num_layers != 1: 178 | s += ', num_layers={num_layers}' 179 | if self.bias is not True: 180 | s += ', bias={bias}' 181 | if self.batch_first is not False: 182 | s += ', batch_first={batch_first}' 183 | if self.dropout != 0: 184 | s += ', dropout={dropout}' 185 | if self.bidirectional is not False: 186 | s += ', bidirectional={bidirectional}' 187 | return s.format(**self.__dict__) 188 | 189 | def __setstate__(self, d): 190 | super(ConvNdRNNBase, self).__setstate__(d) 191 | if 'all_weights' in d: 192 | self._all_weights = d['all_weights'] 193 | if isinstance(self._all_weights[0][0], str): 194 | return 195 | num_layers = self.num_layers 196 | num_directions = 2 if self.bidirectional else 1 197 | self._all_weights = [] 198 | for layer in range(num_layers): 199 | for direction in range(num_directions): 200 | suffix = '_reverse' if direction == 1 else '' 201 | if self.mode == 'PeepholeLSTM': 202 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 203 | 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}', 204 | 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 205 | else: 206 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 207 | 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 208 | weights = [x.format(layer, suffix) for x in weights] 209 | if self.bias: 210 | self._all_weights += [weights] 211 | else: 212 | self._all_weights += [weights[:len(weights) // 2]] 213 | 214 | @property 215 | def all_weights(self): 216 | return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] 217 | 218 | 219 | class Conv1dRNN(ConvNdRNNBase): 220 | def __init__(self, 221 | in_channels: int, 222 | out_channels: int, 223 | kernel_size: Union[int, Sequence[int]], 224 | nonlinearity: str='tanh', 225 | num_layers: int=1, 226 | bias: bool=True, 227 | batch_first: bool=False, 228 | dropout: float=0., 229 | bidirectional: bool=False, 230 | stride: Union[int, Sequence[int]]=1, 231 | dilation: Union[int, Sequence[int]]=1, 232 | groups: int=1): 233 | if nonlinearity == 'tanh': 234 | mode = 'RNN_TANH' 235 | elif nonlinearity == 'relu': 236 | mode = 'RNN_RELU' 237 | else: 238 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 239 | super().__init__( 240 | mode=mode, 241 | in_channels=in_channels, 242 | out_channels=out_channels, 243 | kernel_size=kernel_size, 244 | num_layers=num_layers, 245 | bias=bias, 246 | batch_first=batch_first, 247 | dropout=dropout, 248 | bidirectional=bidirectional, 249 | convndim=1, 250 | stride=stride, 251 | dilation=dilation, 252 | groups=groups) 253 | 254 | 255 | class Conv1dPeepholeLSTM(ConvNdRNNBase): 256 | def __init__(self, 257 | in_channels: int, 258 | out_channels: int, 259 | kernel_size: Union[int, Sequence[int]], 260 | num_layers: int=1, 261 | bias: bool=True, 262 | batch_first: bool=False, 263 | dropout: float=0., 264 | bidirectional: bool=False, 265 | stride: Union[int, Sequence[int]]=1, 266 | dilation: Union[int, Sequence[int]]=1, 267 | groups: int=1): 268 | super().__init__( 269 | mode='PeepholeLSTM', 270 | in_channels=in_channels, 271 | out_channels=out_channels, 272 | kernel_size=kernel_size, 273 | num_layers=num_layers, 274 | bias=bias, 275 | batch_first=batch_first, 276 | dropout=dropout, 277 | bidirectional=bidirectional, 278 | convndim=1, 279 | stride=stride, 280 | dilation=dilation, 281 | groups=groups) 282 | 283 | 284 | class Conv1dLSTM(ConvNdRNNBase): 285 | def __init__(self, 286 | in_channels: int, 287 | out_channels: int, 288 | kernel_size: Union[int, Sequence[int]], 289 | num_layers: int=1, 290 | bias: bool=True, 291 | batch_first: bool=False, 292 | dropout: float=0., 293 | bidirectional: bool=False, 294 | stride: Union[int, Sequence[int]]=1, 295 | dilation: Union[int, Sequence[int]]=1, 296 | groups: int=1): 297 | super().__init__( 298 | mode='LSTM', 299 | in_channels=in_channels, 300 | out_channels=out_channels, 301 | kernel_size=kernel_size, 302 | num_layers=num_layers, 303 | bias=bias, 304 | batch_first=batch_first, 305 | dropout=dropout, 306 | bidirectional=bidirectional, 307 | convndim=1, 308 | stride=stride, 309 | dilation=dilation, 310 | groups=groups) 311 | 312 | 313 | class Conv1dGRU(ConvNdRNNBase): 314 | def __init__(self, 315 | in_channels: int, 316 | out_channels: int, 317 | kernel_size: Union[int, Sequence[int]], 318 | num_layers: int=1, 319 | bias: bool=True, 320 | batch_first: bool=False, 321 | dropout: float=0., 322 | bidirectional: bool=False, 323 | stride: Union[int, Sequence[int]]=1, 324 | dilation: Union[int, Sequence[int]]=1, 325 | groups: int=1): 326 | super().__init__( 327 | mode='GRU', 328 | in_channels=in_channels, 329 | out_channels=out_channels, 330 | kernel_size=kernel_size, 331 | num_layers=num_layers, 332 | bias=bias, 333 | batch_first=batch_first, 334 | dropout=dropout, 335 | bidirectional=bidirectional, 336 | convndim=1, 337 | stride=stride, 338 | dilation=dilation, 339 | groups=groups) 340 | 341 | 342 | class Conv2dRNN(ConvNdRNNBase): 343 | def __init__(self, 344 | in_channels: int, 345 | out_channels: int, 346 | kernel_size: Union[int, Sequence[int]], 347 | nonlinearity: str='tanh', 348 | num_layers: int=1, 349 | bias: bool=True, 350 | batch_first: bool=False, 351 | dropout: float=0., 352 | bidirectional: bool=False, 353 | stride: Union[int, Sequence[int]]=1, 354 | dilation: Union[int, Sequence[int]]=1, 355 | groups: int=1): 356 | if nonlinearity == 'tanh': 357 | mode = 'RNN_TANH' 358 | elif nonlinearity == 'relu': 359 | mode = 'RNN_RELU' 360 | else: 361 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 362 | super().__init__( 363 | mode=mode, 364 | in_channels=in_channels, 365 | out_channels=out_channels, 366 | kernel_size=kernel_size, 367 | num_layers=num_layers, 368 | bias=bias, 369 | batch_first=batch_first, 370 | dropout=dropout, 371 | bidirectional=bidirectional, 372 | convndim=2, 373 | stride=stride, 374 | dilation=dilation, 375 | groups=groups) 376 | 377 | 378 | class Conv2dLSTM(ConvNdRNNBase): 379 | def __init__(self, 380 | in_channels: int, 381 | out_channels: int, 382 | kernel_size: Union[int, Sequence[int]], 383 | num_layers: int=1, 384 | bias: bool=True, 385 | batch_first: bool=False, 386 | dropout: float=0., 387 | bidirectional: bool=False, 388 | stride: Union[int, Sequence[int]]=1, 389 | dilation: Union[int, Sequence[int]]=1, 390 | groups: int=1): 391 | super().__init__( 392 | mode='LSTM', 393 | in_channels=in_channels, 394 | out_channels=out_channels, 395 | kernel_size=kernel_size, 396 | num_layers=num_layers, 397 | bias=bias, 398 | batch_first=batch_first, 399 | dropout=dropout, 400 | bidirectional=bidirectional, 401 | convndim=2, 402 | stride=stride, 403 | dilation=dilation, 404 | groups=groups) 405 | 406 | 407 | class Conv2dPeepholeLSTM(ConvNdRNNBase): 408 | def __init__(self, 409 | in_channels: int, 410 | out_channels: int, 411 | kernel_size: Union[int, Sequence[int]], 412 | num_layers: int=1, 413 | bias: bool=True, 414 | batch_first: bool=False, 415 | dropout: float=0., 416 | bidirectional: bool=False, 417 | stride: Union[int, Sequence[int]]=1, 418 | dilation: Union[int, Sequence[int]]=1, 419 | groups: int=1): 420 | super().__init__( 421 | mode='PeepholeLSTM', 422 | in_channels=in_channels, 423 | out_channels=out_channels, 424 | kernel_size=kernel_size, 425 | num_layers=num_layers, 426 | bias=bias, 427 | batch_first=batch_first, 428 | dropout=dropout, 429 | bidirectional=bidirectional, 430 | convndim=2, 431 | stride=stride, 432 | dilation=dilation, 433 | groups=groups) 434 | 435 | 436 | class Conv2dGRU(ConvNdRNNBase): 437 | def __init__(self, 438 | in_channels: int, 439 | out_channels: int, 440 | kernel_size: Union[int, Sequence[int]], 441 | num_layers: int=1, 442 | bias: bool=True, 443 | batch_first: bool=False, 444 | dropout: float=0., 445 | bidirectional: bool=False, 446 | stride: Union[int, Sequence[int]]=1, 447 | dilation: Union[int, Sequence[int]]=1, 448 | groups: int=1): 449 | super().__init__( 450 | mode='GRU', 451 | in_channels=in_channels, 452 | out_channels=out_channels, 453 | kernel_size=kernel_size, 454 | num_layers=num_layers, 455 | bias=bias, 456 | batch_first=batch_first, 457 | dropout=dropout, 458 | bidirectional=bidirectional, 459 | convndim=2, 460 | stride=stride, 461 | dilation=dilation, 462 | groups=groups) 463 | 464 | 465 | class Conv3dRNN(ConvNdRNNBase): 466 | def __init__(self, 467 | in_channels: int, 468 | out_channels: int, 469 | kernel_size: Union[int, Sequence[int]], 470 | nonlinearity: str='tanh', 471 | num_layers: int=1, 472 | bias: bool=True, 473 | batch_first: bool=False, 474 | dropout: float=0., 475 | bidirectional: bool=False, 476 | stride: Union[int, Sequence[int]]=1, 477 | dilation: Union[int, Sequence[int]]=1, 478 | groups: int=1): 479 | if nonlinearity == 'tanh': 480 | mode = 'RNN_TANH' 481 | elif nonlinearity == 'relu': 482 | mode = 'RNN_RELU' 483 | else: 484 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 485 | super().__init__( 486 | mode=mode, 487 | in_channels=in_channels, 488 | out_channels=out_channels, 489 | kernel_size=kernel_size, 490 | num_layers=num_layers, 491 | bias=bias, 492 | batch_first=batch_first, 493 | dropout=dropout, 494 | bidirectional=bidirectional, 495 | convndim=3, 496 | stride=stride, 497 | dilation=dilation, 498 | groups=groups) 499 | 500 | 501 | class Conv3dLSTM(ConvNdRNNBase): 502 | def __init__(self, 503 | in_channels: int, 504 | out_channels: int, 505 | kernel_size: Union[int, Sequence[int]], 506 | num_layers: int=1, 507 | bias: bool=True, 508 | batch_first: bool=False, 509 | dropout: float=0., 510 | bidirectional: bool=False, 511 | stride: Union[int, Sequence[int]]=1, 512 | dilation: Union[int, Sequence[int]]=1, 513 | groups: int=1): 514 | super().__init__( 515 | mode='LSTM', 516 | in_channels=in_channels, 517 | out_channels=out_channels, 518 | kernel_size=kernel_size, 519 | num_layers=num_layers, 520 | bias=bias, 521 | batch_first=batch_first, 522 | dropout=dropout, 523 | bidirectional=bidirectional, 524 | convndim=3, 525 | stride=stride, 526 | dilation=dilation, 527 | groups=groups) 528 | 529 | 530 | class Conv3dPeepholeLSTM(ConvNdRNNBase): 531 | def __init__(self, 532 | in_channels: int, 533 | out_channels: int, 534 | kernel_size: Union[int, Sequence[int]], 535 | num_layers: int=1, 536 | bias: bool=True, 537 | batch_first: bool=False, 538 | dropout: float=0., 539 | bidirectional: bool=False, 540 | stride: Union[int, Sequence[int]]=1, 541 | dilation: Union[int, Sequence[int]]=1, 542 | groups: int=1): 543 | super().__init__( 544 | mode='PeepholeLSTM', 545 | in_channels=in_channels, 546 | out_channels=out_channels, 547 | kernel_size=kernel_size, 548 | num_layers=num_layers, 549 | bias=bias, 550 | batch_first=batch_first, 551 | dropout=dropout, 552 | bidirectional=bidirectional, 553 | convndim=3, 554 | stride=stride, 555 | dilation=dilation, 556 | groups=groups) 557 | 558 | 559 | class Conv3dGRU(ConvNdRNNBase): 560 | def __init__(self, 561 | in_channels: int, 562 | out_channels: int, 563 | kernel_size: Union[int, Sequence[int]], 564 | num_layers: int=1, 565 | bias: bool=True, 566 | batch_first: bool=False, 567 | dropout: float=0., 568 | bidirectional: bool=False, 569 | stride: Union[int, Sequence[int]]=1, 570 | dilation: Union[int, Sequence[int]]=1, 571 | groups: int=1): 572 | super().__init__( 573 | mode='GRU', 574 | in_channels=in_channels, 575 | out_channels=out_channels, 576 | kernel_size=kernel_size, 577 | num_layers=num_layers, 578 | bias=bias, 579 | batch_first=batch_first, 580 | dropout=dropout, 581 | bidirectional=bidirectional, 582 | convndim=3, 583 | stride=stride, 584 | dilation=dilation, 585 | groups=groups) 586 | 587 | 588 | class ConvRNNCellBase(torch.nn.Module): 589 | def __init__(self, 590 | mode: str, 591 | in_channels: int, 592 | out_channels: int, 593 | kernel_size: Union[int, Sequence[int]], 594 | bias: bool=True, 595 | convndim: int=2, 596 | stride: Union[int, Sequence[int]]=1, 597 | dilation: Union[int, Sequence[int]]=1, 598 | groups: int=1 599 | ): 600 | super().__init__() 601 | self.mode = mode 602 | self.in_channels = in_channels 603 | self.out_channels = out_channels 604 | self.bias = bias 605 | self.convndim = convndim 606 | 607 | if convndim == 1: 608 | ntuple = _single 609 | elif convndim == 2: 610 | ntuple = _pair 611 | elif convndim == 3: 612 | ntuple = _triple 613 | else: 614 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 615 | 616 | self.kernel_size = ntuple(kernel_size) 617 | self.stride = ntuple(stride) 618 | self.dilation = ntuple(dilation) 619 | 620 | self.groups = groups 621 | 622 | if mode in ('LSTM', 'PeepholeLSTM'): 623 | gate_size = 4 * out_channels 624 | elif mode == 'GRU': 625 | gate_size = 3 * out_channels 626 | else: 627 | gate_size = out_channels 628 | 629 | self.weight_ih = Parameter(torch.Tensor(gate_size, in_channels // groups, *self.kernel_size)) 630 | self.weight_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) 631 | 632 | if bias: 633 | self.bias_ih = Parameter(torch.Tensor(gate_size)) 634 | self.bias_hh = Parameter(torch.Tensor(gate_size)) 635 | else: 636 | self.register_parameter('bias_ih', None) 637 | self.register_parameter('bias_hh', None) 638 | 639 | if mode == 'PeepholeLSTM': 640 | self.weight_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 641 | self.weight_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 642 | self.weight_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 643 | 644 | self.reset_parameters() 645 | 646 | def extra_repr(self): 647 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 648 | ', stride={stride}') 649 | if self.dilation != (1,) * len(self.dilation): 650 | s += ', dilation={dilation}' 651 | if self.groups != 1: 652 | s += ', groups={groups}' 653 | if self.bias is not True: 654 | s += ', bias={bias}' 655 | if self.bidirectional is not False: 656 | s += ', bidirectional={bidirectional}' 657 | return s.format(**self.__dict__) 658 | 659 | def check_forward_input(self, input): 660 | if input.size(1) != self.in_channels: 661 | raise RuntimeError( 662 | "input has inconsistent channels: got {}, expected {}".format( 663 | input.size(1), self.in_channels)) 664 | 665 | def check_forward_hidden(self, input, hx, hidden_label=''): 666 | if input.size(0) != hx.size(0): 667 | raise RuntimeError( 668 | "Input batch size {} doesn't match hidden{} batch size {}".format( 669 | input.size(0), hidden_label, hx.size(0))) 670 | 671 | if hx.size(1) != self.out_channels: 672 | raise RuntimeError( 673 | "hidden{} has inconsistent hidden_size: got {}, expected {}".format( 674 | hidden_label, hx.size(1), self.out_channels)) 675 | 676 | def reset_parameters(self): 677 | stdv = 1.0 / math.sqrt(self.out_channels) 678 | for weight in self.parameters(): 679 | weight.data.uniform_(-stdv, stdv) 680 | 681 | def forward(self, input, hx=None): 682 | self.check_forward_input(input) 683 | 684 | if hx is None: 685 | batch_size = input.size(0) 686 | insize = input.shape[2:] 687 | hx = input.new_zeros(batch_size, self.out_channels, *insize, requires_grad=False) 688 | if self.mode in ('LSTM', 'PeepholeLSTM'): 689 | hx = (hx, hx) 690 | if self.mode in ('LSTM', 'PeepholeLSTM'): 691 | self.check_forward_hidden(input, hx[0]) 692 | self.check_forward_hidden(input, hx[1]) 693 | else: 694 | self.check_forward_hidden(input, hx) 695 | 696 | cell = _conv_cell_helper( 697 | self.mode, 698 | convndim=self.convndim, 699 | stride=self.stride, 700 | dilation=self.dilation, 701 | groups=self.groups) 702 | if self.mode == 'PeepholeLSTM': 703 | return cell( 704 | input, hx, 705 | self.weight_ih, self.weight_hh, self.weight_pi, self.weight_pf, self.weight_po, 706 | self.bias_ih, self.bias_hh 707 | ) 708 | else: 709 | return cell( 710 | input, hx, 711 | self.weight_ih, self.weight_hh, 712 | self.bias_ih, self.bias_hh, 713 | ) 714 | 715 | 716 | class Conv1dRNNCell(ConvRNNCellBase): 717 | def __init__(self, 718 | in_channels: int, 719 | out_channels: int, 720 | kernel_size: Union[int, Sequence[int]], 721 | nonlinearity: str='tanh', 722 | bias: bool=True, 723 | stride: Union[int, Sequence[int]]=1, 724 | dilation: Union[int, Sequence[int]]=1, 725 | groups: int=1 726 | ): 727 | if nonlinearity == 'tanh': 728 | mode = 'RNN_TANH' 729 | elif nonlinearity == 'relu': 730 | mode = 'RNN_RELU' 731 | else: 732 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 733 | super().__init__( 734 | mode=mode, 735 | in_channels=in_channels, 736 | out_channels=out_channels, 737 | kernel_size=kernel_size, 738 | bias=bias, 739 | convndim=1, 740 | stride=stride, 741 | dilation=dilation, 742 | groups=groups 743 | ) 744 | 745 | 746 | class Conv1dLSTMCell(ConvRNNCellBase): 747 | def __init__(self, 748 | in_channels: int, 749 | out_channels: int, 750 | kernel_size: Union[int, Sequence[int]], 751 | bias: bool=True, 752 | stride: Union[int, Sequence[int]]=1, 753 | dilation: Union[int, Sequence[int]]=1, 754 | groups: int=1 755 | ): 756 | super().__init__( 757 | mode='LSTM', 758 | in_channels=in_channels, 759 | out_channels=out_channels, 760 | kernel_size=kernel_size, 761 | bias=bias, 762 | convndim=1, 763 | stride=stride, 764 | dilation=dilation, 765 | groups=groups 766 | ) 767 | 768 | 769 | class Conv1dPeepholeLSTMCell(ConvRNNCellBase): 770 | def __init__(self, 771 | in_channels: int, 772 | out_channels: int, 773 | kernel_size: Union[int, Sequence[int]], 774 | bias: bool=True, 775 | stride: Union[int, Sequence[int]]=1, 776 | dilation: Union[int, Sequence[int]]=1, 777 | groups: int=1 778 | ): 779 | super().__init__( 780 | mode='PeepholeLSTM', 781 | in_channels=in_channels, 782 | out_channels=out_channels, 783 | kernel_size=kernel_size, 784 | bias=bias, 785 | convndim=1, 786 | stride=stride, 787 | dilation=dilation, 788 | groups=groups 789 | ) 790 | 791 | 792 | class Conv1dGRUCell(ConvRNNCellBase): 793 | def __init__(self, 794 | in_channels: int, 795 | out_channels: int, 796 | kernel_size: Union[int, Sequence[int]], 797 | bias: bool=True, 798 | stride: Union[int, Sequence[int]]=1, 799 | dilation: Union[int, Sequence[int]]=1, 800 | groups: int=1 801 | ): 802 | super().__init__( 803 | mode='GRU', 804 | in_channels=in_channels, 805 | out_channels=out_channels, 806 | kernel_size=kernel_size, 807 | bias=bias, 808 | convndim=1, 809 | stride=stride, 810 | dilation=dilation, 811 | groups=groups 812 | ) 813 | 814 | 815 | class Conv2dRNNCell(ConvRNNCellBase): 816 | def __init__(self, 817 | in_channels: int, 818 | out_channels: int, 819 | kernel_size: Union[int, Sequence[int]], 820 | nonlinearity: str='tanh', 821 | bias: bool=True, 822 | stride: Union[int, Sequence[int]]=1, 823 | dilation: Union[int, Sequence[int]]=1, 824 | groups: int=1 825 | ): 826 | if nonlinearity == 'tanh': 827 | mode = 'RNN_TANH' 828 | elif nonlinearity == 'relu': 829 | mode = 'RNN_RELU' 830 | else: 831 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 832 | super().__init__( 833 | mode=mode, 834 | in_channels=in_channels, 835 | out_channels=out_channels, 836 | kernel_size=kernel_size, 837 | bias=bias, 838 | convndim=2, 839 | stride=stride, 840 | dilation=dilation, 841 | groups=groups 842 | ) 843 | 844 | 845 | class Conv2dLSTMCell(ConvRNNCellBase): 846 | def __init__(self, 847 | in_channels: int, 848 | out_channels: int, 849 | kernel_size: Union[int, Sequence[int]], 850 | bias: bool=True, 851 | stride: Union[int, Sequence[int]]=1, 852 | dilation: Union[int, Sequence[int]]=1, 853 | groups: int=1 854 | ): 855 | super().__init__( 856 | mode='LSTM', 857 | in_channels=in_channels, 858 | out_channels=out_channels, 859 | kernel_size=kernel_size, 860 | bias=bias, 861 | convndim=2, 862 | stride=stride, 863 | dilation=dilation, 864 | groups=groups 865 | ) 866 | 867 | 868 | class Conv2dPeepholeLSTMCell(ConvRNNCellBase): 869 | def __init__(self, 870 | in_channels: int, 871 | out_channels: int, 872 | kernel_size: Union[int, Sequence[int]], 873 | bias: bool=True, 874 | stride: Union[int, Sequence[int]]=1, 875 | dilation: Union[int, Sequence[int]]=1, 876 | groups: int=1 877 | ): 878 | super().__init__( 879 | mode='PeepholeLSTM', 880 | in_channels=in_channels, 881 | out_channels=out_channels, 882 | kernel_size=kernel_size, 883 | bias=bias, 884 | convndim=2, 885 | stride=stride, 886 | dilation=dilation, 887 | groups=groups 888 | ) 889 | 890 | 891 | class Conv2dGRUCell(ConvRNNCellBase): 892 | def __init__(self, 893 | in_channels: int, 894 | out_channels: int, 895 | kernel_size: Union[int, Sequence[int]], 896 | bias: bool=True, 897 | stride: Union[int, Sequence[int]]=1, 898 | dilation: Union[int, Sequence[int]]=1, 899 | groups: int=1 900 | ): 901 | super().__init__( 902 | mode='GRU', 903 | in_channels=in_channels, 904 | out_channels=out_channels, 905 | kernel_size=kernel_size, 906 | bias=bias, 907 | convndim=2, 908 | stride=stride, 909 | dilation=dilation, 910 | groups=groups 911 | ) 912 | 913 | 914 | class Conv3dRNNCell(ConvRNNCellBase): 915 | def __init__(self, 916 | in_channels: int, 917 | out_channels: int, 918 | kernel_size: Union[int, Sequence[int]], 919 | nonlinearity: str='tanh', 920 | bias: bool=True, 921 | stride: Union[int, Sequence[int]]=1, 922 | dilation: Union[int, Sequence[int]]=1, 923 | groups: int=1 924 | ): 925 | if nonlinearity == 'tanh': 926 | mode = 'RNN_TANH' 927 | elif nonlinearity == 'relu': 928 | mode = 'RNN_RELU' 929 | else: 930 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 931 | super().__init__( 932 | mode=mode, 933 | in_channels=in_channels, 934 | out_channels=out_channels, 935 | kernel_size=kernel_size, 936 | bias=bias, 937 | convndim=3, 938 | stride=stride, 939 | dilation=dilation, 940 | groups=groups 941 | ) 942 | 943 | 944 | class Conv3dLSTMCell(ConvRNNCellBase): 945 | def __init__(self, 946 | in_channels: int, 947 | out_channels: int, 948 | kernel_size: Union[int, Sequence[int]], 949 | bias: bool=True, 950 | stride: Union[int, Sequence[int]]=1, 951 | dilation: Union[int, Sequence[int]]=1, 952 | groups: int=1 953 | ): 954 | super().__init__( 955 | mode='LSTM', 956 | in_channels=in_channels, 957 | out_channels=out_channels, 958 | kernel_size=kernel_size, 959 | bias=bias, 960 | convndim=3, 961 | stride=stride, 962 | dilation=dilation, 963 | groups=groups 964 | ) 965 | 966 | 967 | class Conv3dPeepholeLSTMCell(ConvRNNCellBase): 968 | def __init__(self, 969 | in_channels: int, 970 | out_channels: int, 971 | kernel_size: Union[int, Sequence[int]], 972 | bias: bool=True, 973 | stride: Union[int, Sequence[int]]=1, 974 | dilation: Union[int, Sequence[int]]=1, 975 | groups: int=1 976 | ): 977 | super().__init__( 978 | mode='PeepholeLSTM', 979 | in_channels=in_channels, 980 | out_channels=out_channels, 981 | kernel_size=kernel_size, 982 | bias=bias, 983 | convndim=3, 984 | stride=stride, 985 | dilation=dilation, 986 | groups=groups 987 | ) 988 | 989 | 990 | class Conv3dGRUCell(ConvRNNCellBase): 991 | def __init__(self, 992 | in_channels: int, 993 | out_channels: int, 994 | kernel_size: Union[int, Sequence[int]], 995 | bias: bool=True, 996 | stride: Union[int, Sequence[int]]=1, 997 | dilation: Union[int, Sequence[int]]=1, 998 | groups: int=1 999 | ): 1000 | super().__init__( 1001 | mode='GRU', 1002 | in_channels=in_channels, 1003 | out_channels=out_channels, 1004 | kernel_size=kernel_size, 1005 | bias=bias, 1006 | convndim=3, 1007 | stride=stride, 1008 | dilation=dilation, 1009 | groups=groups 1010 | ) 1011 | --------------------------------------------------------------------------------