├── dfibert ├── envs │ ├── __init__.py │ ├── _state.py │ ├── neuroanatomical_utils.py │ ├── NARLTractEnvironment.py │ └── RLTractEnvironment_fast.py ├── ext │ ├── __init__.py │ └── soft_dtw_cuda.py ├── __init__.py ├── data │ ├── exceptions.py │ ├── postprocessing.py │ └── __init__.py ├── dataset │ ├── exceptions.py │ ├── processing.py │ └── __init__.py ├── tracker │ ├── nn │ │ ├── _segment_tree.py │ │ ├── supervised_pretraining.py │ │ └── rl.py │ └── __init__.py └── util.py ├── examples ├── data │ └── ismrm_seeds_CST.npy ├── workflow_dFTlib.py ├── train_dqn_supervised.py ├── train.py ├── mlp_training.py └── rlWorkflow.ipynb ├── .gitignore ├── README.md └── requirements.txt /dfibert/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dfibert/ext/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/data/ismrm_seeds_CST.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nih23/deepFibreTracking/HEAD/examples/data/ismrm_seeds_CST.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .vscode/ 3 | __pycache__/ 4 | *.py[cod] 5 | /cache/ 6 | /data/ 7 | /examples/data/ 8 | *.pt 9 | *.vtk -------------------------------------------------------------------------------- /dfibert/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import dataset 3 | from . import envs 4 | from . import tracker 5 | from . import util -------------------------------------------------------------------------------- /dfibert/envs/_state.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class TractographyState: 5 | def __init__(self, coordinate, interpolFuncHandle): 6 | self.coordinate = coordinate 7 | self.interpolFuncHandle = interpolFuncHandle 8 | self.interpolatedDWI = None 9 | 10 | def getCoordinate(self): 11 | return self.coordinate 12 | 13 | def getValue(self): 14 | if self.interpolatedDWI is None: 15 | # interpolate DWI value at self.coordinate 16 | self.interpolatedDWI = self.interpolFuncHandle(self.coordinate) 17 | return self.interpolatedDWI 18 | 19 | def __add__(self, other): 20 | return self.getCoordinate() + other.getCoordinate() 21 | 22 | def __sub__(self, other): 23 | return self.getCoordinate() - other.getCoordinate() 24 | -------------------------------------------------------------------------------- /dfibert/data/exceptions.py: -------------------------------------------------------------------------------- 1 | class PointOutsideOfDWIError(LookupError): 2 | """ 3 | Error thrown if given points are outside of the DWI-Image. 4 | This can be bypassed by passing `ignore_outside_points = True` 5 | to the raising function. However, it should be noted that this 6 | is not recommendable behaviour. 7 | Attributes 8 | ---------- 9 | data_container : DataContainer 10 | The `DataContainer` whose DWI-Image is too small to cover the points. 11 | points: ndarray 12 | The point array which is responsible for raising the error. 13 | affected_points: ndarray 14 | The affected points beingn outside of the DWI-image. 15 | """ 16 | 17 | def __init__(self, data_container, points, affected_points): 18 | """ 19 | Parameters 20 | ---------- 21 | data_container : DataContainer 22 | The `DataContainer` whose DWI-Image is too small to cover the points. 23 | points: ndarray 24 | The point array which is responsible for raising the error. 25 | affected_points: ndarray 26 | The affected points beingn outside of the DWI-image. 27 | """ 28 | self.data_container = data_container 29 | self.points = points 30 | self.affected_points = affected_points 31 | super().__init__(("While parsing {no_points} points for further processing, " 32 | "it became apparent that {aff} of the points " 33 | "doesn't lay inside of DataContainer 'xyz'.") 34 | .format(no_points=len(points), aff=affected_points)) 35 | -------------------------------------------------------------------------------- /examples/workflow_dFTlib.py: -------------------------------------------------------------------------------- 1 | """Just example code as explanation. Usable for testing.""" 2 | 3 | from dfibert.data import DataPreprocessor 4 | from dfibert.data.postprocessing import Resample100, SphericalHarmonics 5 | from dfibert.dataset import StreamlineDataset, ConcatenatedDataset 6 | from dfibert.tracker import get_csd_streamlines, save_streamlines, load_streamlines, get_dti_streamlines, filtered_streamlines_by_length 7 | from dfibert.dataset.processing import RegressionProcessing, ClassificationProcessing 8 | 9 | 10 | def main(): 11 | """Main method""" 12 | hcp_data = DataPreprocessor().get_hcp("data/HCP/100307") 13 | ismrm_data = DataPreprocessor().get_ismrm("path/to/ismrm") 14 | print("Loaded DataContainers") 15 | hcp_sl = get_dti_streamlines(hcp_data, random_seeds=True, seeds_count=10000) 16 | 17 | save_streamlines(hcp_sl, "sls3.vtk") 18 | save_streamlines(filtered_streamlines_by_length(hcp_sl), "sls4.vtk") 19 | ismrm_sl = load_streamlines("path/to/ismrm/ground_truth") 20 | print("Tracked Streamlines") 21 | preprocessor = DataPreprocessor().normalize().fa_estimate().crop() 22 | 23 | ismrm_data = preprocessor.preprocess(ismrm_data) 24 | hcp_data = preprocessor.preprocess(hcp_data) 25 | 26 | print("Normalized and cropped data") 27 | 28 | processing = RegressionProcessing(postprocessing=Resample100()) 29 | csd = StreamlineDataset(hcp_sl, hcp_data, processing) 30 | ismrm_set = StreamlineDataset(ismrm_sl, ismrm_data, processing) 31 | 32 | dataset = ConcatenatedDataset([csd, ismrm_set]) 33 | print("Initialised Regression ") 34 | processing = ClassificationProcessing(postprocessing=SphericalHarmonics()) 35 | csd_classification = StreamlineDataset(hcp_sl, hcp_data, processing) 36 | ismrm_classification = StreamlineDataset(ismrm_sl, ismrm_data, processing) 37 | dataset_classification = ConcatenatedDataset([csd_classification, ismrm_classification]) 38 | print("Initialised Classification Datasets") 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /dfibert/dataset/exceptions.py: -------------------------------------------------------------------------------- 1 | class WrongDatasetTypePassedError(Exception): 2 | """Error thrown if `ConcatenatedDataset` retrieves wrong datasets. 3 | 4 | This means that the datasets you passed aren't exclusively IterableDatasets 5 | 6 | Attributes 7 | ---------- 8 | caller: ConcatenatedDataset 9 | The ConcatenatedDataset raising the error. 10 | dataset: BaseDataset 11 | The dataset causing the error. 12 | """ 13 | 14 | def __init__(self, concat, dataset, message): 15 | """ 16 | Parameters 17 | ---------- 18 | concat : ConcatenatedDataset 19 | The dataset raising the error. 20 | dataset : BaseDataset 21 | The dataset responsible for the error. 22 | message : str 23 | Your specific error message. 24 | """ 25 | self.caller = concat 26 | self.dataset = dataset 27 | super().__init__(message) 28 | 29 | class FeatureShapesNotEqualError(Exception): 30 | """Error thrown if FeatureShapes of `ConcatenatedDataset` are not equal, but requested. 31 | 32 | This error only occurs, if `ConcatenatedDataset().get_feature_shape` is called, 33 | and the Datasets in the ConcatenatedDataset doesn't have equal feature shapes. 34 | 35 | Attributes 36 | ---------- 37 | index: int 38 | The index of the BaseDataset responsible for the error. 39 | shape1: tuple 40 | The shape of the reference dataset. 41 | shape2: tuple 42 | The shape of the different dataset. 43 | """ 44 | def __init__(self, index, s1, s2): 45 | """ 46 | Parameters 47 | ---------- 48 | index : int 49 | The index of the BaseDataset responsible for the error. 50 | s1 : tuple 51 | The shape of the reference dataset. 52 | s2 : tuple 53 | The shape of the dataset causing the error. 54 | """ 55 | self.shape1 = s1 56 | self.shape2 = s2 57 | self.index = index 58 | super().__init__(("The shape of the dataset {idx} ({s2}) " 59 | "is not equal to the base shape of the reference dataset 0 ({s1})" 60 | ).format(idx=index, s2=s2, s1=s1)) -------------------------------------------------------------------------------- /dfibert/envs/neuroanatomical_utils.py: -------------------------------------------------------------------------------- 1 | from dfibert.data import DataPreprocessor 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | import gym 8 | import numpy as np 9 | import random 10 | import os, sys 11 | 12 | sys.path.insert(0, '..') 13 | 14 | from collections import deque 15 | 16 | from dfibert.tracker.nn.rl import Agent, DQN 17 | from dfibert.tracker import save_streamlines, load_streamlines 18 | from dfibert.envs._state import TractographyState 19 | from tqdm import trange 20 | from dipy.tracking import utils 21 | import dipy.reconst.dti as dti 22 | from dipy.direction import peaks_from_model 23 | 24 | 25 | def convPoint(p, dims): 26 | dims = dims - 1 27 | return (p - dims / 2.) / (dims / 2.) 28 | 29 | 30 | def interpolate3dAt(data, positions): 31 | # Warning: data is supposed to be CxHxWxD 32 | # normalise coordinates into range [-1,1] 33 | pts = positions.float() 34 | pts = convPoint(pts, torch.tensor(data.shape[1:4]).to(data.device)) 35 | # reverse pts 36 | pts = pts[:, (2, 1, 0)] 37 | # trilinear interpolation 38 | return torch.nn.functional.grid_sample(data.unsqueeze(0), 39 | pts.unsqueeze(0).unsqueeze(0).unsqueeze(0), 40 | align_corners=False, mode="nearest") 41 | 42 | 43 | class FiberBundleDataset(Dataset): 44 | def __init__(self, path_to_files, b_val=1000, device="cpu", dataset=None): 45 | streamlines = load_streamlines(path=path_to_files) 46 | 47 | if dataset is None: 48 | preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate() 49 | dataset = preprocessor.get_ismrm(f"data/ISMRM2015/") 50 | self.dataset = dataset 51 | self.streamlines = [torch.from_numpy(self.dataset.to_ijk(sl)).to(device) for sl in streamlines] 52 | self.tractMask = torch.zeros(self.dataset.binary_mask.shape) 53 | 54 | for sl in self.streamlines: 55 | pi = torch.floor(sl).to(torch.long) 56 | self.tractMask[pi.chunk(chunks=3, dim=1)] = 1 57 | 58 | def __len__(self): 59 | return len(self.streamlines) 60 | 61 | def __getitem__(self, idx): 62 | streamline = self.streamlines[idx] 63 | sl_1 = streamline[0:-2] 64 | sl_2 = streamline[1:-1] 65 | return sl_1, sl_2 66 | -------------------------------------------------------------------------------- /dfibert/data/postprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | The postprocessing submodule of the data module hosts different options 3 | of postprocessing the DWI data. Those can be passed to datasets for further use. 4 | """ 5 | from typing import Union 6 | import numpy as np 7 | from dipy.core.sphere import Sphere 8 | from dipy.reconst.shm import real_sym_sh_mrtrix, smooth_pinv 9 | from dipy.data import get_sphere 10 | 11 | from dfibert.util import get_2D_sphere 12 | 13 | 14 | class PostprocessingOption(object): 15 | def process(self, data_container, points, dwi): 16 | raise NotImplementedError() 17 | 18 | 19 | class Raw(PostprocessingOption): 20 | """Does no resampling.""" 21 | def process(self, data_container, points, dwi): 22 | return dwi 23 | 24 | 25 | class SphericalHarmonics(PostprocessingOption): 26 | def __init__(self, sh_order=8, smooth=0.006): 27 | """ 28 | Resamples the data using spherical harmonics 29 | 30 | The resampled data is calculated using the DWI Sphere. 31 | 32 | Parameters 33 | ---------- 34 | sh_order 35 | The order of the spherical harmonics 36 | smooth 37 | """ 38 | super().__init__() 39 | self.sh_order = sh_order 40 | self.smooth = smooth 41 | 42 | def process(self, data_container, points, dwi): 43 | raw_sphere = Sphere(xyz=data_container.bvecs) 44 | 45 | real_sh, _, n = real_sym_sh_mrtrix(self.sh_order, raw_sphere.theta, raw_sphere.phi) 46 | l = -n * (n + 1) 47 | inv_b = smooth_pinv(real_sh, np.sqrt(self.smooth) * l) 48 | data_sh = np.dot(dwi, inv_b.T) 49 | 50 | return data_sh 51 | 52 | 53 | class Resample(SphericalHarmonics): 54 | def __init__(self, sh_order=8, smooth=0.006, sphere: Union[Sphere, str] = "repulsion100"): 55 | """ 56 | Resample the values according to given sphere or directions. 57 | 58 | The real sphere data is resampled to the new sphere, then spherical harmonics are applied. 59 | 60 | Parameters 61 | ---------- 62 | sh_order 63 | The order of the spherical harmonics 64 | smooth 65 | sphere 66 | The sphere we are resampling to 67 | """ 68 | super().__init__(sh_order=sh_order, smooth=smooth) 69 | if isinstance(sphere, Sphere): 70 | self.sphere = sphere 71 | else: # get with name 72 | self.sphere = get_sphere(sphere) 73 | self.real_sh, _, _ = real_sym_sh_mrtrix(self.sh_order, self.sphere.theta, self.sphere.phi) 74 | 75 | def process(self, data_container, points, dwi): 76 | data_sh = super().process(data_container, points, dwi) 77 | data_resampled = np.dot(data_sh, self.real_sh.T) 78 | 79 | return data_resampled 80 | 81 | 82 | class Resample100(Resample): 83 | def __init__(self, sh_order=8, smooth=0.006): 84 | """ 85 | Resamples the value to 100 directions with the repulsion100 sphere. 86 | 87 | Just a shortcut for the `resample` option. 88 | 89 | Parameters 90 | ---------- 91 | sh_order 92 | The order of the spherical harmonics 93 | smooth 94 | """ 95 | super().__init__(sh_order=sh_order, smooth=smooth, sphere="repulsion100") 96 | 97 | 98 | class Resample2D(Resample): 99 | def __init__(self, sh_order=8, smooth=0.006, no_thetas=16, no_phis=16): 100 | """ 101 | Resamples the value to directions with the 2D sphere. 102 | Just a shortcut for the `resample` option with 2D sphere. 103 | 104 | See `dfibert.util.get_2D_sphere` for more details on how the 2D sphere is generated. 105 | 106 | Parameters 107 | ---------- 108 | sh_order 109 | The order of the spherical harmonics 110 | smooth 111 | no_thetas 112 | the number of thetas to use for the sphere generation 113 | no_phis 114 | the number of phis to use for the sphere generation 115 | """ 116 | super().__init__(sh_order=sh_order, smooth=smooth, sphere=get_2D_sphere(no_phis=no_phis, no_thetas=no_thetas)) -------------------------------------------------------------------------------- /examples/train_dqn_supervised.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | from tqdm.autonotebook import tqdm 3 | import dipy.reconst.dti as dti 4 | from dipy.tracking import utils 5 | 6 | import time 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | import numpy as np 12 | import random 13 | import sys 14 | sys.path.insert(0,'..') 15 | 16 | import dfibert.envs.RLTractEnvironment_fast as RLTe 17 | from dfibert.tracker.nn.rl import DQN 18 | from dfibert.envs._state import TractographyState 19 | from dfibert.tracker import save_streamlines 20 | 21 | 22 | class SupervisedRewardDataset(Dataset): 23 | def __init__(self, inp, outp): 24 | self.inp = inp 25 | self.outp = outp 26 | 27 | def __getitem__(self, index): 28 | return (self.inp[index,], self.outp[index,]) 29 | 30 | def __len__(self): 31 | return len(self.inp) 32 | 33 | 34 | def save_model(path_checkpoint, model, epoch, loss, n_actions): 35 | print("Writing checkpoint to %s" % (path_checkpoint)) 36 | checkpoint = {} 37 | checkpoint["model"] = model.state_dict() 38 | checkpoint["epoch"] = epoch 39 | checkpoint["loss"] = loss 40 | checkpoint["n_actions"] = n_actions 41 | torch.save(checkpoint, path_checkpoint) 42 | 43 | 44 | #------------------ 45 | def train(): 46 | batch_size = 1024 47 | seed_selection_fa_Threshold = 0.1 48 | epochs = 1000 49 | lr = 1e-4 50 | 51 | 52 | #------------------ 53 | seeds_CST = np.load('data/ismrm_seeds_CST.npy') 54 | seeds_CST = torch.from_numpy(seeds_CST) 55 | env = RLTe.RLTractEnvironment(dataset = 'ISMRM', step_width=0.8, 56 | device = 'cuda:0', seeds = seeds_CST, action_space=20, 57 | odf_mode = "DTI", 58 | fa_threshold=0.2, tracking_in_RAS=False) 59 | 60 | 61 | #------------------ 62 | # initialize points for training 63 | dti_model = dti.TensorModel(env.dataset.gtab, fit_method='LS') 64 | dti_fit = dti_model.fit(env.dataset.dwi, mask=env.dataset.binary_mask) 65 | 66 | fa_img = dti_fit.fa 67 | seed_mask = fa_img.copy() 68 | seed_mask[seed_mask >= seed_selection_fa_Threshold] = 1 69 | seed_mask[seed_mask < seed_selection_fa_Threshold] = 0 70 | 71 | seeds = utils.seeds_from_mask(seed_mask, affine=np.eye(4), density=1) # tracking in IJK 72 | seeds = torch.from_numpy(seeds).to(env.device) 73 | print("We got %d seeds" % (len(seeds))) 74 | 75 | #------------------ 76 | # uniformly sample a point within our brain mask 77 | 78 | noActions = 20 79 | dimDWI = 100 80 | noPoints = len(seeds) 81 | 82 | dwi_data = torch.zeros([noPoints,3,3,3,dimDWI], device = env.device) 83 | rewards = torch.zeros([noPoints,noActions], device = env.device) 84 | 85 | for i in tqdm(range(noPoints), ascii=True): 86 | pos = seeds[i] 87 | 88 | # instantiate TractographyState 89 | state = TractographyState(pos, None) 90 | 91 | # call env.reward_for_state 92 | reward_ = env.reward_for_state(state, direction = "forward", prev_direction = None) 93 | 94 | #dwi_ = env.dwi_interpolator(pos.to(env.device)) 95 | dwi_interpol_ = env.interpolate_dwi_at_state(pos.to(env.device)) 96 | 97 | dwi_data[i,:,:,:,:] = dwi_interpol_ 98 | rewards[i,:] = reward_ 99 | 100 | 101 | ds = SupervisedRewardDataset(dwi_data, rewards) 102 | train_loader = DataLoader(ds, batch_size=batch_size,shuffle=True) 103 | 104 | 105 | model = DQN(input_shape = 3*3*3*dimDWI, n_actions=noActions, hidden_size=128, num_hidden=3).to(env.device) 106 | 107 | # optimizer 108 | optimizer = torch.optim.Adam(model.parameters(),lr=lr) 109 | criterion = torch.nn.MSELoss() 110 | 111 | begin = time.time() 112 | with trange(epochs, unit="epochs", ascii=True) as pbar: 113 | for epoch in pbar: 114 | # Set current and total loss value 115 | current_loss = 0.0 116 | 117 | model.train() # Optional when not using Model Specific layer 118 | for i, data in enumerate(train_loader,0): 119 | x_batch = data[0].view([data[0].shape[0], -1]) 120 | y_batch = data[1] 121 | optimizer.zero_grad() 122 | pred = model(x_batch) 123 | 124 | loss = criterion(pred,y_batch) 125 | 126 | loss.backward() 127 | optimizer.step() 128 | 129 | if( (epoch % 50) == 0): 130 | p_cp = 'checkpoints/defi_super_%d.pt' % (epoch) 131 | save_model(p_cp, model, epoch, loss.item(), noActions) 132 | 133 | pbar.set_postfix(loss=loss.item()) 134 | 135 | end = time.time() 136 | print("time:", end - begin) 137 | 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | train() -------------------------------------------------------------------------------- /dfibert/tracker/nn/_segment_tree.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Source: https://github.com/Curt-Park/rainbow-is-all-you-need 3 | """Segment tree for Prioritized Replay Buffer.""" 4 | 5 | import operator 6 | from typing import Callable 7 | 8 | 9 | class SegmentTree: 10 | """ Create SegmentTree. 11 | 12 | Taken from OpenAI baselines github repository: 13 | https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py 14 | 15 | Attributes: 16 | capacity (int) 17 | tree (list) 18 | operation (function) 19 | 20 | """ 21 | 22 | def __init__(self, capacity: int, operation: Callable, init_value: float): 23 | """Initialization. 24 | 25 | Args: 26 | capacity (int) 27 | operation (function) 28 | init_value (float) 29 | 30 | """ 31 | assert ( 32 | capacity > 0 and capacity & (capacity - 1) == 0 33 | ), "capacity must be positive and a power of 2." 34 | self.capacity = capacity 35 | self.tree = [init_value for _ in range(2 * capacity)] 36 | self.operation = operation 37 | 38 | def _operate_helper( 39 | self, start: int, end: int, node: int, node_start: int, node_end: int 40 | ) -> float: 41 | """Returns result of operation in segment.""" 42 | if start == node_start and end == node_end: 43 | return self.tree[node] 44 | mid = (node_start + node_end) // 2 45 | if end <= mid: 46 | return self._operate_helper(start, end, 2 * node, node_start, mid) 47 | else: 48 | if mid + 1 <= start: 49 | return self._operate_helper(start, end, 2 * node + 1, mid + 1, node_end) 50 | else: 51 | return self.operation( 52 | self._operate_helper(start, mid, 2 * node, node_start, mid), 53 | self._operate_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end), 54 | ) 55 | 56 | def operate(self, start: int = 0, end: int = 0) -> float: 57 | """Returns result of applying `self.operation`.""" 58 | if end <= 0: 59 | end += self.capacity 60 | end -= 1 61 | 62 | return self._operate_helper(start, end, 1, 0, self.capacity - 1) 63 | 64 | def __setitem__(self, idx: int, val: float): 65 | """Set value in tree.""" 66 | idx += self.capacity 67 | self.tree[idx] = val 68 | 69 | idx //= 2 70 | while idx >= 1: 71 | self.tree[idx] = self.operation(self.tree[2 * idx], self.tree[2 * idx + 1]) 72 | idx //= 2 73 | 74 | def __getitem__(self, idx: int) -> float: 75 | """Get real value in leaf node of tree.""" 76 | assert 0 <= idx < self.capacity 77 | 78 | return self.tree[self.capacity + idx] 79 | 80 | 81 | class SumSegmentTree(SegmentTree): 82 | """ Create SumSegmentTree. 83 | 84 | Taken from OpenAI baselines github repository: 85 | https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py 86 | 87 | """ 88 | 89 | def __init__(self, capacity: int): 90 | """Initialization. 91 | 92 | Args: 93 | capacity (int) 94 | 95 | """ 96 | super(SumSegmentTree, self).__init__( 97 | capacity=capacity, operation=operator.add, init_value=0.0 98 | ) 99 | 100 | def sum(self, start: int = 0, end: int = 0) -> float: 101 | """Returns arr[start] + ... + arr[end].""" 102 | return super(SumSegmentTree, self).operate(start, end) 103 | 104 | def retrieve(self, upperbound: float) -> int: 105 | """Find the highest index `i` about upper bound in the tree""" 106 | # TODO: Check assert case and fix bug 107 | assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound) 108 | 109 | idx = 1 110 | 111 | while idx < self.capacity: # while non-leaf 112 | left = 2 * idx 113 | right = left + 1 114 | if self.tree[left] > upperbound: 115 | idx = 2 * idx 116 | else: 117 | upperbound -= self.tree[left] 118 | idx = right 119 | return idx - self.capacity 120 | 121 | 122 | class MinSegmentTree(SegmentTree): 123 | """ Create SegmentTree. 124 | 125 | Taken from OpenAI baselines github repository: 126 | https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py 127 | 128 | """ 129 | 130 | def __init__(self, capacity: int): 131 | """Initialization. 132 | 133 | Args: 134 | capacity (int) 135 | 136 | """ 137 | super(MinSegmentTree, self).__init__( 138 | capacity=capacity, operation=min, init_value=float("inf") 139 | ) 140 | 141 | def min(self, start: int = 0, end: int = 0) -> float: 142 | """Returns min(arr[start], ..., arr[end]).""" 143 | return super(MinSegmentTree, self).operate(start, end) 144 | -------------------------------------------------------------------------------- /dfibert/tracker/nn/supervised_pretraining.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange, tqdm 2 | import dipy.reconst.dti as dti 3 | from dipy.tracking import utils 4 | 5 | import time 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | import numpy as np 11 | import random 12 | import sys 13 | sys.path.insert(0,'..') 14 | 15 | import dfibert.envs.RLTractEnvironment_fast as RLTe 16 | from dfibert.envs._state import TractographyState 17 | from dfibert.tracker import save_streamlines 18 | 19 | from copy import deepcopy 20 | import os 21 | 22 | class SupervisedRewardDataset(Dataset): 23 | def __init__(self, inp, outp): 24 | self.inp = inp 25 | self.outp = outp 26 | 27 | def __getitem__(self, index): 28 | return (self.inp[index,], self.outp[index,]) 29 | 30 | def __len__(self): 31 | return len(self.inp) 32 | 33 | 34 | def save_model(path_checkpoint, model, epoch, loss, n_actions): 35 | print("Writing checkpoint to %s" % (path_checkpoint)) 36 | checkpoint = {} 37 | checkpoint["model"] = model.state_dict() 38 | checkpoint["epoch"] = epoch 39 | checkpoint["loss"] = loss 40 | checkpoint["n_actions"] = n_actions 41 | torch.save(checkpoint, path_checkpoint) 42 | 43 | 44 | #------------------ 45 | def train(dqn, env, batch_size: int = 1024, epochs: int = 1000, lr: float = 1e-4, 46 | seed_selection_fa_Threshold: float = 0.1, 47 | path: str = './supervised_checkpoints', 48 | wandb_log: bool = False): 49 | 50 | os.makedirs(path, exist_ok=True) 51 | 52 | if wandb_log: 53 | import wandb 54 | 55 | #------------------ 56 | # initialize points for training 57 | dti_model = dti.TensorModel(env.dataset.gtab, fit_method='LS') 58 | dti_fit = dti_model.fit(env.dataset.dwi, mask=env.dataset.binary_mask) 59 | 60 | fa_img = dti_fit.fa 61 | seed_mask = fa_img.copy() 62 | seed_mask[seed_mask >= seed_selection_fa_Threshold] = 1 63 | seed_mask[seed_mask < seed_selection_fa_Threshold] = 0 64 | 65 | seeds = utils.seeds_from_mask(seed_mask, affine=np.eye(4), density=1) # tracking in IJK 66 | seeds = torch.from_numpy(seeds).to(env.device) 67 | print("We got %d seeds" % (len(seeds))) 68 | 69 | #------------------ 70 | # uniformly sample a point within our brain mask 71 | 72 | noActions = env.action_space.n 73 | #dimDWI = 100 74 | _ = env.reset() 75 | stateShape = tuple(env.state.getValue().shape) 76 | noPoints = len(seeds) 77 | 78 | #dwi_data = torch.zeros([noPoints,3,3,3,dimDWI], device = env.device) 79 | dwi_data = torch.zeros([noPoints,*stateShape], device = env.device) 80 | actions = torch.zeros([noPoints], device = env.device, dtype=torch.int64) 81 | 82 | print("Filling dataset..") 83 | for i in tqdm(range(noPoints), ascii=True): 84 | pos = seeds[i] 85 | # instantiate TractographyState 86 | state = TractographyState(pos, None) 87 | 88 | # call env.reward_for_state 89 | reward_ = env.reward_for_state(state, direction = "forward", prev_direction = None) 90 | action_ = torch.argmax(reward_) 91 | #action_ = torch.LongTensor(action_) 92 | #dwi_ = env.dwi_interpolator(pos.to(env.device)) 93 | dwi_interpol_ = env.interpolate_dwi_at_state(pos.to(env.device)) 94 | 95 | dwi_data[i,:,:,:,:] = dwi_interpol_ 96 | actions[i] = action_ 97 | 98 | print("..done!") 99 | ds = SupervisedRewardDataset(dwi_data, actions)#rewards) 100 | train_loader = DataLoader(ds, batch_size=batch_size,shuffle=True) 101 | 102 | model = deepcopy(dqn) 103 | # optimizer 104 | optimizer = torch.optim.Adam(model.parameters(),lr=lr) 105 | #criterion = torch.nn.MSELoss() 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | print("Start pretraining DQN..") 109 | begin = time.time() 110 | with trange(epochs, unit="epochs", ascii=True) as pbar: 111 | for epoch in pbar: 112 | # Set current and total loss value 113 | current_loss = 0.0 114 | 115 | model.train() # Optional when not using Model Specific layer 116 | for i, data in enumerate(train_loader,0): 117 | x_batch = data[0].view([data[0].shape[0], -1]) 118 | y_batch = data[1].type(torch.int64) 119 | optimizer.zero_grad() 120 | pred = model(x_batch) 121 | 122 | loss = criterion(pred,y_batch) 123 | if wandb_log: 124 | wandb.log({"Pretraining: supervised loss": loss}) 125 | loss.backward() 126 | optimizer.step() 127 | 128 | if loss < 0.005: 129 | break 130 | 131 | if( loss < 0.0005): 132 | print("Early stop at loss %.4f" % (loss)) 133 | p_cp = path+'defi_super_%d.pt' % (epoch) 134 | save_model(p_cp, model, epoch, loss.item(), noActions) 135 | break 136 | if( (epoch % 50) == 0): 137 | p_cp = path+'defi_super_%d.pt' % (epoch) 138 | save_model(p_cp, model, epoch, loss.item(), noActions) 139 | 140 | pbar.set_postfix(loss=loss.item()) 141 | 142 | end = time.time() 143 | print("..done! time:", end - begin) 144 | return deepcopy(model) 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning environment 2 | There are a few classes and methods which could help you building your reinforcement learning environment for tractography. 3 | 4 | ## DWI Data representation 5 | Firstly, you can use `DataContainer` objects to store and retrieve DWI data: 6 | 7 | ### Initialization and preprocessing 8 | 9 | ```python 10 | from dfibert.data import DataPreprocessor 11 | preprocessor = DataPreprocessor().denoise().fa_estimate().normalize().crop() 12 | hcp_data = preprocessor.get_hcp("/path/to/hcp/dataset/") 13 | ismrm_data = preprocessor.get_ismrm("/path/to/ismrm/") 14 | ``` 15 | 16 | ### Coordinate system transforms: 17 | ```python 18 | import numpy as np 19 | ijk_points = hcp_data.to_ijk(ras_points) # Transform to Image coordinate system 20 | ras_points = hcp_data.to_ras(ijk_points) # Transform to World RAS+ coordinate system 21 | 22 | np.array_equal(ras_points, hcp_data.to_ras(hcp_data.to_ijk(ras_points))) # True 23 | ``` 24 | 25 | ### DWI interpolation 26 | You can retrieve interpolated DWI values with the following method. If you pass `ignore_outside_points=True`, there won't be an error thrown for points outside of the DWI Image. 27 | ```python 28 | interpolated_dwi = hcp_data.get_interpolated_dwi(ras_points, ignore_outside_points=False) 29 | ``` 30 | 31 | ### Fields 32 | The fields can be helpful for checks or additional calculations based on the loaded data. 33 | ```python 34 | hcp.fa # fa values if fa_estimate was part of the preprocessing pipeline 35 | hcp.dwi 36 | hcp.t1 37 | hcp.bvals 38 | hcp.bvecs 39 | ... 40 | ``` 41 | 42 | ## Tracking and Streamline Representation 43 | 44 | You use `Tracker` objects to represent already tracked streamlines or to track streamlines using CSD or DTI: 45 | 46 | ### Loading Ground Truth Streamlines 47 | You can retrieve tracked `Tracker` Objects in multiple ways: 48 | ```python 49 | ismrm_sl = ISMRMReferenceStreamlinesTracker(ismrm_data, streamline_count=10000) 50 | ismrm_sl.track() 51 | 52 | file_sl = StreamlinesFromFileTracker("streamlines.vtk") 53 | ismrm_sl.track() 54 | 55 | hcp_sl = CSDTracker(hcp_data, random_seeds=True, fa_threshold=0.15) 56 | ismrm_sl.track() 57 | ``` 58 | Please keep in mind that the `CSDTracker` and `DTITracker` have internal caches, so your given DWI containers aren't tracked each time when you call the track method, but only if there is no corresponding cache file or the cache is deactivated. Because the cache operates on names and paths, it is important that you don't replace DWI files with others with identical names and paths without deleting the cache. 59 | 60 | ### Helpful methods 61 | ```python 62 | streamlines = file_sl.get_streamlines() # retrieve the actual streamlines 63 | filtered_streamlines = file_sl.filtered_streamlines_by_length(minimum=70) # filter streamlines 64 | 65 | hcp_sl.save_to_file("hcp_streamlines.vtk") 66 | ``` 67 | 68 | [TODO add Tracking and retrieving streamlines example]:: 69 | 70 | ## Config 71 | 72 | Furthermore, you can use the `Config` if you want to read and write your own parameters: 73 | 74 | Get the singletone with 75 | 76 | ```python 77 | config = Config.get_config() 78 | ``` 79 | 80 | You can read and set attributes: 81 | ```python 82 | config.set("section", "option1", value="value") 83 | 84 | string_config = config.get("section1", "option2", fallback="default") 85 | int_config = config.getint("section1", "numerical_value", fallback="0") 86 | float_config = config.getfloat("section", "option_f", fallback="1.2") 87 | bool_config = config.getboolean("section", "option_b", fallback="True") 88 | ``` 89 | Loading and saving is handled automatically, and the fallback values are being added to the configuration file as soon as they are requested the first time. 90 | 91 | ## Helpful methods to use for you 92 | Last, I gathered a few methods which should assist you in creating your training environment without having to reinvent the wheel regarding the data processing: 93 | 94 | ### 1. data_container.get_interpolated_dwi(points, ignore_outside_points=False) 95 | Returns the interpolated DWI at the given points while keeping the dimensions of the given points, for example, you can put in a point ndarray of size `A x B X 3` and you get an ndarray of the size `A x B x DWI` 96 | ### 2. util.get_grid(grid_dimension) 97 | Takes a 3D tuple `(x,y,z)` as `grid_dimension` and generates a grid with the given dimensions, applyable to any point or: 98 | ### 3. util.apply_rotation_matrix_to_grid(grid, rot_matrix) 99 | Takes a grid from the `util.get_grid` method and a list of rotation_matrices and applies all the rotations to the grid parallelized, returning a list of grids. 100 | ### 4. util.direction_to_classification(sphere, next_dir, include_stop=False, last_is_stop=False, stop_values=None) 101 | Takes a `dipy` Sphere with directions, and a list of directions and will return the classifier output weighted after the similarity to the given vector. If `include_stop` is `True`, you can either provide `last_is_stop` which defines that the last element fulfills the stop condition or provide a stop_value between 0 and 1 for every next_dir in `stop_values`, which will be added to the classifier output. 102 | ### 5. processing.calculate_item(data_container, previous_sl, next_dir) 103 | Takes a `DataContainer`, the streamline calculated until this point and the direction it is supposed to interpolate to. Returns an `(input, output)` tuple for the NN. Available for every processing method. 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=1_gnu 6 | argon2-cffi=20.1.0=py38h27cfd23_1 7 | async_generator=1.10=pyhd3eb1b0_0 8 | attrs=21.2.0=pyhd3eb1b0_0 9 | backcall=0.2.0=pyhd3eb1b0_0 10 | blas=1.0=openblas 11 | bleach=4.0.0=pyhd3eb1b0_0 12 | blosc=1.21.0=h9c3ff4c_0 13 | brotli=1.0.9=h7f98852_6 14 | brotli-bin=1.0.9=h7f98852_6 15 | bzip2=1.0.8=h7f98852_4 16 | c-ares=1.18.1=h7f98852_0 17 | ca-certificates=2021.10.26=h06a4308_2 18 | cached-property=1.5.2=hd8ed1ab_1 19 | cached_property=1.5.2=pyha770c72_1 20 | certifi=2021.10.8=py38h06a4308_0 21 | cffi=1.14.6=py38h400218f_0 22 | cloudpickle=2.0.0=pyhd8ed1ab_0 23 | colorama=0.4.4=pyh9f0ad1d_0 24 | cudatoolkit=10.1.243=h6bb024c_0 25 | curl=7.78.0=h1ccaba5_0 26 | cvxpy=1.1.18=py38h578d9bd_0 27 | cvxpy-base=1.1.18=py38h43a58ef_0 28 | cycler=0.11.0=pyhd8ed1ab_0 29 | cytoolz=0.11.2=py38h497a2fe_1 30 | dask-core=2021.12.0=pyhd8ed1ab_0 31 | debugpy=1.5.1=py38h295c915_0 32 | decorator=5.1.0=pyhd3eb1b0_0 33 | defusedxml=0.7.1=pyhd3eb1b0_0 34 | dipy=1.4.1=py38hb5d20a5_0 35 | ecos=2.0.8=py38h6c62de6_2 36 | entrypoints=0.3=py38_0 37 | enum34=1.1.10=py38h32f6830_2 38 | expat=2.4.2=h9c3ff4c_0 39 | fonttools=4.28.5=py38h497a2fe_0 40 | freetype=2.10.4=h0708190_1 41 | fsspec=2021.11.1=pyhd8ed1ab_0 42 | fury=0.7.1=pyh6c4a22f_0 43 | future=0.18.2=py38h578d9bd_4 44 | h5py=3.2.1=nompi_py38h9915d05_100 45 | hdf4=4.2.15=h10796ff_3 46 | hdf5=1.10.6=nompi_h6a2412b_1114 47 | icu=68.2=h9c3ff4c_0 48 | imagecodecs-lite=2019.12.3=py38h6c62de6_4 49 | imageio=2.6.1=py38_0 50 | importlib-metadata=4.8.2=py38h06a4308_0 51 | importlib_metadata=4.8.2=hd3eb1b0_0 52 | intel-openmp=2021.4.0=h06a4308_3561 53 | ipykernel=6.4.1=py38h06a4308_1 54 | ipython=7.29.0=py38hb070fc8_0 55 | ipython_genutils=0.2.0=pyhd3eb1b0_1 56 | jedi=0.18.0=py38h06a4308_1 57 | jinja2=3.0.2=pyhd3eb1b0_0 58 | joblib=1.1.0=pyhd8ed1ab_0 59 | jpeg=9d=h36c2ea0_0 60 | jsoncpp=1.8.4=hc9558a2_1002 61 | jsonschema=3.2.0=pyhd3eb1b0_2 62 | jupyter_client=7.1.0=pyhd3eb1b0_0 63 | jupyter_core=4.9.1=py38h06a4308_0 64 | jupyterlab_pygments=0.1.2=py_0 65 | kiwisolver=1.3.2=py38h1fd1430_1 66 | krb5=1.19.2=hcc1bbae_3 67 | lcms2=2.12=hddcbb42_0 68 | ld_impl_linux-64=2.35.1=h7274673_9 69 | libblas=3.9.0=12_linux64_openblas 70 | libbrotlicommon=1.0.9=h7f98852_6 71 | libbrotlidec=1.0.9=h7f98852_6 72 | libbrotlienc=1.0.9=h7f98852_6 73 | libcblas=3.9.0=12_linux64_openblas 74 | libcurl=7.78.0=h0b77cf5_0 75 | libedit=3.1.20210910=h7f8727e_0 76 | libev=4.33=h516909a_1 77 | libffi=3.3=he6710b0_2 78 | libgcc-ng=11.2.0=h1d223b6_11 79 | libgfortran-ng=11.2.0=h69a702a_11 80 | libgfortran5=11.2.0=h5c6108e_11 81 | libgomp=11.2.0=h1d223b6_11 82 | libiconv=1.16=h516909a_0 83 | liblapack=3.9.0=12_linux64_openblas 84 | libnetcdf=4.7.4=nompi_h56d31a8_107 85 | libnghttp2=1.43.0=h812cca2_1 86 | libopenblas=0.3.18=pthreads_h8fe5266_0 87 | libpng=1.6.37=h21135ba_2 88 | libsodium=1.0.18=h7b6447c_0 89 | libssh2=1.10.0=ha56f1ee_2 90 | libstdcxx-ng=11.2.0=he4da1e4_11 91 | libtiff=4.2.0=h85742a9_0 92 | libuuid=2.32.1=h7f98852_1000 93 | libwebp-base=1.2.1=h7f98852_0 94 | libxcb=1.13=h7f98852_1004 95 | libxml2=2.9.12=h72842e0_0 96 | locket=0.2.0=py_2 97 | lz4-c=1.9.2=he1b5a44_3 98 | lzo=2.10=h516909a_1000 99 | markupsafe=2.0.1=py38h27cfd23_0 100 | matplotlib-base=3.5.1=py38hf4fb855_0 101 | matplotlib-inline=0.1.2=pyhd3eb1b0_2 102 | mistune=0.8.4=py38h7b6447c_1000 103 | mkl=2021.4.0=h06a4308_640 104 | mock=4.0.3=py38h578d9bd_2 105 | munkres=1.1.4=pyh9f0ad1d_0 106 | nb_conda=2.2.1=py38h06a4308_1 107 | nb_conda_kernels=2.3.1=py38h06a4308_0 108 | nbclient=0.5.3=pyhd3eb1b0_0 109 | nbconvert=6.1.0=py38h06a4308_0 110 | nbformat=5.1.3=pyhd3eb1b0_0 111 | ncurses=6.3=h7f8727e_2 112 | nest-asyncio=1.5.1=pyhd3eb1b0_0 113 | networkx=2.6.3=pyhd8ed1ab_1 114 | nibabel=3.2.1=pyhd8ed1ab_0 115 | ninja=1.10.2=py38hd09550d_3 116 | nomkl=3.0=0 117 | notebook=6.4.6=py38h06a4308_0 118 | numexpr=2.8.0=py38h6045d29_100 119 | numpy=1.22.0=py38h6ae9a64_0 120 | olefile=0.46=pyh9f0ad1d_1 121 | openssl=1.1.1l=h7f8727e_0 122 | osqp=0.6.2.post0=py38h43a58ef_3 123 | packaging=21.3=pyhd8ed1ab_0 124 | pandas=1.3.5=py38h43a58ef_0 125 | pandocfilters=1.4.3=py38h06a4308_1 126 | parso=0.8.2=pyhd3eb1b0_0 127 | partd=1.2.0=pyhd8ed1ab_0 128 | pathlib=1.0.1=py38h578d9bd_5 129 | patsy=0.5.2=pyhd8ed1ab_0 130 | pexpect=4.8.0=pyhd3eb1b0_3 131 | pickleshare=0.7.5=pyhd3eb1b0_1003 132 | pillow=7.2.0=py38h9776b28_2 133 | pip=21.2.4=py38h06a4308_0 134 | prometheus_client=0.12.0=pyhd3eb1b0_0 135 | prompt-toolkit=3.0.20=pyhd3eb1b0_0 136 | pthread-stubs=0.4=h36c2ea0_1001 137 | ptyprocess=0.7.0=pyhd3eb1b0_2 138 | pycparser=2.21=pyhd3eb1b0_0 139 | pydicom=2.2.2=pyh6c4a22f_0 140 | pygments=2.10.0=pyhd3eb1b0_0 141 | pyparsing=3.0.6=pyhd8ed1ab_0 142 | pyrsistent=0.18.0=py38heee7806_0 143 | pytables=3.6.1=py38hc386592_3 144 | python=3.8.12=h12debd9_0 145 | python-dateutil=2.8.2=pyhd8ed1ab_0 146 | python_abi=3.8=2_cp38 147 | pytorch=1.4.0=py3.8_cuda10.1.243_cudnn7.6.3_0 148 | pytz=2021.3=pyhd8ed1ab_0 149 | pywavelets=1.2.0=py38h6c62de6_1 150 | pyyaml=6.0=py38h497a2fe_3 151 | pyzmq=22.3.0=py38h295c915_2 152 | qdldl-python=0.1.5=py38h43a58ef_2 153 | readline=8.1=h27cfd23_0 154 | scikit-image=0.19.1=py38h43a58ef_0 155 | scikit-learn=1.0.2=py38h1561384_0 156 | scipy=1.7.3=py38h56a6a73_0 157 | scs=3.0.0=py38h6afa1d1_1 158 | send2trash=1.8.0=pyhd3eb1b0_1 159 | setuptools=58.0.4=py38h06a4308_0 160 | six=1.16.0=pyh6c4a22f_0 161 | sqlite=3.37.0=hc218d9a_0 162 | statsmodels=0.13.1=py38h6c62de6_0 163 | tbb=2020.2=h4bd325d_4 164 | terminado=0.9.4=py38h06a4308_0 165 | testpath=0.5.0=pyhd3eb1b0_0 166 | threadpoolctl=3.0.0=pyh8a188c0_0 167 | tifffile=2019.7.26.2=py38_0 168 | tk=8.6.11=h1ccaba5_0 169 | toolz=0.11.2=pyhd8ed1ab_0 170 | tornado=6.1=py38h27cfd23_0 171 | tqdm=4.62.3=pyhd8ed1ab_0 172 | traitlets=5.1.1=pyhd3eb1b0_0 173 | unicodedata2=14.0.0=py38h497a2fe_0 174 | vtk=8.2.0=py38hf2e56f5_218 175 | wcwidth=0.2.5=pyhd3eb1b0_0 176 | webencodings=0.5.1=py38_1 177 | wheel=0.37.0=pyhd3eb1b0_1 178 | xorg-kbproto=1.0.7=h7f98852_1002 179 | xorg-libice=1.0.10=h7f98852_0 180 | xorg-libsm=1.2.3=hd9c2040_1000 181 | xorg-libx11=1.7.2=h7f98852_0 182 | xorg-libxau=1.0.9=h7f98852_0 183 | xorg-libxdmcp=1.1.3=h7f98852_0 184 | xorg-libxt=1.2.1=h7f98852_2 185 | xorg-xproto=7.0.31=h7f98852_1007 186 | xz=5.2.5=h7b6447c_0 187 | yaml=0.2.5=h516909a_0 188 | zeromq=4.3.4=h2531618_0 189 | zipp=3.6.0=pyhd3eb1b0_0 190 | zlib=1.2.11=h7f8727e_4 191 | zstd=1.4.5=h9ceee32_0 192 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | 5 | import os, sys 6 | sys.path.insert(0,'..') 7 | 8 | from dfibert.tracker.nn.rainbow_agent import DQNAgent 9 | from dfibert.util import set_seed 10 | 11 | import dfibert.envs.RLTractEnvironment_fast as RLTe 12 | 13 | 14 | def train(path, pretraining=False, max_steps=3000000, batch_size=32, replay_memory_size=20000, gamma=0.99, network_update_every=10000, learning_rate=0.0000625, checkpoint_every=200000, wandb=False, step_width = 0.8, odf_mode = "CSD"): 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | print("Device:", device) 18 | print("Init environment..") 19 | 20 | #seeds_CST = np.load('data/ismrm_seeds_CST.npy') 21 | #seeds_CST = torch.from_numpy(seeds_CST) 22 | 23 | env = RLTe.RLTractEnvironment(dataset = 'ISMRM', step_width=step_width, 24 | device = device, seeds = None, action_space=20, 25 | tracking_in_RAS = False, odf_state = False, odf_mode = odf_mode) 26 | 27 | print("..done!") 28 | print("Init agent..") 29 | 30 | 31 | agent = DQNAgent(env=env, memory_size = replay_memory_size, 32 | batch_size = batch_size, 33 | target_update = network_update_every, 34 | lr = learning_rate, 35 | gamma = gamma, 36 | device = device, 37 | wandb_log=wandb 38 | ) 39 | print("..done!") 40 | 41 | if pretraining: 42 | print("Start pretraining..") 43 | agent.pretrain(path=path+'super_checkpoints/') 44 | 45 | print("Start DQL...") 46 | agent.train(num_steps = max_steps, checkpoint_interval=checkpoint_every, path = path, plot=False) 47 | 48 | 49 | 50 | def resume(path, max_steps=3000000, batch_size=32, replay_memory_size=20000, gamma=0.99, network_update_every=10000, learning_rate=0.0000625, checkpoint_every=200000, wandb=False): 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | print("Device:", device) 53 | print("Init environment..") 54 | 55 | #seeds_CST = np.load('data/ismrm_seeds_CST.npy') 56 | #seeds_CST = torch.from_numpy(seeds_CST) 57 | 58 | env = RLTe.RLTractEnvironment(dataset = 'ISMRM', step_width=0.2, 59 | device = device, seeds = None, action_space=20, 60 | tracking_in_RAS = False, odf_state = False, odf_mode = "CSD") 61 | 62 | print("..done!") 63 | print("Init agent..") 64 | 65 | agent = DQNAgent(env=env, memory_size = replay_memory_size, 66 | batch_size = batch_size, 67 | target_update = network_update_every, 68 | lr = learning_rate, 69 | gamma = gamma, 70 | device = device, 71 | wandb_log=wandb 72 | ) 73 | 74 | print("..done!") 75 | print("Resume training..") 76 | 77 | agent.resume_training(path=path, plot=False, wandb=wandb) 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--max_steps", default=3000000, type=int, help="Choose maximum amount of training steps") 82 | #parser.add_argument("--start_learning", default=2000, type=int, help="Set amount of steps after which epsilon will be decreased and the agent will be learning") 83 | parser.add_argument("--replay_memory_size", default=100000, type=int, help="Set amount of past expiriences stored in the replay memory") 84 | #parser.add_argument("--eps_annealing_steps", default=100000, type=int, help="Set amount of steps after which epsilon is decreased more slowly until max_steps") 85 | #parser.add_argument("--agent_history_length", default=1, type=int, help="Choose how many past states are included in each input to update the agent") 86 | #parser.add_argument("--evaluate_every", default=20000, type=int, help="Set evaluation interval") 87 | #parser.add_argument("--eval_runs", default=10, type=int, help="Set amount of runs performed during evaluation") 88 | parser.add_argument("--network_update_every", default=1000, type=int, help="Set target network update frequency") 89 | #parser.add_argument("--max_episode_length", default=550, type=int, help="Set maximum episode length") 90 | parser.add_argument("--batch_size", default=512, type=int, help="Set batch size retrieved from memory for learning") 91 | parser.add_argument("--learning_rate", default=0.0000625, type=float, help="Set learning rate") 92 | parser.add_argument("--gamma", default=0.99, type=float, help="Set discount factor for Bellman equation") 93 | #parser.add_argument("--eps_final", default=0.1, type=float, help="Set first value to which epsilon is lowered to after eps_annealing_steps") 94 | #parser.add_argument("--eps_final_step", default=0.01, type=float, help="Set the second value to which epsilon is lowered to from eps_final until max_steps") 95 | parser.add_argument("--checkpoint_every", default=200000, type=int, help="Set checkpointing interval") 96 | 97 | parser.add_argument("--path", default=".", type=str, help="Set default saving path of logs and checkpoints") 98 | parser.add_argument("--seed", default=42, type=int, help="Set a seed for the training run") 99 | 100 | parser.add_argument("--step_width", default=0.8, type=float, help="step width for tracking") 101 | parser.add_argument("--odf_mode", default="CSD", type=str, help="compute ODF in reward based on DTI or CSD?") 102 | 103 | parser.add_argument("--pretrain", action='store_true', help="Pretrain the DQN with superwised learnin") 104 | parser.add_argument("--resume_training", dest="resume", action='store_true', help="Load checkpoint from path folder and resume training") 105 | 106 | parser.add_argument("--wandb", action='store_true', help="Log training on W&B") 107 | parser.add_argument("--wandb_project", default="deepFibreTracking", type=str, help="Set name of W&B project") 108 | parser.add_argument("--wandb_entity", default=None, type=str, help="Set entity of W&B project") 109 | #parser.add_argument("--odf-as-state-value",dest="odf_state", action='store_true') 110 | #parser.set_defaults(odf_state=False) 111 | 112 | args = parser.parse_args() 113 | 114 | set_seed(args.seed) 115 | 116 | if args.wandb: 117 | import wandb 118 | config = args 119 | 120 | if args.resume: 121 | if args.wandb: 122 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=config, resume=True) 123 | resume(args.path, batch_size=args.batch_size, gamma=args.gamma, checkpoint_every=args.checkpoint_every, wandb=args.wandb) 124 | 125 | else: 126 | if args.wandb: 127 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=config, resume='allow') 128 | train(args.path, pretraining=args.pretrain, max_steps=args.max_steps, replay_memory_size=args.replay_memory_size, 129 | batch_size=args.batch_size, gamma=args.gamma, 130 | network_update_every=args.network_update_every, learning_rate=args.learning_rate, 131 | checkpoint_every=args.checkpoint_every, wandb=args.wandb, step_width = args.step_width, odf_mode = args.odf_mode) 132 | 133 | 134 | -------------------------------------------------------------------------------- /examples/mlp_training.py: -------------------------------------------------------------------------------- 1 | """Simple MLP example training""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils import data as dataL 6 | 7 | from dfibert.data import DataPreprocessor 8 | from dfibert.data.postprocessing import Resample100 9 | from dfibert.dataset.processing import RegressionProcessing 10 | from dfibert.dataset import SingleDirectionsDataset # TODO update example 11 | from dfibert.tracker import get_csd_streamlines 12 | from dfibert.util import random_split 13 | 14 | # This is a simple MLP training to show example usage for the provided library 15 | 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | # switch to cuda if available 21 | # this only works if you want to train on a single CPU/GPU, without parallelization 22 | 23 | class ModelMLP(nn.Module): # the NN as PyTorch module, this is best practise to keep the structure modular and simple. 24 | 'MLP Model Class' # but you can also just use an nn.Sequential(*layers) for simple models 25 | 26 | def __init__(self, hidden_sizes=None, activation_function=None, dropout=0.0, input_size=None): 27 | """Initializes the MLP model with given parameters: 28 | 29 | Arguments: 30 | hidden_sizes: a list containing the hidden sizes 31 | activation_function: any torch.nn.* activation function 32 | dropout: specifies the dropout applied after each layer, if zero, no dropout is applied. default: 0 33 | input_size: the input size of dwi data. 34 | """ 35 | super(ModelMLP, self).__init__() 36 | 37 | self.optimizer = None 38 | self.loss = None 39 | # construct layers 40 | hidden_sizes.insert(0, input_size) # input size 41 | layers = [nn.Flatten(start_dim=1)] 42 | for index in range(len(hidden_sizes) - 1): 43 | layers.append(nn.Linear(hidden_sizes[index], hidden_sizes[index + 1])) 44 | layers.append(activation_function) 45 | if dropout > 0: 46 | layers.append(nn.Dropout(p=dropout)) 47 | layers.append(nn.Linear(hidden_sizes[-1], 3)) 48 | layers.append(nn.Tanh()) 49 | 50 | self.main = nn.Sequential(*layers) 51 | 52 | def forward(self, 53 | x): # default function for propagating data through the network, do not change name or arguments. 54 | # called with mlp_model(data) 55 | """pass data through the network. returns output tuple. 56 | 57 | Arguments: 58 | x: the data to pass 59 | """ 60 | return self.main(x) 61 | 62 | def compile_model(self, optimizer, loss): # to keep training internal, let optimizer and loss be set external 63 | self.optimizer = optimizer 64 | self.loss = loss 65 | 66 | def train_model(self, training_set, validation_set=None, epochs=200): 67 | "Trains the model" 68 | best_loss = 10000 69 | for epoch in range(0, epochs): # each epoch 70 | 71 | self.train() # change model to train mode to activate dropout and backpropagation 72 | loss = self._feed_model(training_set) 73 | print("Epoch {} - train: {:.6f}".format(epoch, loss), end="\r") 74 | 75 | if validation_set is not None: 76 | self.eval() # change model to evalulation mode to deactivate dropout and backpropagation 77 | validation_loss = self._feed_model(validation_set, validation=True) 78 | 79 | if validation_loss < best_loss: # save best model to file 80 | best_loss = validation_loss 81 | torch.save(self.state_dict(), 'best_model.pt') 82 | 83 | print("Epoch {} - train: {:.6f} - test: {:.6f}".format(epoch, loss, validation_loss)) 84 | 85 | def _feed_model(self, generator, 86 | validation=False): # prepares data in mini batches and passes them through the network 87 | """Feeds given data to model, returns average loss""" 88 | whole_loss = 0 89 | divisor = 0 90 | for batch, (dwi, next_dir) in enumerate(generator): 91 | dwi = dwi.to(device) 92 | next_dir = next_dir.to(device) 93 | 94 | print("Batch {}/{}".format(batch, len(generator)), end="\r") 95 | 96 | if not validation: 97 | self.optimizer.zero_grad() 98 | 99 | pred_next_dir = self(dwi) 100 | 101 | loss = self.loss(pred_next_dir, next_dir) 102 | 103 | if not validation: # backpropagation 104 | loss.backward() 105 | self.optimizer.step() 106 | 107 | whole_loss += next_dir.shape[ 108 | 0] * loss.item() # necessary for loss calculation, loss.item() is mean over batch 109 | divisor += next_dir.shape[0] 110 | 111 | # delete tensors to prevent ram usage until generation of next batch 112 | del dwi 113 | del next_dir 114 | del pred_next_dir 115 | 116 | return whole_loss / divisor 117 | 118 | 119 | def radians_loss(x, y): 120 | """Quick implementation of the radian loss 1- cos(alpha). 121 | 122 | Arguments: 123 | x: the network output 124 | y: the supposed output 125 | """ 126 | mask = ((y == 0).sum(-1) < 3) # zero vectors in supposed output 127 | cossim = torch.nn.CosineSimilarity(dim=1) 128 | output = cossim(x, y) ** 2 129 | output = output[mask.squeeze() != 0] 130 | return 1 - torch.mean(output) 131 | 132 | 133 | def main(): 134 | """The main function""" 135 | preprocessor = DataPreprocessor().normalize().crop() # normalisation and cropping 136 | tracker_data = DataPreprocessor().get_hcp("data/HCP/100307") # Initialize DW-MRT Image of HCP Participant #100307 137 | data = preprocessor.preprocess(tracker_data) 138 | print("Initialized data...") 139 | streamlines = get_csd_streamlines(tracker_data, random_seeds=True, seeds_count=10000) 140 | 141 | print("Initialized streamlines...") 142 | 143 | processing = RegressionProcessing(rotate=False, grid_dimension=(3, 3, 3), 144 | postprocessing=Resample100()) # choose a data Processing option for your training 145 | dataset = SingleDirectionsDataset(streamlines, data, processing, append_reverse=True, online_caching=True) 146 | # choose a dataset, this one is good for non-recurrent architectures 147 | 148 | training_set, validation_set = random_split( 149 | dataset) # randomly splits the dataset into training and validation. Default 90% Training 150 | 151 | model = ModelMLP(hidden_sizes=[512, 512, 512], activation_function=nn.ReLU(), dropout=0.05, input_size=2700).to( 152 | device) # Initialize the model 153 | 154 | params = {'batch_size': 2048, 'num_workers': 0, 155 | 'shuffle': True} # !!!! NUM WORKERS > 0 does not work with caching yet !!! 156 | training_generator = dataL.DataLoader(training_set, **params) # specify a training and testing generator 157 | validation_generator = dataL.DataLoader(validation_set, **params) 158 | 159 | print("Initialized dataset & model...") 160 | 161 | optimizer = optim.Adam(model.parameters(), lr=1e-5) 162 | 163 | model.compile_model(optimizer, radians_loss) # choose optimizer and loss 164 | model.train_model(training_generator, validation_set=validation_generator, epochs=5000) # start training 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /dfibert/tracker/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementing different tracking approaches""" 2 | 3 | import os 4 | from typing import List 5 | 6 | from dipy.tracking.utils import random_seeds_from_mask, seeds_from_mask 7 | from dipy.tracking.life import transform_streamlines 8 | from dipy.tracking.local_tracking import LocalTracking 9 | from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion 10 | from dipy.tracking.streamline import Streamlines 11 | from dipy.tracking import metrics 12 | from dipy.reconst.dti import TensorModel 13 | from dipy.io.streamline import save_vtk_streamlines, load_vtk_streamlines 14 | from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, auto_response_ssst 15 | from dipy.data import get_sphere, default_sphere 16 | from dipy.direction import peaks_from_model, DeterministicMaximumDirectionGetter 17 | import dipy.reconst.dti as dti 18 | 19 | import fury 20 | import vtk 21 | import numpy as np 22 | 23 | def _get_seeds(data_container, random_seeds=False, seeds_count=30000, seeds_per_voxel=False): 24 | if not random_seeds: 25 | return seeds_from_mask(data_container.binary_mask, affine=data_container.aff) 26 | else: 27 | return random_seeds_from_mask(data_container.binary_mask, 28 | seeds_count=seeds_count, 29 | seed_count_per_voxel=seeds_per_voxel, 30 | affine=data_container.aff) 31 | 32 | 33 | def get_csd_streamlines(data_container, random_seeds=False, seeds_count=30000, seeds_per_voxel=False, step_width=1.0, 34 | roi_r=10, auto_response_fa_threshold=0.7, fa_threshold=0.15, relative_peak_threshold=0.5, 35 | min_separation_angle=25): 36 | """ 37 | Tracks and returns CSD Streamlines for the given DataContainer. 38 | 39 | Parameters 40 | ---------- 41 | data_container 42 | The DataContainer we would like to track streamlines on 43 | random_seeds 44 | A boolean indicating whether we would like to use random seeds 45 | seeds_count 46 | If we use random seeds, this specifies the seed count 47 | seeds_per_voxel 48 | If True, the seed count is specified per voxel 49 | step_width 50 | The step width used while tracking 51 | roi_r 52 | The radii of the cuboid roi for the automatic estimation of single-shell single-tissue response function using FA. 53 | auto_response_fa_threshold 54 | The FA threshold for the automatic estimation of single-shell single-tissue response function using FA. 55 | fa_threshold 56 | The FA threshold to use to stop tracking 57 | relative_peak_threshold 58 | The relative peak threshold to use to get peaks from the CSDModel 59 | min_separation_angle 60 | The minimal separation angle of peaks 61 | Returns 62 | ------- 63 | Streamlines 64 | A list of Streamlines 65 | """ 66 | seeds = _get_seeds(data_container, random_seeds, seeds_count, seeds_per_voxel) 67 | 68 | response, _ = auto_response_ssst(data_container.gtab, data_container.dwi, roi_radii=roi_r, fa_thr=auto_response_fa_threshold) 69 | csd_model = ConstrainedSphericalDeconvModel(data_container.gtab, response) 70 | 71 | direction_getter = peaks_from_model(model=csd_model, 72 | data=data_container.dwi, 73 | sphere=get_sphere('symmetric724'), 74 | mask=data_container.binary_mask, 75 | relative_peak_threshold=relative_peak_threshold, 76 | min_separation_angle=min_separation_angle, 77 | parallel=False) 78 | 79 | dti_fit = dti.TensorModel(data_container.gtab, fit_method='LS').fit(data_container.dwi, mask=data_container.binary_mask) 80 | classifier = ThresholdStoppingCriterion(dti_fit.fa, fa_threshold) 81 | 82 | streamlines_generator = LocalTracking(direction_getter, classifier, seeds, data_container.aff, step_size=step_width) 83 | streamlines = Streamlines(streamlines_generator) 84 | 85 | return streamlines 86 | 87 | 88 | def get_dti_streamlines(data_container, random_seeds=False, seeds_count=30000, seeds_per_voxel=False, step_width=1.0, 89 | max_angle=30.0, fa_threshold=0.15): 90 | """ 91 | Tracks and returns CSD Streamlines for the given DataContainer. 92 | 93 | Parameters 94 | ---------- 95 | data_container 96 | The DataContainer we would like to track streamlines on 97 | random_seeds 98 | A boolean indicating whether we would like to use random seeds 99 | seeds_count 100 | If we use random seeds, this specifies the seed count 101 | seeds_per_voxel 102 | If True, the seed count is specified per voxel 103 | step_width 104 | The step width used while tracking 105 | fa_threshold 106 | The FA threshold to use to stop tracking 107 | max_angle 108 | The maximum allowed angle between incoming and outgoing angle, float between 0.0 and 90.0 deg 109 | Returns 110 | ------- 111 | Streamlines 112 | A list of Streamlines 113 | """ 114 | seeds = _get_seeds(data_container, random_seeds, seeds_count, seeds_per_voxel) 115 | 116 | dti_fit = TensorModel(data_container.gtab).fit(data_container.dwi, mask=data_container.binary_mask) 117 | dti_fit_odf = dti_fit.odf(sphere=default_sphere) 118 | 119 | direction_getter = DeterministicMaximumDirectionGetter.from_pmf(dti_fit_odf, 120 | max_angle=max_angle, 121 | sphere=default_sphere) 122 | classifier = ThresholdStoppingCriterion(dti_fit.fa, fa_threshold) 123 | 124 | streamlines_generator = LocalTracking(direction_getter, classifier, seeds, data_container.aff, step_size=step_width) 125 | streamlines = Streamlines(streamlines_generator) 126 | 127 | return streamlines 128 | 129 | 130 | def save_streamlines(streamlines: list, path: str, to_lps=True, binary=False): 131 | """ 132 | Saves the given streamlines to a file in VTK 4.2 file format due to 3dslicer compatability. 133 | 134 | Parameters 135 | ---------- 136 | streamlines 137 | The streamlines we want to save 138 | path 139 | The path we save the streamlines to 140 | to_lps 141 | A boolean indicating whether we want to save them in the LPS format instead of RAS (True by default) 142 | binary 143 | If True, the file will be written in a binary format. 144 | Returns 145 | ------- 146 | 147 | """ 148 | if to_lps: 149 | # ras (mm) to lps (mm) 150 | to_lps = np.eye(4) 151 | to_lps[0, 0] = -1 152 | to_lps[1, 1] = -1 153 | streamlines = transform_streamlines(streamlines, to_lps) 154 | 155 | polydata, _ = fury.utils.lines_to_vtk_polydata(streamlines) 156 | 157 | file_extension = path.split(".")[-1].lower() 158 | 159 | writer = vtk.vtkPolyDataWriter() 160 | writer.SetFileName(path) 161 | writer.SetFileVersion(42) 162 | writer = fury.io.set_input(writer, polydata) 163 | 164 | writer.Update() 165 | writer.Write() 166 | 167 | 168 | 169 | def load_streamlines(path: str, to_lps=True) -> list: 170 | """ 171 | Loads streamlines from the given path. 172 | Parameters 173 | ---------- 174 | path 175 | The path to load streamlines from 176 | to_lps 177 | If True, we load streamlines under the assumption that they were stored in LPS (True by default) 178 | Returns 179 | ------- 180 | list 181 | The streamlines we are trying to load 182 | """ 183 | if os.path.isdir(path): 184 | streamlines = [] 185 | for file in os.listdir(path): 186 | if file.endswith(".fib") or file.endswith(".vtk"): 187 | sl = load_vtk_streamlines(os.path.join(path, file), to_lps) 188 | streamlines.extend(sl) 189 | else: 190 | streamlines = load_vtk_streamlines(path, to_lps) 191 | return streamlines 192 | 193 | 194 | def filtered_streamlines_by_length(streamlines: List, minimum=20, maximum=200) -> List: 195 | """ 196 | Returns filtered streamlines that are longer than minimum (in mm) and shorter than maximum (in mm) 197 | Parameters 198 | ---------- 199 | streamlines 200 | The streamlines we would like to filter 201 | minimum 202 | The minimum length in mm 203 | maximum 204 | The maximum length in mm 205 | Returns 206 | ------- 207 | List 208 | The filtered streamlines 209 | """ 210 | return [x for x in streamlines if minimum <= metrics.length(x) <= maximum] 211 | -------------------------------------------------------------------------------- /dfibert/util.py: -------------------------------------------------------------------------------- 1 | """Helpful functions required multiple times in different contexts 2 | """ 3 | import os 4 | import random 5 | import torch 6 | import numpy as np 7 | from dipy.core.sphere import Sphere 8 | from dipy.core.geometry import sphere_distance 9 | 10 | 11 | def set_seed(seed): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | 20 | 21 | def rotation_from_vectors_p(rot, vectors_orig, vectors_fin): 22 | vectors_orig = vectors_orig / np.linalg.norm(vectors_orig, axis=1)[:, None] 23 | vectors_fin = vectors_fin / np.linalg.norm(vectors_fin, axis=1)[:, None] 24 | axes = np.cross(vectors_orig, vectors_fin) 25 | axes_lens = np.linalg.norm(axes, axis=1) 26 | 27 | axes_lens[axes_lens == 0] = 1 28 | 29 | axes = axes / axes_lens[:, None] 30 | 31 | x = axes[:, 0] 32 | y = axes[:, 1] 33 | z = axes[:, 2] 34 | 35 | angles = np.arccos(np.sum(vectors_orig * vectors_fin, axis=1)) 36 | sa = np.sin(angles) 37 | ca = np.cos(angles) # cos 38 | rot[:, 0, 0] = 1.0 + (1.0 - ca) * (x ** 2 - 1.0) 39 | rot[:, 0, 1] = -z * sa + (1.0 - ca) * x * y 40 | rot[:, 0, 2] = y * sa + (1.0 - ca) * x * z 41 | rot[:, 1, 0] = z * sa + (1.0 - ca) * x * y 42 | rot[:, 1, 1] = 1.0 + (1.0 - ca) * (y ** 2 - 1.0) 43 | rot[:, 1, 2] = -x * sa + (1.0 - ca) * y * z 44 | rot[:, 2, 0] = -y * sa + (1.0 - ca) * x * z 45 | rot[:, 2, 1] = x * sa + (1.0 - ca) * y * z 46 | rot[:, 2, 2] = 1.0 + (1.0 - ca) * (z ** 2 - 1.0) 47 | 48 | 49 | def rotation_from_vectors(rot, vector_orig, vector_fin): 50 | """Calculate the rotation matrix required to rotate from one vector to another. 51 | For the rotation of one vector to another, there are an infinit series of rotation matrices 52 | possible. Due to axially symmetry, the rotation axis can be any vector lying in the symmetry 53 | plane between the two vectors. Hence the axis-angle convention will be used to construct the 54 | matrix with the rotation axis defined as the cross product of the two vectors. The rotation 55 | angle is the arccosine of the dot product of the two unit vectors. 56 | Given a unit vector parallel to the rotation axis, w = [x, y, z] and the rotation angle a, 57 | the rotation matrix R is:: 58 | | 1 + (1-cos(a))*(x*x-1) -z*sin(a)+(1-cos(a))*x*y y*sin(a)+(1-cos(a))*x*z | 59 | R = | z*sin(a)+(1-cos(a))*x*y 1 + (1-cos(a))*(y*y-1) -x*sin(a)+(1-cos(a))*y*z | 60 | | -y*sin(a)+(1-cos(a))*x*z x*sin(a)+(1-cos(a))*y*z 1 + (1-cos(a))*(z*z-1) | 61 | @param rot: The 3x3 rotation matrix to update. 62 | @type rot: 3x3 numpy array 63 | @param vector_orig: The unrotated vector defined in the reference frame. 64 | @type vector_orig: numpy array, len 3 65 | @param vector_fin: The rotated vector defined in the reference frame. 66 | @type vector_fin: numpy array, len 3 67 | """ 68 | 69 | # Convert the vectors to unit vectors. 70 | vector_orig = vector_orig / np.linalg.norm(vector_orig) 71 | vector_fin = vector_fin / np.linalg.norm(vector_fin) 72 | 73 | # The rotation axis (normalised). 74 | axis = np.cross(vector_orig, vector_fin) 75 | axis_len = np.linalg.norm(axis) 76 | if axis_len != 0.0: 77 | axis = axis / axis_len 78 | 79 | # Alias the axis coordinates. 80 | x = axis[0] 81 | y = axis[1] 82 | z = axis[2] 83 | 84 | # The rotation angle. 85 | angle = np.arccos(np.dot(vector_orig, vector_fin)) 86 | 87 | # Trig functions (only need to do this maths once!). 88 | ca = np.cos(angle) 89 | sa = np.sin(angle) 90 | 91 | # Calculate the rotation matrix elements. 92 | rot[0, 0] = 1.0 + (1.0 - ca) * (x ** 2 - 1.0) 93 | rot[0, 1] = -z * sa + (1.0 - ca) * x * y 94 | rot[0, 2] = y * sa + (1.0 - ca) * x * z 95 | rot[1, 0] = z * sa + (1.0 - ca) * x * y 96 | rot[1, 1] = 1.0 + (1.0 - ca) * (y ** 2 - 1.0) 97 | rot[1, 2] = -x * sa + (1.0 - ca) * y * z 98 | rot[2, 0] = -y * sa + (1.0 - ca) * x * z 99 | rot[2, 1] = x * sa + (1.0 - ca) * y * z 100 | rot[2, 2] = 1.0 + (1.0 - ca) * (z ** 2 - 1.0) 101 | 102 | 103 | def get_reference_orientation(orientation="R+"): 104 | """Get current reference rotation 105 | 106 | Returns 107 | ------- 108 | numpy.ndarray 109 | The reference rotation usable for rotations. 110 | """ 111 | orientation = orientation.upper() 112 | ref = None 113 | if orientation[0] == 'R': 114 | ref = np.array([1, 0, 0]) 115 | elif orientation[0] == 'A': 116 | ref = np.array([0, 1, 0]) 117 | elif orientation[0] == 'S': 118 | ref = np.array([0, 1, 0]) 119 | if len(orientation) > 1 and orientation[1] == '-': 120 | ref = ref * -1 121 | return ref 122 | 123 | 124 | def get_2D_sphere(no_phis=16, no_thetas=16): 125 | """Retrieve evenly distributed 2D sphere out of phi and theta count. 126 | 127 | 128 | Parameters 129 | ---------- 130 | no_phis : int, optional 131 | The numbers of phis in the sphere, by default as in config file / 16 132 | no_thetas : int, optional 133 | The numbers of thetas in the sphere, by default as in config file / 16 134 | 135 | Returns 136 | ------- 137 | Sphere 138 | The 2D sphere requested 139 | """ 140 | xi = np.arange(0, np.pi, np.pi / no_thetas) # theta 141 | yi = np.arange(-np.pi, np.pi, 2 * np.pi / no_phis) # phi 142 | 143 | basis = np.array(np.meshgrid(yi, xi)) 144 | 145 | sphere = Sphere(theta=basis[0, :], phi=basis[1, :]) 146 | 147 | return sphere 148 | 149 | 150 | def get_grid(grid_dimension): 151 | """Calculates grid for given dimension 152 | 153 | Parameters 154 | ---------- 155 | grid_dimension : numpy.ndarray 156 | The grid dimensions of the grid to calculate 157 | 158 | Returns 159 | ------- 160 | numpy.ndarray 161 | The requested grid 162 | """ 163 | (dx, dy, dz) = (grid_dimension - 1) / 2 164 | return np.moveaxis(np.mgrid[-dx:dx + 1, -dy:dy + 1, -dz:dz + 1], 0, 3) 165 | 166 | 167 | def random_split(dataset, training_part=0.9): 168 | """Retrieves a dataset from given path and splits them randomly in train and test data. 169 | 170 | Parameters 171 | ---------- 172 | dataset : Dataset 173 | The dataset to use 174 | training_part : float, optional 175 | The training part, by default 0.9 (90%) 176 | 177 | Returns 178 | ------- 179 | tuple 180 | A tuple containing (train_dataset, validation_dataset) 181 | """ 182 | train_len = int(training_part * len(dataset)) 183 | test_len = len(dataset) - train_len 184 | (train_split, test_split) = torch.utils.data.random_split(dataset, (train_len, test_len)) 185 | return train_split, test_split 186 | 187 | 188 | def get_mask_from_lengths(lengths): 189 | """Returns a mask for given array of lengths 190 | 191 | Parameters 192 | ---------- 193 | lengths: Tensor 194 | The lengths to padd 195 | Returns 196 | ------- 197 | Tensor 198 | The requested mask.""" 199 | return torch.arange(torch.max(lengths), device=lengths.device)[None, :] < lengths[:, None] 200 | 201 | 202 | def apply_rotation_matrix_to_grid(grid, rot_matrix): 203 | """Applies the given list of rotation matrices to given grid 204 | 205 | Parameters 206 | ---------- 207 | grid : numpy.ndarray 208 | The grid 209 | rot_matrix : numpy.ndarray 210 | The rotation matrix with the dimensions (N, 3, 3) 211 | 212 | Returns 213 | ------- 214 | numpy.ndarray 215 | The grid, rotated along the rotation_matrix; Shape: (N, ...grid_dimensions) 216 | """ 217 | return (rot_matrix.repeat(grid.size / 3, axis=0) @ grid[None,].repeat(len(rot_matrix), axis=0).reshape(-1, 3, 218 | 1)).reshape( 219 | (-1, *grid.shape)) 220 | 221 | 222 | def direction_to_classification(sphere, next_dir, include_stop=False, last_is_stop=False, stop_values=None): 223 | # code adapted from Benou "DeepTract",exi 224 | # https://github.com/itaybenou/DeepTract/blob/master/utils/train_utils.py 225 | 226 | sl_len = len(next_dir) 227 | loop_len = sl_len - 1 if include_stop and last_is_stop else sl_len 228 | l = len(sphere.theta) + 1 if include_stop else len(sphere.theta) 229 | classification_output = np.zeros((sl_len, l)) 230 | for i in range(loop_len): 231 | if not (next_dir[i, 0] == 0.0 and next_dir[i, 1] == 0.0 and next_dir[i, 2] == 0.0): 232 | labels_odf = np.exp(-1 * sphere_distance(next_dir[i, :], np.asarray( 233 | [sphere.x, sphere.y, sphere.z]).T, radius=1, check_radius=False) * 10) 234 | if include_stop: 235 | classification_output[i][:-1] = labels_odf / np.sum(labels_odf) 236 | classification_output[i, -1] = 0.0 237 | else: 238 | classification_output[i] = labels_odf / np.sum(labels_odf) 239 | if include_stop and last_is_stop: 240 | classification_output[-1, -1] = 1 # stop condition or 241 | if include_stop and stop_values is not None: 242 | classification_output[:, -1] = stop_values # stop values 243 | return classification_output 244 | -------------------------------------------------------------------------------- /dfibert/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data module is handling all kinds of DWI-data. 3 | 4 | Use this as a starting point to represent your loaded DWI-scan. 5 | This module provides methods helping you to implement datasets, 6 | environments and all other kinds of modules with the requirement 7 | to work directly with the data. 8 | """ 9 | from __future__ import annotations 10 | 11 | import os 12 | import warnings 13 | from typing import Optional 14 | 15 | import dipy.reconst.dti as dti 16 | import nibabel as nb 17 | import numpy as np 18 | from dipy.core.gradients import gradient_table, GradientTable 19 | from dipy.denoise.localpca import localpca 20 | from dipy.denoise.pca_noise_estimate import pca_noise_estimate 21 | from dipy.io import read_bvals_bvecs 22 | from dipy.segment.mask import median_otsu 23 | from nibabel.affines import apply_affine 24 | from scipy.interpolate import RegularGridInterpolator 25 | 26 | from dfibert.data.exceptions import PointOutsideOfDWIError 27 | from dfibert.data.postprocessing import PostprocessingOption 28 | 29 | 30 | class DataPreprocessor(object): 31 | def __init__(self, parent: Optional[DataPreprocessor] = None): 32 | """ 33 | Creates a new empty DataPreprocessor. 34 | 35 | Parameters 36 | ---------- 37 | parent 38 | An optional previous DataPreprocessor we want to continue from 39 | """ 40 | self._parent = parent 41 | 42 | def _preprocess(self, data_container: DataContainer) -> DataContainer: 43 | if self._parent is None: 44 | return data_container 45 | else: 46 | return self._parent._preprocess(data_container) 47 | 48 | def preprocess(self, data_container: DataContainer) -> DataContainer: 49 | """ 50 | Returns a preprocessed DataContainer created by taking the given one and applying the given steps. 51 | Because data_containers are treated as immutable, the given data_container (and its numpy arrays) 52 | won't be modified. 53 | 54 | Parameters 55 | ---------- 56 | data_container 57 | The given data_container 58 | 59 | Returns 60 | ------- 61 | DataContainer 62 | A new preprocessed DataContainer 63 | """ 64 | 65 | dc = data_container 66 | data_container = DataContainer(dc.bvals.copy(), dc.bvecs.copy(), dc.gtab.copy(), dc.t1.copy(), dc.dwi.copy(), 67 | dc.aff.copy(), 68 | dc.binary_mask.copy(), dc.b0.copy(), None if dc.fa is None else dc.fa.copy()) 69 | return self._preprocess(data_container) 70 | 71 | def denoise(self, smooth=3, patch_radius=3) -> DataPreprocessor: 72 | """ 73 | Denoises the data using Local PCA with empirical thresholds 74 | 75 | Parameters 76 | ---------- 77 | smooth 78 | the voxel radius used by the Gaussian filter for the noise estimate 79 | patch_radius 80 | the voxel radius used by the Local PCA algorithm to denoise 81 | Returns 82 | ------- 83 | DataPreprocessor 84 | A new DataPreprocessor, incorporating the previous steps plus the new denoise 85 | """ 86 | return _DataDenoiser(self, smooth, patch_radius) 87 | 88 | def normalize(self) -> DataPreprocessor: 89 | """ 90 | Normalize DWI Data based on b0 image. 91 | The weights are divided by their b0 value. 92 | 93 | Returns 94 | ------- 95 | DataPreprocessor 96 | A new DataPreprocessor, incorporating the previous steps plus the new normalize 97 | """ 98 | return _DataNormalizer(self) 99 | 100 | def crop(self, b_value=1000.0, max_deviation=100.0, b0_threshold=10.0) -> DataPreprocessor: 101 | """ 102 | Crops the dataset based on B-value. 103 | 104 | All measurements where the B-value deviates more than `max_deviation` from the `b_value` 105 | are removed from the dataset, except from values less than the specified b0 threshold. 106 | 107 | Parameters 108 | ---------- 109 | b0_threshold 110 | the b0 threshold - everything below will be kept as well. 111 | b_value 112 | the intended B-value 113 | max_deviation 114 | the maximum allowed deviation from the given b_value 115 | Returns 116 | ------- 117 | DataPreprocessor 118 | A new DataPreprocessor, incorporating the previous steps plus the new crop 119 | """ 120 | return _DataCropper(self, b_value, max_deviation, b0_threshold) 121 | 122 | def fa_estimate(self): 123 | """ 124 | Does the FA estimation at the current position in the pipeline 125 | 126 | Returns 127 | ------- 128 | DataPreprocessor 129 | A new DataPreprocessor, incorporating the previous steps plus the new fa estimate 130 | """ 131 | return _DataFAEstimator(self) 132 | 133 | def get_hcp(self, path: str, b0_threshold: float = 10.0) -> DataContainer: 134 | """ 135 | Loads a HCP Dataset and preprocesses it, returning a DataContainer 136 | 137 | Parameters 138 | ---------- 139 | path 140 | The path of the HCP Dataset 141 | b0_threshold 142 | The threshold for the b0 image 143 | Returns 144 | ------- 145 | DataContainer 146 | A newly created DataContainer with the preprocessed HCP data. 147 | """ 148 | 149 | file_mapping = {'bvals': 'bvals', 'bvecs': 'bvecs', 'img': 'data.nii.gz', 150 | 't1': 'T1w_acpc_dc_restore_1.25.nii.gz', 'mask': 'nodif_brain_mask.nii.gz'} 151 | return self._get_from_file_mapping(path, file_mapping, b0_threshold) 152 | 153 | def get_ismrm(self, path: str, b0_threshold: float = 10.0) -> DataContainer: 154 | """ 155 | Loads a ISMRM Dataset and preprocesses it, returning a DataContainer 156 | 157 | Parameters 158 | ---------- 159 | path 160 | The path of the ISMRM Dataset 161 | b0_threshold 162 | The threshold for the b0 image 163 | Returns 164 | ------- 165 | DataContainer 166 | A newly created DataContainer with the preprocessed ISMRM data. 167 | """ 168 | file_mapping = {'bvals': 'Diffusion.bvals', 'bvecs': 'Diffusion.bvecs', 169 | 'img': 'Diffusion.nii.gz', 't1': 'T1.nii.gz'} 170 | return self._get_from_file_mapping(path, file_mapping, b0_threshold) 171 | 172 | def _get_from_file_mapping(self, path, file_mapping: dict, b0_threshold: float = 10.0): 173 | 174 | path_mapping = {key: os.path.join(path, file_mapping[key]) for key in file_mapping} 175 | bvals, bvecs = read_bvals_bvecs(path_mapping['bvals'], 176 | path_mapping['bvecs']) 177 | 178 | # img, t1, gradient table, affine and dwi 179 | img = nb.load(path_mapping['img']) 180 | t1 = nb.load(path_mapping['t1']).get_data() 181 | 182 | dwi = img.get_data().astype("float32") 183 | 184 | aff = img.affine 185 | 186 | # binary mask 187 | if 'mask' in path_mapping: 188 | binary_mask = nb.load(path_mapping['mask']).get_data() 189 | else: 190 | _, binary_mask = median_otsu(dwi[..., 0], 2, 1) 191 | 192 | # calculating b0 193 | b0 = dwi[..., bvals < b0_threshold].mean(axis=-1) 194 | 195 | # Do not generate fa yet 196 | fa = None 197 | gtab = gradient_table(bvals, bvecs) 198 | data_container = DataContainer(bvals, bvecs, gtab, t1, dwi, aff, binary_mask, b0, fa) 199 | return self._preprocess(data_container) 200 | 201 | 202 | class _DataCropper(DataPreprocessor): 203 | def __init__(self, parent, b_value, max_deviation, b0_threshold): 204 | super().__init__(parent) 205 | self.b_value = b_value 206 | self.max_deviation = max_deviation 207 | self.b0_threshold = b0_threshold 208 | 209 | def _preprocess(self, data_container: DataContainer) -> DataContainer: 210 | dc = \ 211 | super()._preprocess(data_container) 212 | 213 | mask = (np.abs(dc.bvals - self.b_value) < self.max_deviation) | (dc.bvals < self.b0_threshold) 214 | 215 | dwi = dc.dwi[..., mask] 216 | bvals = dc.bvals[mask] 217 | bvecs = dc.bvecs[mask] 218 | gtab = gradient_table(bvals, bvecs) 219 | return DataContainer(bvals, bvecs, gtab, dc.t1, dwi, dc.aff, dc.binary_mask, dc.b0, dc.fa) 220 | 221 | 222 | class _DataNormalizer(DataPreprocessor): 223 | def __init__(self, parent): 224 | super().__init__(parent) 225 | 226 | def _preprocess(self, data_container: DataContainer) -> DataContainer: 227 | dc = \ 228 | super()._preprocess(data_container) 229 | 230 | b0 = dc.b0[..., None] 231 | dwi = dc.dwi 232 | nb_erroneous_voxels = np.sum(dc.dwi > b0) 233 | if nb_erroneous_voxels != 0: 234 | dwi = np.minimum(dwi, b0) 235 | with warnings.catch_warnings(): 236 | warnings.simplefilter("ignore") 237 | dwi = dwi / b0 238 | dwi[np.logical_not(np.isfinite(dwi))] = 0. 239 | 240 | return DataContainer(dc.bvals, dc.bvecs, dc.gtab, dc.t1, dwi, dc.aff, dc.binary_mask, dc.b0, dc.fa) 241 | 242 | 243 | class _DataDenoiser(DataPreprocessor): 244 | def __init__(self, parent, smooth, patch_radius): 245 | super().__init__(parent) 246 | self.smooth = smooth 247 | self.patch_radius = patch_radius 248 | 249 | def _preprocess(self, data_container: DataContainer) -> DataContainer: 250 | dc = super()._preprocess(data_container) 251 | sigma = pca_noise_estimate(dc.dwi, dc.gtab, correct_bias=True, 252 | smooth=self.smooth) 253 | dwi = localpca(dc.dwi, sigma=sigma, 254 | patch_radius=self.patch_radius) 255 | return DataContainer(dc.bvals, dc.bvecs, dc.gtab, dc.t1, dwi, dc.aff, dc.binary_mask, dc.b0, dc.fa) 256 | 257 | 258 | class _DataFAEstimator(DataPreprocessor): 259 | def __init__(self, parent): 260 | super().__init__(parent) 261 | 262 | def _preprocess(self, data_container: DataContainer) -> DataContainer: 263 | dc = \ 264 | super()._preprocess(data_container) 265 | 266 | # calculating fractional anisotropy (fa) 267 | dti_model = dti.TensorModel(dc.gtab, fit_method='LS') 268 | dti_fit = dti_model.fit(dc.dwi, mask=dc.binary_mask) 269 | fa = dti_fit.fa 270 | return DataContainer(dc.bvals, dc.bvecs, dc.gtab, dc.t1, dc.dwi, dc.aff, dc.binary_mask, dc.b0, fa) 271 | 272 | 273 | class DataContainer(object): 274 | 275 | def __init__(self, bvals: np.ndarray, bvecs: np.ndarray, gtab: GradientTable, t1: np.ndarray, 276 | dwi: np.ndarray, aff: np.ndarray, binary_mask: np.ndarray, b0: np.ndarray, 277 | fa: Optional[np.ndarray]): 278 | self.bvals = bvals 279 | self.bvecs = bvecs 280 | self.t1 = t1 281 | self.dwi = dwi 282 | self.aff = aff 283 | self.binary_mask = binary_mask 284 | self.b0 = b0 285 | self.fa = fa 286 | self.gtab = gtab 287 | x_range = np.arange(dwi.shape[0]) 288 | y_range = np.arange(dwi.shape[1]) 289 | z_range = np.arange(dwi.shape[2]) 290 | self.fa_interpolator = RegularGridInterpolator((x_range, y_range, z_range), fa) if fa is not None else None 291 | self.interpolator = RegularGridInterpolator((x_range, y_range, z_range), dwi) 292 | 293 | def to_ijk(self, points: np.ndarray) -> np.ndarray: 294 | """ 295 | Converts given RAS+ points to IJK in DataContainers Image Coordinates. 296 | 297 | The conversion happens using the affine of the DWI image. 298 | It should be noted that the dimension of the given point array stays the same. 299 | 300 | Parameters 301 | ---------- 302 | points 303 | The points to convert. 304 | Returns 305 | ------- 306 | np.ndarray 307 | The converted points. 308 | """ 309 | 310 | aff = np.linalg.inv(self.aff) 311 | return apply_affine(aff, points) 312 | 313 | def to_ras(self, points: np.ndarray) -> np.ndarray: 314 | """ 315 | Converts given IJK points in DataContainers Coordinate System to RAS+. 316 | 317 | The conversion happens using the affine of the DWI image. 318 | It should be noted that the dimension of the given point array stays the same. 319 | 320 | 321 | Parameters 322 | ---------- 323 | points 324 | The points to convert. 325 | 326 | Returns 327 | ------- 328 | np.ndarray 329 | The converted points. 330 | """ 331 | return apply_affine(self.aff, points) 332 | 333 | def get_interpolated_dwi(self, points: np.ndarray, postprocessing: Optional[PostprocessingOption] = None, 334 | ignore_outside_points: bool = False) -> np.ndarray: 335 | """ 336 | Returns interpolated dwi for given RAS+ points. 337 | 338 | The shape of the input points will be retained for the return array, 339 | only the last dimension will be changed from 3 to the (interpolated) DWI-size accordingly. 340 | 341 | If you provide a postprocessing method, the interpolated data is then fed through this postprocessing option. 342 | 343 | Parameters 344 | ---------- 345 | points 346 | The array containing the points. Shape is matched in output. 347 | postprocessing 348 | A postprocessing method, e.g Resample100, Raw, SphericalHarmonics etc. 349 | which will be applied to the output. 350 | ignore_outside_points 351 | A boolean indicating whether an exception should be thrown if points lay outside of the DWI scan 352 | Returns 353 | ------- 354 | np.ndarray 355 | The DWI-Values interpolated for the given points. The input shape is matched aside of 356 | the last dimension. 357 | """ 358 | new_shape = (*points.shape[:-1], -1) 359 | 360 | points = self.to_ijk(points).reshape(-1, 3) 361 | 362 | is_outside = ((points[:, 0] < 0) + (points[:, 0] >= self.dwi.shape[0]) + # OR 363 | (points[:, 1] < 0) + (points[:, 1] >= self.dwi.shape[1]) + 364 | (points[:, 2] < 0) + (points[:, 2] >= self.dwi.shape[2])) > 0 365 | 366 | if np.sum(is_outside) > 0 and not ignore_outside_points: 367 | raise PointOutsideOfDWIError(self, self.to_ras(points), self.to_ras(points[is_outside])) 368 | 369 | points[is_outside, :] = 0 370 | result = self.interpolator(points) 371 | 372 | if postprocessing is not None: 373 | result = postprocessing.process(self, points, result) 374 | result[is_outside, :] = 0 375 | 376 | result = result.reshape(new_shape) 377 | return result 378 | 379 | def get_fa(self, coordinate): 380 | """Retrieves the FA values at a specific position. 381 | 382 | Returns 383 | ------- 384 | ndarray 385 | Fractional anisotropy (FA) calculated from cached eigenvalues. 386 | 387 | See Also 388 | -------- 389 | generate_fa: The method generating the fa values which are returned here. 390 | """ 391 | return self.fa_interpolator(coordinate) 392 | -------------------------------------------------------------------------------- /dfibert/tracker/nn/rl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | import numpy as np 6 | 7 | 8 | 9 | class ReplayMemory(object): 10 | """Replay Memory that stores the last size=1,000,000 transitions""" 11 | def __init__(self, size=1000000, shape=(100,3,3,3), 12 | agent_history_length=1, batch_size=32): 13 | """ 14 | Args: 15 | size: Integer, Number of stored transitions 16 | frame_height: Integer, Height of a frame of an Atari game 17 | frame_width: Integer, Width of a frame of an Atari game 18 | agent_history_length: Integer, Number of frames stacked together to create a state 19 | batch_size: Integer, Number if transitions returned in a minibatch 20 | """ 21 | self.size = size 22 | self.agent_history_length = agent_history_length 23 | self.batch_size = batch_size 24 | self.count = 0 25 | self.current = 0 26 | self.shape = shape 27 | 28 | # Pre-allocate memory 29 | self.actions = np.empty(self.size, dtype=np.int32) 30 | self.rewards = np.empty(self.size, dtype=np.float32) 31 | self.states = np.empty((self.size, *self.shape), dtype=np.float32) 32 | self.new_states = np.empty((self.size, *self.shape), dtype=np.float32) 33 | self.terminal_flags = np.empty(self.size, dtype=np.bool) 34 | 35 | self._indices = np.empty(self.batch_size, dtype=np.int32) 36 | 37 | def add_experience(self, action, state, reward, new_state, terminal): 38 | """ 39 | Args: 40 | action: An integer between 0 and env.action_space.n - 1 41 | determining the action the agent perfomed 42 | state: A (100, 3, 3, 3) matrix of interpolated DWI data 43 | reward: A float determining the reward the agend received for performing an action 44 | new_state: A (100, 3, 3, 3) matrix of interpolated DWI data 45 | terminal: A bool stating whether the episode terminated 46 | """ 47 | 48 | self.actions[self.current] = action 49 | self.states[self.current] = state 50 | self.rewards[self.current] = reward 51 | self.new_states[self.current] = new_state 52 | self.terminal_flags[self.current] = terminal 53 | self.count = max(self.count, self.current+1) 54 | self.current = (self.current + 1) % self.size 55 | 56 | def get_minibatch(self): 57 | """ 58 | Returns a minibatch of self.batch_size = 32 transitions 59 | """ 60 | if self.count < self.batch_size: 61 | raise ValueError('Not enough memories to get a minibatch') 62 | 63 | self._indices = np.random.randint(self.count, size=self.batch_size) 64 | return self.states[self._indices], self.actions[self._indices], self.rewards[self._indices], self.new_states[self._indices], self.terminal_flags[self._indices] 65 | 66 | 67 | class DQN(nn.Module): 68 | """ 69 | Main modell class. First 4 layers are convolutional layers, after that the model is split into the 70 | advantage and value stream. See the documentation. The convolutional layers are initialized with Kaiming He initialization. 71 | """ 72 | def __init__(self, input_shape, n_actions, hidden_size = 1024, num_hidden = 8, activation=torch.relu): 73 | super(DQN, self).__init__() 74 | self.linear_layers = nn.ModuleList() 75 | self.activation = activation 76 | self.init_layers(input_shape, n_actions, hidden_size, num_hidden) 77 | 78 | 79 | def init_layers(self, input_size, output_size, hidden_size, num_hidden): 80 | self.linear_layers.append(nn.Linear(input_size, hidden_size)) 81 | for _ in range(num_hidden): 82 | self.linear_layers.append(nn.Linear(hidden_size, hidden_size)) 83 | self.linear_layers.append(nn.Linear(hidden_size, output_size)) 84 | 85 | for m in self.linear_layers: 86 | if isinstance(m, nn.Linear): 87 | nn.init.xavier_normal_(m.weight) 88 | nn.init.constant_(m.bias, 0) 89 | 90 | def forward(self, x): 91 | x = x.reshape(x.size(0), -1) 92 | for i in range(len(self.linear_layers) - 1): 93 | x = self.linear_layers[i](x) 94 | x = self.activation(x) 95 | x = self.linear_layers[-1](x) 96 | return x 97 | 98 | 99 | class Agent(): 100 | """ 101 | The main agent that is optimized throughout the training process. The class consists of the two models (main and target), 102 | the optimizer and the memory. 103 | Args: 104 | n_actions: Integer, number of possible actions for the environment 105 | device: PyTorch device to which the models are sent to 106 | hidden: Integer, amount of hidden neurons in the model 107 | learning_rate: Float, learning rate for the training process 108 | gamma: Float, [0:1], discount factor for the loss function 109 | batch_size: Integer, specify size of each minibatch to process while learning 110 | agent_history_length: Integer, amount of stacked frames forming one transition 111 | memory_size: Integer, size of the replay memory 112 | """ 113 | def __init__(self, n_actions, device, inp_size, hidden=128, learning_rate=0.0000625, 114 | gamma=.9995, batch_size=32, agent_history_length=4, memory_size=1000000, 115 | epsilon=1.0, epsilon_decay=0.999, min_epsilon=0.01): 116 | 117 | self.n_actions = n_actions 118 | self.device = device 119 | self.inp_size = inp_size 120 | self.hidden = hidden 121 | self.lr = learning_rate 122 | self.batch_size = batch_size 123 | self.agent_history_length= agent_history_length 124 | self.gamma = gamma 125 | self.memory_size = memory_size 126 | self.epsilon = epsilon 127 | self.epsilon_decay = epsilon_decay 128 | self.min_epsilon = min_epsilon 129 | 130 | # Create 2 models 131 | self.main_dqn = DQN(n_actions=self.n_actions, input_shape=np.prod(np.array(self.inp_size))).to(device) 132 | self.target_dqn = DQN(n_actions=self.n_actions, input_shape=np.prod(np.array(self.inp_size))).to(device) 133 | # and send them to the device 134 | #self.main_dqn = self.main_dqn.to(self.device) 135 | #self.target_dqn = self.target_dqn.to(self.device) 136 | 137 | # Copy weights of the main model to the target model 138 | self.target_dqn.load_state_dict(self.main_dqn.state_dict()) 139 | # and freeze target model. The model will be updated every now an then (specified in main function) 140 | #self.target_dqn.eval() 141 | 142 | 143 | self.replay_memory = ReplayMemory(size=self.memory_size, shape=self.inp_size ,agent_history_length=self.agent_history_length, batch_size=self.batch_size) 144 | self.optimizer = torch.optim.Adam(self.main_dqn.parameters(), self.lr) 145 | #self.optimizer = torch.optim.SGD(self.main_dqn.parameters(), self.lr) 146 | 147 | def optimize(self): 148 | """ 149 | Optimize the main model. 150 | Returns: 151 | Float, the loss between the predicted Q values from the main model and the target Q values from the target model 152 | """ 153 | # get a minibatch of transitions 154 | states, actions, rewards, new_states, terminal_flags = self.replay_memory.get_minibatch() 155 | 156 | states = torch.from_numpy(states).to(self.device) 157 | next_states = torch.from_numpy(new_states).to(self.device) 158 | actions = torch.from_numpy(actions).unsqueeze(1).long().to(self.device) 159 | rewards = torch.from_numpy(rewards).to(self.device) 160 | terminal_flags = torch.from_numpy(terminal_flags).to(self.device) 161 | 162 | 163 | state_action_values = self.main_dqn(states).gather(1, actions).squeeze(-1) 164 | next_state_actions = torch.argmax(self.main_dqn(next_states), dim=1) 165 | next_state_values = self.target_dqn(next_states).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1) 166 | 167 | next_state_values[terminal_flags] = 0.0 168 | expected_state_action_values = next_state_values.detach() * self.gamma + rewards 169 | 170 | loss = nn.MSELoss()(state_action_values, expected_state_action_values) 171 | #loss = torch.nn.SmoothL1Loss()(state_action_values, expected_state_action_values) 172 | self.optimizer.zero_grad() 173 | loss.backward() 174 | self.optimizer.step() 175 | 176 | return loss 177 | 178 | def reduce_epsilon(self): 179 | self.epsilon = max(self.epsilon * self.epsilon_decay, self.min_epsilon) 180 | 181 | def predict_action(self, state): 182 | if random.random() < eps: 183 | action = np.random.randint(self.n_actions) # either random action 184 | else: # or action from agent 185 | self.main_dqn.eval() 186 | with torch.no_grad(): 187 | state_v = torch.from_numpy(state.getValue()).unsqueeze(0).float().to(device) 188 | action = torch.argmax(self.main_dqn(state_v)).item() 189 | self.main_dqn.train() 190 | return action 191 | 192 | def save_model(self, path_checkpoint, epoch, mean_reward, max_steps, start_learning, network_update_every, max_episode_length, evaluate_every, eval_runs): 193 | print("Writing checkpoint to %s" % (path_checkpoint)) 194 | checkpoint = {} 195 | checkpoint["epoch"] = epoch 196 | checkpoint["mean_reward"] = mean_reward 197 | checkpoint["max_steps"] = max_steps 198 | checkpoint["start_learning"] = start_learning 199 | checkpoint["network_update_every"] = network_update_every 200 | checkpoint["max_episode_length"] = max_episode_length 201 | checkpoint["evaluate_every"] = evaluate_every 202 | checkpoint["eval_runs"] = eval_runs 203 | 204 | checkpoint["model"] = self.main_dqn.state_dict() 205 | checkpoint["epsilon"] = self.epsilon 206 | checkpoint["batch_size"] = self.batch_size 207 | checkpoint["gamma"] = self.gamma 208 | checkpoint["memory_size"] = self.memory_size 209 | checkpoint["learning_rate"] = self.learning_rate 210 | checkpoint["state_shape"] = self.inp_size 211 | checkpoint["n_actions"] = self.n_actions 212 | torch.save(checkpoint, path_checkpoint) 213 | 214 | 215 | def load_model(path_checkpoint, overwrite=False): 216 | print("Loading checkpoint from %s" % (path_checkpoint)) 217 | checkpoint = torch.load(path_checkpoint) 218 | 219 | # load crucial parameters 220 | epoch = checkpoint['epoch'] 221 | self.inp_size = checkpoint["state_shape"] 222 | self.n_actions = checkpoint["n_actions"] 223 | self.epsilon = checkpoint['epsilon'] 224 | 225 | # re-initialize the models with loaded hyperparameters and state dict 226 | self.main_dqn = DQN(n_actions=self.n_actions, input_shape=np.prod(np.array(self.inp_size))).to(device) 227 | self.target_dqn = DQN(n_actions=self.n_actions, input_shape=np.prod(np.array(self.inp_size))).to(device) 228 | self.main_dqn.load_state_dict(checkpoint['model']) 229 | self.target_dqn.load_state_dict(checkpoint['model']) 230 | 231 | # load external parameters 232 | mean_reward = checkpoint['mean_reward'] 233 | max_steps = checkpoint["max_steps"] 234 | start_learning = checkpoint["start_learning"] 235 | network_update_every = checkpoint["network_update_every"] 236 | max_episode_length = checkpoint["max_episode_length"] 237 | evaluate_every = checkpoint["evaluate_every"] 238 | eval_runs = checkpoint["eval_runs"] 239 | 240 | # overwrite internal hyperparameters set at initialization with saved ones 241 | if overwrite: 242 | self.batch_size = checkpoint["batch_size"] 243 | self.gamma = checkpoint["gamma"] 244 | self.memory_size = checkpoint["memory_size"] 245 | self.learning_rate = checkpoint["learning_rate"] 246 | 247 | return epoch, mean_reward, max_steps, start_learning, network_update_every, max_episode_length, evaluate_every, eval_runs 248 | 249 | 250 | ''' 251 | class Action_Scheduler(): 252 | """Determines an action according to an epsilon greedy strategy with annealing epsilon""" 253 | def __init__(self, num_actions, model, eps_initial=1, eps_final=0.1, eps_final_step=0.01, 254 | eps_annealing_steps=1000000, replay_memory_start_size=50000, 255 | max_steps=25000000): 256 | """ 257 | Args: 258 | num_actions: Integer, number of possible actions 259 | model: A DQN object 260 | eps_initial: Float, Exploration probability for the first 261 | replay_memory_start_size frames 262 | eps_final: Float, Exploration probability after 263 | replay_memory_start_size + eps_annealing_frames frames 264 | eps_final_frame: Float, Exploration probability after max_frames frames 265 | eps_evaluation: Float, Exploration probability during evaluation 266 | eps_annealing_frames: Int, Number of frames over which the 267 | exploration probabilty is annealed from eps_initial to eps_final 268 | replay_memory_start_size: Integer, Number of frames during 269 | which the agent only explores 270 | max_frames: Integer, Total number of frames shown to the agent 271 | """ 272 | self.num_actions = num_actions 273 | self.eps_initial = eps_initial 274 | self.eps_final = eps_final 275 | self.eps_final_frame = eps_final_step 276 | self.eps_annealing_frames = eps_annealing_steps 277 | self.replay_memory_start_size = replay_memory_start_size 278 | self.max_frames = max_steps 279 | self.model = model 280 | 281 | self.eps_current = self.eps_initial 282 | self.slope = -(self.eps_initial - self.eps_final) / self.eps_annealing_frames 283 | self.intercept = self.eps_initial - self.slope*self.replay_memory_start_size 284 | self.slope_2 = - (self.eps_final - self.eps_final_frame) / (self.max_frames - self.eps_annealing_frames - self.replay_memory_start_size) 285 | self.intercept_2 = self.eps_final_frame - self.slope_2 * self.max_frames 286 | 287 | def get_action(self, frame_number, state, evaluation=False): 288 | if evaluation: 289 | self.eps_current = 0.0 290 | elif frame_number < self.replay_memory_start_size: 291 | self.eps_current = self.eps_initial 292 | elif frame_number >= self.replay_memory_start_size and frame_number < self.replay_memory_start_size + self.eps_annealing_frames: 293 | self.eps_current = self.slope * frame_number + self.intercept 294 | elif frame_number >= self.replay_memory_start_size + self.eps_annealing_frames: 295 | self.eps_current = self.slope_2 * frame_number + self.intercept_2 296 | 297 | if np.random.rand(1) < self.eps_current: 298 | return np.random.randint(0, self.num_actions) 299 | else: 300 | with torch.no_grad(): 301 | q_vals = self.model(state) 302 | #print("Q Values: ", q_vals) 303 | action = torch.argmax(q_vals).item() 304 | #print("Action: ", action ) 305 | return action 306 | ''' -------------------------------------------------------------------------------- /dfibert/dataset/processing.py: -------------------------------------------------------------------------------- 1 | """The processing submodule contains processing options for the raw streamline and DWI data. 2 | 3 | Classes 4 | ------- 5 | Processing 6 | The base class for all processing instructions 7 | RegressionProcessing 8 | The basic processing, calculates direction vectors out of streamlines and interpolates DWI along a grid 9 | ClassificationProcessing 10 | Based on RegressionProcessing, however it reshapes the regression problem of the direction vector as a classification problem. 11 | """ 12 | from types import SimpleNamespace 13 | import numpy as np 14 | 15 | from dipy.core.sphere import Sphere 16 | from dipy.data import get_sphere 17 | 18 | from dfibert.data import DataContainer 19 | from dfibert.util import get_reference_orientation, get_grid, apply_rotation_matrix_to_grid, \ 20 | direction_to_classification, rotation_from_vectors_p 21 | 22 | 23 | class Processing: 24 | """The basic Processing class. 25 | 26 | Every Processing should extend this function and implement the following: 27 | 28 | Methods 29 | ------- 30 | calculate_streamline(data_container, streamline) 31 | Calculates the (input, output) tuple for a complete streamline 32 | calculate_item(data_container, sl, next_direction) 33 | Calculates the (input, output) tuple for a single last streamline point 34 | 35 | The methods can work together, but they do not have to. 36 | The existence of both must be guaranteed to be able to use every dataset. 37 | """ 38 | 39 | # TODO - Live Calculation for Tracker 40 | def calculate_streamline(self, data_container, streamline): 41 | """Calculates the (input, output) tuple for a whole streamline. 42 | 43 | Arguments 44 | --------- 45 | data_container : DataContainer 46 | The DataContainer the streamline is associated with 47 | streamline: Tensor 48 | The streamline the input and output data should be calculated for 49 | 50 | Raises 51 | ------ 52 | NotImplementedError 53 | If the Processing subclass didn't overwrite the function. 54 | 55 | Returns 56 | ------- 57 | tuple 58 | The (input, output) data for the requested item. 59 | 60 | """ 61 | raise NotImplementedError 62 | 63 | def calculate_item(self, data_container : DataContainer, previous_sl, next_dir): 64 | """Calculates the (input, output) tuple for a single streamline point. 65 | 66 | Arguments 67 | --------- 68 | previous_sl 69 | The previous streamline 70 | data_container 71 | The DataContainer the streamline is associated with 72 | point: Tensor 73 | The point the data should be calculated for in RAS* 74 | next_dir: Tensor, optional 75 | The next direction, you do not have to provide it if you only need the input part. 76 | Raises 77 | ------ 78 | NotImplementedError 79 | If the Processing subclass didn't overwrite the function. 80 | 81 | Returns 82 | ------- 83 | tuple 84 | The (input, output) data for the requested item. 85 | """ 86 | raise NotImplementedError 87 | 88 | 89 | class RegressionProcessing(Processing): 90 | """Provides a Processing option for regression training. 91 | 92 | There are many configuration options specified in the constructor. 93 | An instance of this class has to be passed onto a Dataset. 94 | 95 | Attributes 96 | ---------- 97 | options: SimpleNamespace 98 | An object holding all configuration options of this dataset. 99 | grid: numpy.ndarray 100 | The grid, precalculated for this processing option 101 | id: str 102 | An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. 103 | 104 | Methods 105 | ------- 106 | calculate_streamline(data_container, streamline) 107 | Calculates the (input, output) tuple for a complete streamline 108 | calculate_item(data_container, point, next_direction) 109 | Calculates the (input, output) tuple for a single streamline point 110 | 111 | """ 112 | 113 | def __init__(self, rotate=True, grid_dimension=(3, 3, 3), grid_spacing=1.0, postprocessing=None, normalize=None, 114 | normalize_mean=(9.8811e-01, 2.6814e-04, 1.2876e-03), normalize_std=(0.0262, 0.1064, 0.1078)): 115 | """ 116 | 117 | If the parameters are passed as none, the value from the config.ini is used. 118 | 119 | Parameters 120 | ---------- 121 | rotate : bool, optional 122 | Indicates wether grid should be rotated along fiber, by default None 123 | grid_dimension : numpy.ndarray, optional 124 | Grid dimension (X,Y,Z) of the interpolation grid, by default None 125 | grid_spacing : float, optional 126 | Grid spacing, by default None 127 | postprocessing : data.postprocessing, optional 128 | The postprocessing to be done on the interpolated DWI, by default None 129 | normalize : bool, optional 130 | Indicates whether data should be normalized, by default None 131 | normalize_mean : numpy.ndarray, optional 132 | Give mean for normalization, by default None 133 | normalize_std : numpy.ndarray, optional 134 | Give std for normalization, by default None 135 | """ 136 | if isinstance(grid_dimension, tuple): 137 | grid_dimension = np.array(grid_dimension) 138 | 139 | normalize = normalize if normalize is not None else rotate 140 | 141 | self.options = SimpleNamespace() 142 | 143 | if rotate and normalize: 144 | if isinstance(normalize_mean, tuple): 145 | normalize_mean = np.array(normalize_mean) 146 | 147 | if isinstance(normalize_std, tuple): 148 | normalize_std = np.array(normalize_std) 149 | 150 | self.options.normalize_mean = normalize_mean 151 | self.options.normalize_std = normalize_std 152 | 153 | self.options.rotate = rotate 154 | self.options.normalize = normalize 155 | self.options.grid_dimension = grid_dimension 156 | self.options.grid_spacing = grid_spacing 157 | self.options.postprocessing = postprocessing 158 | self.grid = get_grid(grid_dimension) * grid_spacing 159 | 160 | self.id = "RegressionProcessing-r{}-grid{}x{}x{}-spacing{}-postprocessing-{}".format(rotate, *grid_dimension, 161 | grid_spacing, 162 | postprocessing.id) 163 | 164 | def calculate_item(self, data_container, previous_sl, next_dir): 165 | """Calculates the (input, output) tuple for the last streamline point. 166 | 167 | Arguments 168 | --------- 169 | data_container : DataContainer 170 | The DataContainer the streamline is associated with 171 | previous_sl: np.array 172 | The previous streamline point including the point the data should be calculated for in RAS* 173 | next_dir: Tensor 174 | The next direction, provide a null vector [0,0,0] if it is irrelevant. 175 | 176 | Returns 177 | ------- 178 | tuple 179 | The (input, output) data for the requested item. 180 | """ 181 | # create artificial next_dirs consisting of last and next dir for rot_mat calculation 182 | next_dirs = np.concatenate(((previous_sl[1:] - previous_sl[:-1])[-1:], next_dir[np.newaxis, ...])) 183 | # TODO - normalize direction vectors 184 | next_dirs, rot_matrix = self._apply_rot_matrix(next_dirs) 185 | 186 | next_dir = next_dirs[-1] 187 | rot_matrix = None if rot_matrix is None else rot_matrix[np.newaxis, -1] 188 | dwi, _ = self._get_dwi(data_container, previous_sl[np.newaxis, -1], rot_matrix=rot_matrix) 189 | if self.options.postprocessing is not None: 190 | dwi = self.options.postprocessing(dwi, data_container.b0, 191 | data_container.bvecs, 192 | data_container.bvals) 193 | dwi = dwi.squeeze(axis=0) 194 | if self.options.normalize: 195 | next_dir = (next_dir - self.options.normalize_mean) / self.options.normalize_std 196 | return dwi, next_dir 197 | 198 | def calculate_streamline(self, data_container, streamline): 199 | """Calculates the (input, output) tuple for a whole streamline. 200 | 201 | Arguments 202 | --------- 203 | data_container : DataContainer 204 | The DataContainer the streamline is associated with 205 | streamline: Tensor 206 | The streamline the input and output data should be calculated for 207 | 208 | Returns 209 | ------- 210 | tuple 211 | The (input, output) data for the requested item. 212 | 213 | """ 214 | next_dir = self._get_next_direction(streamline) 215 | next_dir, rot_matrix = self._apply_rot_matrix(next_dir) 216 | dwi, _ = self._get_dwi(data_container, streamline, rot_matrix=rot_matrix, 217 | postprocessing=self.options.postprocessing) 218 | if self.options.postprocessing is not None: 219 | dwi = self.options.postprocessing(dwi, data_container.b0, 220 | data_container.bvecs, 221 | data_container.bvals) 222 | if self.options.normalize: 223 | next_dir = (next_dir - self.options.normalize_mean) / self.options.normalize_std 224 | return (dwi, next_dir) 225 | 226 | def _get_dwi(self, data_container, streamline, rot_matrix=None, postprocessing=None): 227 | points = self._get_grid_points(streamline, rot_matrix=rot_matrix) 228 | dwi = data_container.get_interpolated_dwi(points, postprocessing=postprocessing) 229 | return dwi, points 230 | 231 | def _get_next_direction(self, streamline): 232 | next_dir = streamline[1:] - streamline[:-1] 233 | next_dir = next_dir / np.linalg.norm(next_dir, axis=1)[:, None] 234 | next_dir = np.concatenate((next_dir, np.array([[0, 0, 0]]))) 235 | return next_dir 236 | 237 | def _apply_rot_matrix(self, next_dir): 238 | if not self.options.rotate: 239 | return next_dir, None 240 | reference = get_reference_orientation() 241 | rot_matrix = np.empty([len(next_dir), 3, 3]) 242 | # rot_mat (N, 3, 3) 243 | # next dir (N, 3) 244 | rot_matrix[0] = np.eye(3) 245 | rotation_from_vectors_p(rot_matrix[1:, :, :], reference[None, :], next_dir[:-1]) 246 | 247 | rot_next_dir = (rot_matrix.transpose((0, 2, 1)) @ next_dir[:, :, None]).squeeze(2) 248 | return rot_next_dir, rot_matrix 249 | 250 | def _get_grid_points(self, streamline, rot_matrix=None): 251 | grid = self.grid 252 | if rot_matrix is not None: 253 | grid = apply_rotation_matrix_to_grid(grid, rot_matrix) 254 | # shape [N x R x A x S x 3] or [R x A x S x 3] 255 | points = streamline[:, None, None, None, :] + grid 256 | return points 257 | 258 | 259 | class ClassificationProcessing(RegressionProcessing): 260 | """Provides a Processing option for regression training. 261 | 262 | There are many configuration options specified in the constructor. 263 | An instance of this class has to be passed onto a Dataset. 264 | 265 | Attributes 266 | ---------- 267 | options: SimpleNamespace 268 | An object holding all configuration options of this dataset. 269 | grid: numpy.ndarray 270 | The grid, precalculated for this processing option 271 | id: str 272 | An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. 273 | 274 | Methods 275 | ------- 276 | calculate_streamline(data_container, streamline) 277 | Calculates the (input, output) tuple for a complete streamline 278 | calculate_item(data_container, point, next_direction) 279 | Calculates the (input, output) tuple for a single streamline point 280 | """ 281 | 282 | def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postprocessing=None, 283 | sphere="repulsion724"): 284 | """ 285 | 286 | If the parameters are passed as none, the value from the config.ini is used. 287 | 288 | Parameters 289 | ---------- 290 | rotate : bool, optional 291 | Indicates wether grid should be rotated along fiber, by default None 292 | grid_dimension : numpy.ndarray, optional 293 | Grid dimension (X,Y,Z) of the interpolation grid, by default None 294 | grid_spacing : float, optional 295 | Grid spacing, by default None 296 | postprocessing : data.postprocessing, optional 297 | The postprocessing to be done on the interpolated DWI, by default None 298 | sphere : Sphere or str, optional 299 | The sphere to use for interpolation 300 | """ 301 | 302 | RegressionProcessing.__init__(self, rotate=rotate, grid_dimension=grid_dimension, 303 | grid_spacing=grid_spacing, postprocessing=postprocessing, 304 | normalize=False) 305 | if isinstance(sphere, Sphere): 306 | real_sphere = sphere 307 | sphere = "custom" 308 | else: 309 | real_sphere = get_sphere(sphere) 310 | self.sphere = real_sphere 311 | self.options.sphere = sphere 312 | self.id = ("ClassificationProcessing-r{}-sphere-{}-grid{}x{}x{}-spacing{}-postprocessing-{}" 313 | .format(self.options.rotate, self.options.sphere, *self.options.grid_dimension, 314 | self.options.grid_spacing, self.options.postprocessing.id)) 315 | 316 | def calculate_streamline(self, data_container, streamline): 317 | """Calculates the classification (input, output) tuple for a whole streamline. 318 | 319 | Arguments 320 | --------- 321 | data_container : DataContainer 322 | The DataContainer the streamline is associated with 323 | streamline: Tensor 324 | The streamline the input and output data should be calculated for 325 | 326 | Returns 327 | ------- 328 | tuple 329 | The (input, output) data for the requested item. 330 | 331 | """ 332 | dwi, next_dir = RegressionProcessing.calculate_streamline(self, data_container, streamline) 333 | classification_output = direction_to_classification(self.sphere, next_dir, include_stop=True, last_is_stop=True) 334 | return dwi, classification_output 335 | 336 | def calculate_item(self, data_container, previous_sl, next_dir): 337 | """Calculates the classification (input, output) tuple for the last streamline point. 338 | 339 | Arguments 340 | --------- 341 | data_container : DataContainer 342 | The DataContainer the streamline is associated with 343 | previous_sl: np.array 344 | The previous streamline point including the point the data should be calculated for in RAS* 345 | next_dir: Tensor 346 | The next direction, provide a null vector [0,0,0] if it is irrelevant. 347 | 348 | Returns 349 | ------- 350 | tuple 351 | The (input, output) data for the requested item. 352 | """ 353 | dwi, next_dir = RegressionProcessing.calculate_item(data_container, previous_sl, next_dir) 354 | classification_output = direction_to_classification(self.sphere, next_dir[None, ...], include_stop=True, 355 | last_is_stop=True).squeeze(axis=0) 356 | return dwi, classification_output 357 | -------------------------------------------------------------------------------- /examples/rlWorkflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "source": [ 7 | "%load_ext autoreload\n", 8 | "%autoreload 2" 9 | ], 10 | "outputs": [], 11 | "metadata": {} 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "source": [ 17 | "import torch\n", 18 | "import torch.nn.functional as F\n", 19 | "import numpy as np\n", 20 | "import random\n", 21 | "import sys\n", 22 | "sys.path.insert(0,'..')\n", 23 | "\n", 24 | "from dfibert.tracker.nn.rl import Agent\n", 25 | "import dfibert.envs.RLTractEnvironment_fast as RLTe\n", 26 | "\n", 27 | "from dfibert.tracker import save_streamlines\n", 28 | "\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "%matplotlib notebook\n", 31 | "\n", 32 | "#from train import load_model" 33 | ], 34 | "outputs": [], 35 | "metadata": {} 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "source": [ 40 | "# I. HCP Tracking\n", 41 | "The environment is able to run tracking on a fixed set of datasets. At the moment, it is able to load HCP data as well as ISMRM data. The following cells shows the initalisation of our environment on HCP dataset `100307` while seed points are automatically determined at voxels with fa-value >= 0.2 via `seeds = None`." 42 | ], 43 | "metadata": {} 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "source": [ 49 | "env = RLTe.RLTractEnvironment(step_width=0.8, dataset = '100307',\n", 50 | " device = 'cpu', seeds = None, tracking_in_RAS = False,\n", 51 | " odf_state = False, odf_mode = \"DTI\")" 52 | ], 53 | "outputs": [], 54 | "metadata": {} 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "source": [ 60 | "streamlines = env.track()" 61 | ], 62 | "outputs": [], 63 | "metadata": { 64 | "scrolled": true 65 | } 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "source": [ 70 | "We can also directly visualize our streamlines in this notebook by `ax.plot3d`. However, a single streamline is typically very hard to comprehend so this is merely one tool to qualitatively reason about major bugs in our tracking code." 71 | ], 72 | "metadata": {} 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "source": [ 78 | "%matplotlib notebook\n", 79 | "streamline_index = 9\n", 80 | "streamline_np = np.stack(streamlines[streamline_index])\n", 81 | "\n", 82 | "fig = plt.figure()\n", 83 | "ax = plt.axes(projection='3d')\n", 84 | "#ax.plot3D(env.referenceStreamline_ijk.T[0], env.referenceStreamline_ijk.T[1], env.referenceStreamline_ijk.T[2], '-*')\n", 85 | "ax.plot3D(streamline_np[:,0], streamline_np[:,1], streamline_np[:,2])\n", 86 | "#plt.legend(['gt', 'agent'])\n", 87 | "plt.legend('agent')" 88 | ], 89 | "outputs": [], 90 | "metadata": { 91 | "scrolled": true 92 | } 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "source": [ 97 | "# II. Evaluation of Cortico Spinal Tract @ ISMRM benchmark data\n", 98 | "We will now be using our environment along with our reward function to track streamlines on the ISMRM dataset. For this purpose, we first initialise our environment and set seed points to the cortico spinal tract. We precomputed seed points in IJK for our ISMRM dataset. These seeds will now be loaded into our environment." 99 | ], 100 | "metadata": {} 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "seeds_CST = np.load('data/ismrm_seeds_CST.npy')\n", 109 | "seeds_CST = torch.from_numpy(seeds_CST)" 110 | ], 111 | "outputs": [], 112 | "metadata": {} 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "source": [ 118 | "env = RLTe.RLTractEnvironment(dataset = 'ISMRM', step_width=0.8,\n", 119 | " device = 'cpu', seeds = seeds_CST[0:100,:], action_space=100,\n", 120 | " tracking_in_RAS = False, odf_state = False, odf_mode = \"DTI\")" 121 | ], 122 | "outputs": [], 123 | "metadata": {} 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "source": [ 128 | "Tracking itself can now be done by basically calling the `.track()` function that tracks our streamlines from each of the provided seed points in a forward and backward direciton." 129 | ], 130 | "metadata": {} 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "source": [ 136 | "streamlines = env.track()" 137 | ], 138 | "outputs": [], 139 | "metadata": {} 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "source": [ 144 | "The streamlines are now stored as VTK file. The nice thing about this format is that we can directly import the streamlines into 3dSlicer via the slicer-dMRI extension." 145 | ], 146 | "metadata": {} 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "source": [ 152 | "streamlines_ras = [env.dataset.to_ras(sl) for sl in streamlines]" 153 | ], 154 | "outputs": [], 155 | "metadata": {} 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "source": [ 161 | "save_streamlines(streamlines=streamlines_ras, path=\"ismrm_cst2_ras_100actions_hemi.vtk\")" 162 | ], 163 | "outputs": [], 164 | "metadata": {} 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "source": [ 170 | "def convPoint(p, dims):\n", 171 | " dims = dims - 1\n", 172 | " return (p - dims/2.) / (dims/2.)\n", 173 | "\n", 174 | "def interpolate3dAtt(data, positions):\n", 175 | " # Data is supposed to be CxHxWxD\n", 176 | " # normalise coordinates into range [-1,1]\n", 177 | " pts = positions.to(torch.float)\n", 178 | " pts = convPoint(pts, torch.tensor(data.shape[1:4]))\n", 179 | " # reverse pts\n", 180 | " pts = pts[:,(2,1,0)]\n", 181 | " # trilinear interpolation\n", 182 | " return torch.nn.functional.grid_sample(data.unsqueeze(0), \n", 183 | " pts.unsqueeze(0).unsqueeze(0).unsqueeze(0),\n", 184 | " align_corners = False, mode = \"nearest\")" 185 | ], 186 | "outputs": [], 187 | "metadata": {} 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "source": [ 193 | "interpolate3dAtt(env.tractMasksAllBundles, torch.from_numpy(np.array([[30,50,30]]))).squeeze().shape" 194 | ], 195 | "outputs": [], 196 | "metadata": {} 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "source": [ 202 | "torch.mean(na_reward_history, dim = 0)" 203 | ], 204 | "outputs": [], 205 | "metadata": {} 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "source": [ 211 | "na_reward_history[0,:] = 1" 212 | ], 213 | "outputs": [], 214 | "metadata": {} 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "source": [ 220 | "na_reward_history = torch.zeros((env.maxSteps, env.tractMasksAllBundles.shape[0]))" 221 | ], 222 | "outputs": [], 223 | "metadata": {} 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "source": [ 229 | "from torch.utils.data import Dataset, DataLoader\n", 230 | "from dfibert.tracker import save_streamlines, load_streamlines\n", 231 | "\n", 232 | "class FiberBundleDatasetv2(Dataset):\n", 233 | " def __init__(self, path_to_files, b_val = 1000, device = \"cpu\", dataset = None):\n", 234 | " streamlines = load_streamlines(path=path_to_files)\n", 235 | " \n", 236 | " if(dataset is None):\n", 237 | " preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate()\n", 238 | " dataset = preprocessor.get_ismrm(f\"data/ISMRM2015/\")\n", 239 | " self.dataset = dataset\n", 240 | " self.streamlines = [torch.from_numpy(self.dataset.to_ijk(sl)).to(device) for sl in streamlines]\n", 241 | " self.tractMask = torch.zeros(self.dataset.binary_mask.shape)\n", 242 | " \n", 243 | " for sl in self.streamlines:\n", 244 | " pi = torch.floor(sl).to(torch.long)\n", 245 | " self.tractMask[pi.chunk(chunks=3, dim = 1)] = 1\n", 246 | " \n", 247 | " def __len__(self):\n", 248 | " return len(self.streamlines)\n", 249 | " \n", 250 | " def __getitem__(self, idx):\n", 251 | " streamline = self.streamlines[idx]\n", 252 | " sl_1 = streamline[0:-2]\n", 253 | " sl_2 = streamline[1:-1]\n", 254 | " return sl_1, sl_2\n" 255 | ], 256 | "outputs": [], 257 | "metadata": {} 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "source": [ 263 | "fibv2.streamlines[0].chunk(chunks=3, dim = 1)[3]" 264 | ], 265 | "outputs": [], 266 | "metadata": {} 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "source": [ 272 | "fibv2 = FiberBundleDatasetv2(path_to_files=\"data/ISMRM2015/gt_bundles/SLF_left.fib\", dataset = dataset)" 273 | ], 274 | "outputs": [], 275 | "metadata": {} 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "source": [ 281 | "fibv1 = FiberBundleDataset(path_to_files=\"data/ISMRM2015/gt_bundles/SLF_left.fib\", dataset = dataset)" 282 | ], 283 | "outputs": [], 284 | "metadata": {} 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "source": [ 290 | "torch.sum(fibv2.tractMask)" 291 | ], 292 | "outputs": [], 293 | "metadata": {} 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "source": [ 299 | "torch.sum(fibv1.tractMask)" 300 | ], 301 | "outputs": [], 302 | "metadata": {} 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "source": [ 307 | "# Reinforcement Learning\n", 308 | "## DQN" 309 | ], 310 | "metadata": {} 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "source": [ 316 | "%matplotlib inline\n", 317 | "from dfibert.envs.NARLTractEnvironment import NARLTractEnvironment as RLEnv" 318 | ], 319 | "outputs": [], 320 | "metadata": {} 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "source": [ 326 | "# init env\n", 327 | "#seeds_CST = np.load('data/ismrm_seeds_CST.npy')\n", 328 | "#seeds_CST = torch.from_numpy(seeds_CST)\n", 329 | "env = RLEnv(dataset = 'ISMRM', step_width=0.2,\n", 330 | " device = 'cpu', action_space=20,\n", 331 | " odf_mode = \"CSD\")#, seeds = seeds_CST)" 332 | ], 333 | "outputs": [], 334 | "metadata": {} 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "source": [ 340 | "from dipy.tracking.utils import random_seeds_from_mask\n", 341 | "temp_seeds = env.seeds\n", 342 | "env.seeds = random_seeds_from_mask(env.dataset.binary_mask,\n", 343 | "seeds_count=10000,\n", 344 | "seed_count_per_voxel=False,\n", 345 | "affine=env.dataset.aff)" 346 | ], 347 | "outputs": [], 348 | "metadata": {} 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "source": [ 354 | "from dfibert.tracker.nn.rainbow_agent import DQNAgent" 355 | ], 356 | "outputs": [], 357 | "metadata": {} 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "source": [ 363 | "# Hyperparameters:\n", 364 | "replay_memory_size = 100000\n", 365 | "batch_size = 512\n", 366 | "target_update = 10000\n", 367 | "gamma = 0.95\n", 368 | "max_steps = 60000000\n", 369 | "\n", 370 | "path = './training_lower_stepwidth'" 371 | ], 372 | "outputs": [], 373 | "metadata": {} 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "source": [ 379 | "agent = DQNAgent(env=env, memory_size = replay_memory_size,\n", 380 | " batch_size = batch_size,\n", 381 | " target_update = target_update,\n", 382 | " gamma = gamma)" 383 | ], 384 | "outputs": [], 385 | "metadata": {} 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "source": [ 391 | "# start training\n", 392 | "%matplotlib inline\n", 393 | "agent.train(num_steps = max_steps, checkpoint_interval=2000, path = path, plot=True)" 394 | ], 395 | "outputs": [], 396 | "metadata": { 397 | "scrolled": false 398 | } 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "source": [ 404 | "# resume the training process\n", 405 | "agent.resume_training(path='./training_test/checkpoints/rainbow_14000_16.65.pth', plot=True)" 406 | ], 407 | "outputs": [], 408 | "metadata": {} 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "source": [ 414 | "# load a saved checkpoint\n", 415 | "agent = DQNAgent(env=env, \n", 416 | " memory_size = replay_memory_size, # memory + batch size and target update will be overwritten with the\n", 417 | " batch_size = batch_size, # saved parameters\n", 418 | " target_update = target_update)\n", 419 | "num_steps, rewards, losses, max_steps = agent._load_model('./training_test/checkpoints/rainbow_248000_15.00.pth')" 420 | ], 421 | "outputs": [], 422 | "metadata": {} 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "source": [ 428 | "seeds = env.seeds\n", 429 | "env.seeds = env.seeds[:100]" 430 | ], 431 | "outputs": [], 432 | "metadata": {} 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "source": [ 438 | "# calculate and save tractogram with trained agent\n", 439 | "streamlines = agent.create_tractogram(path=\"ismrm_defi_15.0.vtk\")\n", 440 | "#streamlines = env.track()" 441 | ], 442 | "outputs": [], 443 | "metadata": {} 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "source": [ 449 | "save_streamlines(streamlines=streamlines, path=\"./ismrm_defi_15.0.vtk\")" 450 | ], 451 | "outputs": [], 452 | "metadata": {} 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "source": [ 458 | "counter = 0\n", 459 | "for i in range(len(streamlines)):\n", 460 | " if len(streamlines[i])>10:\n", 461 | " counter +=1\n", 462 | "\n", 463 | "print(counter)" 464 | ], 465 | "outputs": [], 466 | "metadata": {} 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "source": [ 472 | "# plot rewards and losses for loaded checkpoint\n", 473 | "%matplotlib inline\n", 474 | "agent._plot(num_steps, rewards, losses)" 475 | ], 476 | "outputs": [], 477 | "metadata": {} 478 | } 479 | ], 480 | "metadata": { 481 | "kernelspec": { 482 | "name": "python3", 483 | "display_name": "Python 3.9.7 64-bit ('defi': conda)" 484 | }, 485 | "language_info": { 486 | "codemirror_mode": { 487 | "name": "ipython", 488 | "version": 3 489 | }, 490 | "file_extension": ".py", 491 | "mimetype": "text/x-python", 492 | "name": "python", 493 | "nbconvert_exporter": "python", 494 | "pygments_lexer": "ipython3", 495 | "version": "3.9.7" 496 | }, 497 | "latex_envs": { 498 | "LaTeX_envs_menu_present": true, 499 | "autoclose": false, 500 | "autocomplete": true, 501 | "bibliofile": "biblio.bib", 502 | "cite_by": "apalike", 503 | "current_citInitial": 1, 504 | "eqLabelWithNumbers": true, 505 | "eqNumInitial": 1, 506 | "hotkeys": { 507 | "equation": "Ctrl-E", 508 | "itemize": "Ctrl-I" 509 | }, 510 | "labels_anchors": false, 511 | "latex_user_defs": false, 512 | "report_style_numbering": false, 513 | "user_envs_cfg": false 514 | }, 515 | "interpreter": { 516 | "hash": "f91776d31cb0b04e7fab63167ea94962184dc442b6d58ab0d535daf835db8614" 517 | } 518 | }, 519 | "nbformat": 4, 520 | "nbformat_minor": 4 521 | } 522 | -------------------------------------------------------------------------------- /dfibert/envs/NARLTractEnvironment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional 4 | from collections import deque 5 | import dipy.reconst.dti as dti 6 | import gym 7 | import numpy as np 8 | from dipy.core.sphere import HemiSphere 9 | from dipy.core.sphere import disperse_charges 10 | from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel 11 | from dipy.reconst.csdeconv import (mask_for_response_ssst, 12 | response_from_mask_ssst) 13 | from dipy.tracking import utils 14 | from gym.spaces import Discrete 15 | from tqdm import trange 16 | 17 | from dfibert.data import DataPreprocessor 18 | from dfibert.data.postprocessing import Resample, Resample100 19 | from dfibert.util import get_grid 20 | from ..tracker import load_streamlines 21 | 22 | import torch 23 | 24 | class TorchGridInterpolator: 25 | def __init__(self, data) -> None: 26 | self.data = data.float() # [X,Y,Z,C] 27 | self.data = data.permute((3,0,1,2)).unsqueeze(0) #compatible to torch>1.7 28 | #self.data = self.data.moveaxis(3,0).unsqueeze(0) # [1, C, X, Y, Z], requires torch>1.10 29 | self.interpol_transform = (torch.tensor(data.shape[:3], device=self.data.device) - 1) / 2 30 | 31 | def _convert_points(self, pts): 32 | return (pts[:, (2,1,0)] / self.interpol_transform) - 1 33 | 34 | def __call__(self, pts) -> None: 35 | new_shape = (*pts.shape[:-1], -1) 36 | pts = pts.reshape(-1, 3) 37 | pts = self._convert_points(pts.float()) 38 | pts = pts.unsqueeze(0).unsqueeze(0).unsqueeze(0) 39 | #print(pts.dtype, pts.shape) 40 | #print(self.data.dtype, self.data.shape) 41 | interpolated = torch.nn.functional.grid_sample(self.data, pts, align_corners=True, mode="bilinear") 42 | # [1 , C, 1, 1, N] 43 | interpolated = interpolated.reshape((self.data.shape[1], -1)) 44 | return interpolated.permute((1,0)).reshape(new_shape) 45 | 46 | 47 | def get_uniform_hemisphere_with_points(action_space: int, seed=42) -> HemiSphere: 48 | if seed is not None: 49 | np.random.seed(seed) 50 | 51 | phi = np.pi * np.random.rand(action_space) 52 | theta = 2 * np.pi * np.random.rand(action_space) 53 | sphere = HemiSphere(theta=theta, phi=phi) # Sphere(theta=theta, phi=phi) 54 | sphere, _ = disperse_charges(sphere, 5000) # enforce uniform distribution of our points 55 | 56 | return sphere 57 | 58 | 59 | def get_tract_mask(path_to_files, dataset): 60 | streamlines = load_streamlines(path=str(path_to_files)) 61 | streamlines = [dataset.to_ijk(sl) for sl in streamlines] 62 | tract_mask = np.zeros(dataset.binary_mask.shape) 63 | 64 | for sl in streamlines: 65 | pi = np.floor(sl).astype(int) 66 | tract_mask[pi[:, 0], pi[:, 1], pi[:, 2]] = 1 67 | return tract_mask 68 | 69 | 70 | def get_all_tract_mask(bundle_path, dataset): 71 | tract_masks = np.stack([get_tract_mask(bundle_path / file, dataset=dataset) for file in os.listdir(bundle_path)]) 72 | return np.moveaxis(tract_masks, 0, -1) # X * Y * Z * Bundle 73 | 74 | 75 | class TractographyState: 76 | def __init__(self, coordinate, interpol_func): 77 | self.coordinate = coordinate 78 | self.interpol_func = interpol_func 79 | self.interpol_dwi = None 80 | 81 | def get_coordinate(self): 82 | return self.coordinate 83 | 84 | def get_value(self): 85 | if self.interpol_dwi is None: 86 | # interpolate DWI value at self.coordinate 87 | self.interpol_dwi = self.interpol_func(self.coordinate) 88 | return self.interpol_dwi 89 | 90 | def __add__(self, other): 91 | if isinstance(other, torch.Tensor): 92 | return TractographyState(self.get_coordinate() + other, self.interpol_func) 93 | elif isinstance(other, TractographyState): 94 | return TractographyState(self.get_coordinate() + other.get_coordinate(), self.interpol_func) 95 | else: 96 | raise NotImplementedError() 97 | 98 | def __sub__(self, other): 99 | if isinstance(other, torch.Tensor): 100 | return TractographyState(self.get_coordinate() - other, self.interpol_func) 101 | elif isinstance(other, TractographyState): 102 | return TractographyState(self.get_coordinate() - other.get_coordinate(), self.interpol_func) 103 | else: 104 | raise NotImplementedError() 105 | 106 | 107 | class RLTractEnvironment(gym.Env): 108 | def __init__(self, device, seeds=None, dataset="100307", step_width=0.8, b_val=1000, action_space=100, 109 | grid_dim=(3, 3, 3), 110 | max_steps=2000, fa_threshold=0.2, bundles_path="data/gt_bundles/", odf_mode="CSD"): 111 | 112 | print("Loading dataset # ", dataset) 113 | self.device = device 114 | preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate() 115 | if dataset == 'ISMRM': 116 | self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/") 117 | else: 118 | self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/") 119 | self.sphere = get_uniform_hemisphere_with_points(action_space=action_space) 120 | self.directions = torch.from_numpy(self.sphere.vertices).to(device=device) 121 | self.grid = torch.from_numpy(get_grid(np.array(grid_dim))).to(device=device) 122 | self.action_space = Discrete(action_space) 123 | 124 | if seeds is None: 125 | seeds = utils.seeds_from_mask(self.dataset.binary_mask, self.dataset.aff) 126 | self.seeds = seeds.to(device=device).float() # IJK 127 | print("[I] seeds are supposed to be in IJK") 128 | self.max_steps = max_steps 129 | self.step_width = step_width 130 | 131 | self.dwi = torch.from_numpy(Resample100().process(self.dataset, None, self.dataset.dwi)).to(device=device).float() 132 | self.dwi_processor = TorchGridInterpolator(self.dwi) 133 | self.binary_mask = torch.from_numpy(self.dataset.binary_mask).to(device=device) 134 | self.fa_interpolator = TorchGridInterpolator(torch.from_numpy(self.dataset.fa).to(device=device).unsqueeze(-1).float()) 135 | self.fa_threshold = fa_threshold 136 | 137 | self._init_na(Path(bundles_path)) 138 | self._init_odf(odf_mode=odf_mode) 139 | self.ras_aff = torch.from_numpy(self.dataset.aff).to(device=device).float() 140 | self.ijk_aff = self.ras_aff.inverse().float() 141 | self.state: Optional[TractographyState] = None 142 | self.no_steps = 0 143 | self.state_history = torch.zeros((self.max_steps + 1, 3)) 144 | self.na_reward_history = torch.zeros((self.max_steps, self.tract_masks.shape[-1])) 145 | self.reset() 146 | 147 | def _init_na(self, bundles_path): 148 | self.tract_masks = torch.from_numpy(get_all_tract_mask(bundles_path, self.dataset)).to(device=self.device).float() 149 | self.na_interpolator = TorchGridInterpolator(self.tract_masks) 150 | 151 | def _init_odf(self, odf_mode): 152 | print("Initialising ODF") 153 | # fit DTI model to data 154 | if odf_mode == "DTI": 155 | print("DTI-based ODF computation") 156 | dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS') 157 | dti_fit = dti_model.fit(self.dataset.dwi, mask=self.dataset.binary_mask) 158 | # compute ODF 159 | odf = dti_fit.odf(self.sphere) 160 | elif odf_mode == "CSD": 161 | print("CSD-based ODF computation") 162 | mask = mask_for_response_ssst(self.dataset.gtab, self.dataset.dwi, roi_radii=10, fa_thr=0.7) 163 | response, ratio = response_from_mask_ssst(self.dataset.gtab, self.dataset.dwi, mask) 164 | dti_model = ConstrainedSphericalDeconvModel(self.dataset.gtab, response) 165 | dti_fit = dti_model.fit(self.dataset.dwi) 166 | odf = dti_fit.odf(self.sphere) 167 | else: 168 | raise NotImplementedError("ODF mode not found") 169 | # -- set up interpolator for odf evaluation 170 | odf = torch.from_numpy(odf).to(device=self.device).float() 171 | 172 | self.odf_interpolator = TorchGridInterpolator(odf) 173 | 174 | def step(self, action, backwards=False): 175 | ijk_coordinate = self.state.get_coordinate() 176 | odf_cur = self.odf_interpolator(ijk_coordinate).squeeze() 177 | 178 | if torch.max(odf_cur) > 0: 179 | odf_cur = odf_cur / torch.max(odf_cur) 180 | 181 | if self.no_steps >= self.max_steps: 182 | #print("#1") 183 | return self.state.get_value(), 0., True, {} 184 | if self.fa_interpolator(ijk_coordinate) < self.fa_threshold: 185 | #print("#2") 186 | #print(ijk_coordinate) 187 | #print(self.fa_interpolator(ijk_coordinate)) 188 | return self.state.get_value(), 0., True, {} 189 | 190 | if self.binary_mask[int(ijk_coordinate[0]), int(ijk_coordinate[1]), int(ijk_coordinate[2])] == 0: 191 | #print("#3") 192 | return self.state.get_value(), 0., True, {} 193 | 194 | next_dir = self.directions[action].clone().detach().float() 195 | if self.no_steps > 0: 196 | prev_dir = self.state_history[self.no_steps,:] - self.state_history[self.no_steps - 1,:] 197 | prev_dir = prev_dir / torch.linalg.norm(prev_dir) 198 | if torch.dot(next_dir, prev_dir) < 0: 199 | next_dir = next_dir * -1 200 | 201 | if torch.dot(next_dir, prev_dir) < 0.5: 202 | return self.state.get_value(), 0., True, {} 203 | else: 204 | if backwards: 205 | next_dir = next_dir * -1 206 | prev_dir = next_dir 207 | 208 | step_width = self.step_width if self.no_steps > 0 else 0.5 * self.step_width 209 | self.state = self.state + (step_width * next_dir) 210 | self.no_steps += 1 211 | self.state_history[self.no_steps, :] = self.state.get_coordinate() 212 | 213 | ijk_coordinate = self.state.get_coordinate() 214 | 215 | local_na_reward = self.na_interpolator(ijk_coordinate) 216 | #print(local_na_reward.shape) 217 | #print(self.na_reward_history.shape) 218 | self.na_reward_history[self.no_steps - 1, :] = local_na_reward 219 | 220 | if self.no_steps > 1: 221 | mean_na_reward = torch.mean(self.na_reward_history[0: self.no_steps - 1], dim=0) 222 | na_reward = mean_na_reward + local_na_reward 223 | else: 224 | na_reward = local_na_reward 225 | reward = odf_cur[action] * torch.dot(next_dir, prev_dir) + torch.max(na_reward) 226 | return self.get_observation_from_state(self.state), reward, False, {} 227 | 228 | def to_ras(self, points): 229 | new_shape = (points.shape) 230 | points = points.reshape(-1, 3) 231 | return (torch.mm(self.ras_aff[:3,:3],points.T) + self.ras_aff[:3,3:4]).T.reshape(new_shape) 232 | 233 | def to_ijk(self, points): 234 | new_shape = (points.shape) 235 | points = points.reshape(-1, 3) 236 | res_pts = (torch.mm(self.ijk_aff[:3,:3],points.T) + self.ijk_aff[:3,3:4]).T.reshape(new_shape) 237 | return res_pts 238 | 239 | 240 | def interpolate_dwi_at_state(self, points): 241 | ras_points = self.grid + self.to_ras(points.float()) 242 | new_shape = (*points.shape[:-1], -1) 243 | points = self.to_ijk(ras_points.float()).reshape(-1, 3) 244 | 245 | is_outside = ((points[:, 0] < 0) + (points[:, 0] >= self.dwi.shape[0]) + # OR 246 | (points[:, 1] < 0) + (points[:, 1] >= self.dwi.shape[1]) + 247 | (points[:, 2] < 0) + (points[:, 2] >= self.dwi.shape[2])) > 0 248 | 249 | if torch.sum(is_outside) > 0: 250 | return None 251 | 252 | result = self.dwi_processor(points) 253 | 254 | 255 | result = result.reshape(new_shape) 256 | return result 257 | 258 | 259 | def _next_pos_and_reward(self, backwards=False): 260 | next_dirs = self.directions.clone().detach() 261 | if self.no_steps > 0: 262 | prev_dir = self.state_history[self.no_steps] - self.state_history[self.no_steps - 1] 263 | prev_dir = prev_dir / torch.linalg.norm(prev_dir) 264 | 265 | should_be_inverted = torch.sum(next_dirs * prev_dir, dim=1) < 0 266 | next_dirs[should_be_inverted] = -next_dirs[should_be_inverted] 267 | elif backwards: 268 | next_dirs = -next_dirs 269 | rewards = self._get_reward_for_move(torch.arange(0, self.directions.shape[0]), next_dirs) 270 | 271 | if self.no_steps > 0: 272 | angle_to_sharp = torch.sum(next_dirs * prev_dir, dim=1) < 0.5 273 | rewards[angle_to_sharp] = 0 274 | return next_dirs, rewards 275 | 276 | 277 | def _get_reward_for_move(self, actions, next_directions): 278 | 279 | cur_pos_ijk = self.state.get_coordinate() 280 | 281 | odf_cur = self.odf_interpolator(cur_pos_ijk.unsqueeze(0)).squeeze() # [L] 282 | if torch.max(odf_cur) > 0: 283 | odf_cur = odf_cur / torch.max(odf_cur) 284 | # actions : [N], next_directions: [N ,3] 285 | # bundles : [K] 286 | step_width = self.step_width if self.no_steps > 0 else 0.5 * self.step_width 287 | 288 | next_positions = cur_pos_ijk + next_directions * step_width 289 | ijk_coordinates = next_positions # [N, 3] 290 | local_na_reward = self.na_interpolator(ijk_coordinates) # [N, K] 291 | 292 | if self.no_steps > 0: 293 | mean_na_reward = torch.mean(self.na_reward_history[0: self.no_steps], dim=0) # [K] 294 | na_reward = mean_na_reward + local_na_reward # [N, K] 295 | na_reward = torch.max(na_reward, dim = 2).values.squeeze() # N 296 | 297 | prev_dir = self.state_history[self.no_steps] - self.state_history[self.no_steps - 1] 298 | prev_dir = prev_dir / torch.linalg.norm(prev_dir) 299 | return odf_cur[actions] * torch.sum(next_directions* prev_dir, dim=1) + na_reward 300 | else: 301 | return odf_cur[actions] + torch.max(local_na_reward, dim = 2).values.squeeze() 302 | 303 | 304 | def track(self, with_best_action=True, agent=None): 305 | streamlines = [] 306 | for i in trange(len(self.seeds)): 307 | streamline = [] 308 | self.reset(seed_index=i) 309 | 310 | # -- forward tracking -- 311 | terminal = False 312 | while not terminal: 313 | # current position 314 | # get the best choice from environment 315 | if with_best_action: 316 | _, reward = self._next_pos_and_reward() 317 | action = torch.argmax(reward) 318 | else: 319 | action = agent(self.state.get_value()) 320 | #raise NotImplementedError 321 | # take a step 322 | _, reward, terminal, _ = self.step(action) 323 | # step function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 324 | if not terminal: 325 | streamline.append(self.to_ras(self.state.get_coordinate().float()).cpu().detach().numpy()) 326 | 327 | # -- backward tracking -- 328 | self.reset(seed_index=i) 329 | # reset function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 330 | 331 | streamline = streamline[::-1] 332 | 333 | while not terminal: 334 | if with_best_action: 335 | _, reward = self._next_pos_and_reward(backwards=True) 336 | action = torch.argmax(reward) 337 | else: 338 | action = agent(self.state.get_value()) 339 | # take a step 340 | _, reward, terminal, _ = self.step(action, backwards=True) 341 | # step function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 342 | if not terminal: 343 | streamline.append(self.to_ras(self.state.get_coordinate().float()).cpu().detach().numpy()) 344 | streamline = np.array(streamline) 345 | streamlines.append(streamline) 346 | 347 | return streamlines 348 | 349 | 350 | # adapter to experiment with different representations of our state 351 | def get_observation_from_state(self, state): 352 | dwi_values = state#.getValue().flatten() 353 | # TODO -> currently only on dwi values, not on past states 354 | #past_coordinates = np.array(list(self.state_history)).flatten() 355 | #return np.concatenate((dwi_values, past_coordinates)) 356 | return dwi_values 357 | 358 | 359 | # reset the game and returns the observed data from the last episode 360 | def reset(self, seed_index=None): 361 | if seed_index is None: 362 | seed_index = torch.randint(len(self.seeds), size=(1,1))[0][0] 363 | self.seed_index = seed_index 364 | #seed_ras = self.seeds[self.seed_index] 365 | self.seed_ijk = self.seeds[self.seed_index].to(self.device) 366 | self.no_steps = 0 367 | self.na_reward_history = [] 368 | self.state = TractographyState(self.seed_ijk, self.interpolate_dwi_at_state) 369 | self.state_history = torch.zeros((self.max_steps + 1, 3)).to(self.device) 370 | self.state_history[0,:] = self.state.get_coordinate() 371 | #self.state_history = deque([self.state]*4, maxlen=4) 372 | self.na_reward_history = torch.zeros((self.max_steps, self.tract_masks.shape[-1])).to(self.device) 373 | 374 | return self.get_observation_from_state(self.state) 375 | 376 | def render(self, mode="human"): 377 | pass 378 | -------------------------------------------------------------------------------- /dfibert/envs/RLTractEnvironment_fast.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import os 3 | import gym 4 | import numpy as np 5 | import torch 6 | import dipy.reconst.dti as dti 7 | from dipy.core.interpolation import trilinear_interpolate4d 8 | from dipy.core.sphere import HemiSphere, Sphere 9 | from dipy.core.sphere_stats import random_uniform_on_sphere 10 | 11 | from dipy.data import get_sphere 12 | from dipy.direction import peaks_from_model 13 | from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel 14 | from dipy.reconst.csdeconv import (mask_for_response_ssst, 15 | response_from_mask_ssst) 16 | from dipy.reconst.shm import order_from_ncoef, sph_harm_lookup 17 | from dipy.tracking import utils 18 | from gym.spaces import Discrete, Box 19 | from scipy.interpolate import RegularGridInterpolator 20 | from tqdm import trange 21 | 22 | from dfibert.data import DataContainer, DataPreprocessor, PointOutsideOfDWIError 23 | from dfibert.data.postprocessing import Resample, Resample100 24 | from dfibert.util import get_grid, set_seed 25 | from ._state import TractographyState 26 | 27 | 28 | from .neuroanatomical_utils import FiberBundleDataset, interpolate3dAt 29 | from .NARLTractEnvironment import TorchGridInterpolator 30 | 31 | 32 | class RLTractEnvironment(gym.Env): 33 | def __init__(self, device, seeds=None, step_width=0.8, dataset='100307', grid_dim=(3, 3, 3), 34 | neuroanatomical_reward=True, tracking_in_RAS=False, fa_threshold=0.1, b_val=1000, 35 | odf_state=True, odf_mode="CSD", action_space=100, pFolderBundles = "data/gt_bundles/", rnd_seed = 2342): 36 | print("Will be deprecated by NARLTractEnvironment as soon as Jos fixes all bugs in the reward function.") 37 | self.state_history = None 38 | self.reference_seed_point_ijk = None 39 | self.points_visited = None 40 | self.past_reward = None 41 | self.reward = None 42 | self.stepCounter = None 43 | self.done = None 44 | self.seed_index = None 45 | self.step_angles = None 46 | self.line = None 47 | self.neuroanatomical_reward = neuroanatomical_reward 48 | if(neuroanatomical_reward): 49 | self.na_reward_history = None 50 | self.av_na_reward = None 51 | self.past_bundle = None 52 | self.device = device 53 | 54 | if(dataset.__class__ == DataContainer): 55 | print("Using preloaded data.") 56 | self.dataset = dataset 57 | else: 58 | print("Loading dataset # %s" % (dataset)) 59 | preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate() 60 | if dataset == 'ISMRM': 61 | self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/") 62 | else: 63 | self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/") 64 | 65 | self.step_width = step_width 66 | self.dtype = torch.FloatTensor # vs. torch.cuda.FloatTensor 67 | self.dti_model = None 68 | self.dti_fit = None 69 | self.odf_interpolator = None 70 | self.sh_coefficient = None 71 | self.odf_mode = odf_mode 72 | 73 | # build DWI object by interpolating at all IJK coordinates 74 | interpol_pts = None 75 | # permute into CxHxWxD 76 | self.dwi = torch.from_numpy(Resample100().process(self.dataset, None, self.dataset.dwi)).to(device=device).float() 77 | 78 | set_seed(rnd_seed) 79 | X = random_uniform_on_sphere(n=action_space) 80 | self.sphere = HemiSphere(xyz=X) 81 | self.sphere_odf = self.sphere 82 | 83 | # -- interpolation function of state's value -- 84 | self.state_interpol_func = self.interpolate_dwi_at_state 85 | if odf_state: 86 | print("Interpolating ODF as state Value") 87 | self.state_interpol_func = self.interpolate_odf_at_state 88 | 89 | self.directions = torch.from_numpy(self.sphere.vertices).to(device) 90 | no_actions, _ = self.directions.shape 91 | self.directions_odf = torch.from_numpy(self.sphere_odf.vertices).to(device) 92 | 93 | self.action_space = Discrete(no_actions) # spaces.Discrete(no_actions+1) 94 | self.dwi_postprocessor = Resample(sphere=get_sphere('repulsion100')) # resample(sphere=sphere) 95 | self.referenceStreamline_ijk = None 96 | self.grid = get_grid(np.array(grid_dim)) 97 | self.grid = torch.from_numpy(self.grid).to(self.device) 98 | self.tracking_in_RAS = tracking_in_RAS 99 | 100 | # -- load streamlines -- 101 | self.fa_threshold = fa_threshold 102 | self.maxSteps = 2000 103 | 104 | # -- init seeds -- 105 | self.seeds = seeds 106 | if self.seeds is None: 107 | if self.dti_fit is None: 108 | self._init_odf() 109 | 110 | dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS') 111 | dti_fit = dti_model.fit(self.dataset.dwi, mask=self.dataset.binary_mask) 112 | 113 | fa_img = dti_fit.fa 114 | seed_mask = fa_img.copy() 115 | seed_mask[seed_mask >= 0.2] = 1 116 | seed_mask[seed_mask < 0.2] = 0 117 | 118 | seeds = utils.seeds_from_mask(seed_mask, affine=np.eye(4), density=1) # tracking in IJK 119 | self.seeds = torch.from_numpy(seeds) 120 | 121 | # -- init bundles for neuroanatomical reward -- 122 | if(neuroanatomical_reward): 123 | print("Init tract masks for neuroanatomical reward") 124 | fibers = [] 125 | self.bundleNames = os.listdir(pFolderBundles) 126 | for fibFile in self.bundleNames: 127 | pFibre = pFolderBundles + fibFile 128 | #print(" @ " + pFibre) 129 | fibers.append(FiberBundleDataset(path_to_files=pFibre, dataset = self.dataset).tractMask) 130 | 131 | ## Define our interpolators 132 | self.tractMasks = torch.stack(fibers, dim = 0).to(self.device).permute((1,2,3,0)) # [X,Y,Z,C] 133 | print(self.tractMasks.shape) 134 | self.tractMask_interpolator = TorchGridInterpolator(self.tractMasks) 135 | 136 | # -- init interpolators -- 137 | self.binary_mask = torch.from_numpy(self.dataset.binary_mask).to(device=device) 138 | self.fa_interpolator = TorchGridInterpolator(torch.from_numpy(self.dataset.fa).to(device=device).unsqueeze(-1).float()) 139 | self.dwi_interpolator = TorchGridInterpolator(self.dwi.to(self.device)) 140 | self.brainMask_interpolator = TorchGridInterpolator(torch.from_numpy(self.dataset.binary_mask).to(self.device).unsqueeze(-1).float()) 141 | 142 | # -- set default values -- 143 | self.reset() 144 | 145 | 146 | def _init_odf(self): 147 | print("Initialising ODF") 148 | # fit DTI model to data 149 | if self.odf_mode == "DTI": 150 | print("DTI-based ODF computation") 151 | self.dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS') 152 | self.dti_fit = self.dti_model.fit(self.dataset.dwi, mask=self.dataset.binary_mask) 153 | # compute ODF 154 | odf = self.dti_fit.odf(self.sphere_odf) 155 | elif self.odf_mode == "CSD": 156 | print("CSD-based ODF computation") 157 | mask = mask_for_response_ssst(self.dataset.gtab, self.dataset.dwi, roi_radii=10, fa_thr=0.7) 158 | num_voxels = np.sum(mask) 159 | print(num_voxels) 160 | response, ratio = response_from_mask_ssst(self.dataset.gtab, self.dataset.dwi, mask) 161 | print(response) 162 | self.dti_model = ConstrainedSphericalDeconvModel(self.dataset.gtab, response) 163 | self.dti_fit = self.dti_model.fit(self.dataset.dwi) 164 | odf = self.dti_fit.odf(self.sphere_odf) 165 | 166 | # -- set up interpolator for odf evaluation 167 | odf = torch.from_numpy(odf).to(device=self.device).float() 168 | self.odf_interpolator = TorchGridInterpolator(odf) 169 | print("..done!") 170 | 171 | def interpolate_dwi_at_state(self, stateCoordinates): 172 | # torch 173 | ijk_pts = self.grid + stateCoordinates 174 | new_shape = (*ijk_pts.shape[:-1], -1) 175 | 176 | interpolated_dwi = self.dwi_interpolator(ijk_pts) 177 | interpolated_dwi = interpolated_dwi.reshape(new_shape) 178 | 179 | return interpolated_dwi 180 | 181 | 182 | def interpolate_odf_at_state(self, stateCoordinates): 183 | # torch 184 | if self.odf_interpolator is None: 185 | self._init_odf() 186 | 187 | new_shape = (*stateCoordinates.shape[:-1], -1) 188 | 189 | interpol_odf = self.odf_interpolator(stateCoordinates) 190 | interpol_odf = interpol_odf.reshape(new_shape) 191 | return interpol_odf 192 | 193 | 194 | def step(self, action, direction="forward"): 195 | self.stepCounter += 1 196 | cur_position = self.state.getCoordinate().view(-1, 3).to(self.device) 197 | 198 | 199 | # -- Termination conditions -- 200 | # I. number of steps larger than maximum 201 | if self.stepCounter >= self.maxSteps: 202 | return self.get_observation_from_state(self.state), 0., True, {} 203 | 204 | # II. fa below threshold? stop tracking 205 | if(self.fa_interpolator(cur_position) < self.fa_threshold): 206 | #if self.dataset.get_fa(self.state.getCoordinate().cpu()) < self.fa_threshold: 207 | return self.get_observation_from_state(self.state), 0., True, {} 208 | 209 | # III. leaving brain mask 210 | if(self.brainMask_interpolator(cur_position) == 0): 211 | return self.get_observation_from_state(self.state), 0., True, {} 212 | 213 | # -- Tracking -- 214 | cur_tangent = self.directions[action].view(-1, 3) # get direction from action (action = vertex id on (half)sphere) 215 | if(direction == "backward"): 216 | cur_tangent = cur_tangent * -1 217 | next_position = cur_position + self.step_width * cur_tangent 218 | next_state = TractographyState(next_position, self.state_interpol_func) 219 | 220 | # -- REWARD -- 221 | reward = self.reward_for_state_action_pair(self.state, action, direction) # prev_tangent => None 222 | 223 | # -- book keeping -- 224 | self.state_history.append(next_state) 225 | self.state = next_state 226 | 227 | return self.get_observation_from_state(next_state), reward, False, {} 228 | 229 | 230 | def reward_for_state(self, state, direction, prev_direction = None): 231 | my_position = state.getCoordinate().squeeze(0).to(self.device) 232 | # -- main peaks from ODF -- 233 | pmf_cur = self.interpolate_odf_at_state(my_position) 234 | reward = pmf_cur / torch.max(pmf_cur) 235 | 236 | #if(prev_direction != None): 237 | # print("[Warning] cosine similarity loss not used anymore due to resutls of ablation study.") 238 | 239 | ## ablation study on CST found that peak finding not required 240 | ''' 241 | peak_indices = self._get_odf_peaks(reward, window_width=int(self.action_space.n/3)) 242 | mask = torch.zeros_like(reward, device = self.device) 243 | mask[peak_indices] = 1 244 | reward *= mask 245 | ''' 246 | 247 | ## ablation study on CST found that angular deviation not needed 248 | ''' 249 | # -- limit angular deviation -- 250 | if prev_direction is not None: 251 | reward = reward * abs(torch.nn.functional.cosine_similarity(self.directions, prev_direction.view(1,-1))).view(-1) # noActions 252 | # neuroanatomical reward 253 | ''' 254 | 255 | 256 | # -- neuroanatomical reward -- 257 | orientation = self.step_width * self.directions 258 | if(direction == "backward"): 259 | orientation = -1 * orientation 260 | next_pos = my_position.view(1,-1) + orientation # gets next positions for all directions actions X 3 261 | 262 | if(self.neuroanatomical_reward): 263 | local_reward_na = self.tractMask_interpolator(next_pos) # noActions x noTracts 264 | 265 | reward_na_mu_hist = torch.mean(self.na_reward_history[0:max(self.stepCounter-1,1), :], dim = 0).view(1,-1) # 1 x no_tracts 266 | local_reward_na = local_reward_na + reward_na_mu_hist # noActions x noTracts 267 | reward_na, _ = torch.max(local_reward_na, dim = 1) # # marginalize tracts 268 | reward_na = reward_na.view(-1) # noActions 269 | 270 | # reward_na_arg = torch.argmax(local_reward_na, dim = 0) # get dominant tract per action 271 | 272 | reward = reward + na 273 | 274 | return reward 275 | 276 | 277 | def reward_for_state_action_pair(self, state, action, direction, prev_direction = None): 278 | reward = self.reward_for_state(state, direction, prev_direction) 279 | return reward[action] 280 | 281 | 282 | def _get_best_action(self, state, direction="forward", prev_direction = None): 283 | reward = self.reward_for_state(state, direction, prev_direction) 284 | return torch.argmax(reward) 285 | 286 | 287 | def track(self, agent=None): 288 | streamlines = [] 289 | for i in trange(len(self.seeds), ascii=True): 290 | streamline, _ = self._track_single_streamline(i, agent) 291 | streamlines.append((streamline)) 292 | 293 | return streamlines 294 | 295 | 296 | def _track_single_streamline(self, index, agent=None): 297 | all_states = [] 298 | self.reset(seed_index=index) 299 | state = self.state # reset function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 300 | seed_position = state.getCoordinate().to(self.device) 301 | current_direction = None 302 | all_states.append(seed_position.squeeze(0)) 303 | streamline_reward = 0 304 | # -- forward tracking -- 305 | terminal = False 306 | eval_steps = 0 307 | while not terminal: 308 | # current position 309 | # get the best choice from environment 310 | if agent is None: 311 | action = self._get_best_action(state, direction="forward", prev_direction=current_direction) 312 | else: 313 | action = agent(self.get_observation_from_state(self.state)) 314 | # store tangent for next time step 315 | current_direction = self.directions[action] #.numpy() 316 | # take a step 317 | _, reward, terminal, _ = self.step(action) 318 | streamline_reward += reward 319 | state = self.state # step function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 320 | if not terminal: 321 | all_states.append(state.getCoordinate().squeeze(0)) 322 | eval_steps = eval_steps + 1 323 | 324 | # -- backward tracking -- 325 | self.reset(seed_index=index, terminal_F=True) 326 | state = self.state # reset function now returns dwi values --> due to compatibility to rainbow agent or stable baselines 327 | current_direction = None # potentially take tangent of first step of forward tracker 328 | terminal = False 329 | all_states = all_states[::-1] 330 | while not terminal: 331 | # current position 332 | my_position = state.getCoordinate().double().squeeze(0) 333 | # get the best choice from environment 334 | if agent is None: 335 | action = self._get_best_action(state, direction="forward", prev_direction=current_direction) 336 | else: 337 | action = agent(self.get_observation_from_state(self.state)) 338 | # store tangent for next time step 339 | current_direction = self.directions[action]#.numpy() 340 | # take a step 341 | _, reward, terminal, _ = self.step(action, direction="backward") 342 | state = self.state 343 | my_position = my_position.to(self.device) # DIRTY!!! 344 | my_coord = state.getCoordinate().squeeze(0).to(self.device) 345 | if (False in torch.eq(my_coord, my_position)) & (not terminal): 346 | all_states.append(my_coord) 347 | 348 | return all_states, streamline_reward 349 | 350 | 351 | def get_observation_from_state(self, state): 352 | dwi_values = state.getValue().flatten() 353 | # TODO -> currently only on dwi values, not on past states 354 | #past_coordinates = np.array(list(self.state_history)).flatten() 355 | #return np.concatenate((dwi_values, past_coordinates)) 356 | return dwi_values 357 | 358 | 359 | # reset the game and returns the observed data from the last episode 360 | def reset(self, seed_index=None, terminal_F=False, terminal_B=False): 361 | # self.seed_index = seed_index 362 | if seed_index is not None: 363 | self.seed_index = seed_index 364 | elif not terminal_F and not terminal_B or terminal_F and terminal_B: 365 | self.seed_index = np.random.randint(len(self.seeds)) 366 | 367 | if self.tracking_in_RAS: 368 | reference_seed_point_ras = self.seeds[self.seed_index] 369 | reference_seed_point_ijk = self.dataset.to_ijk( 370 | reference_seed_point_ras) 371 | else: 372 | reference_seed_point_ijk = self.seeds[self.seed_index] 373 | 374 | self.done = False 375 | self.stepCounter = 0 376 | self.reward = 0 377 | self.past_reward = 0 378 | self.points_visited = 1 # position_index 379 | 380 | self.reference_seed_point_ijk = reference_seed_point_ijk.to(self.device) 381 | self.state = TractographyState(self.reference_seed_point_ijk, self.state_interpol_func) 382 | self.state_history = deque([self.state]*4, maxlen=4) 383 | 384 | if(self.neuroanatomical_reward): 385 | self.na_reward_history = torch.zeros((self.maxSteps, self.tractMasks.shape[-1]), device = self.device) 386 | 387 | return self.get_observation_from_state(self.state) 388 | 389 | def render(self, mode="human"): 390 | pass 391 | -------------------------------------------------------------------------------- /dfibert/ext/soft_dtw_cuda.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Mehran Maghoumi 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # ---------------------------------------------------------------------------------------------------------------------- 23 | 24 | import numpy as np 25 | import torch 26 | import torch.cuda 27 | from numba import jit 28 | from torch.autograd import Function 29 | from numba import cuda 30 | import math 31 | 32 | # ---------------------------------------------------------------------------------------------------------------------- 33 | @cuda.jit 34 | def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R): 35 | """ 36 | :param seq_len: The length of the sequence (both inputs are assumed to be of the same size) 37 | :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals) 38 | """ 39 | # Each block processes one pair of examples 40 | b = cuda.blockIdx.x 41 | # We have as many threads as seq_len, because the most number of threads we need 42 | # is equal to the number of elements on the largest anti-diagonal 43 | tid = cuda.threadIdx.x 44 | 45 | # Compute I, J, the indices from [0, seq_len) 46 | 47 | # The row index is always the same as tid 48 | I = tid 49 | 50 | inv_gamma = 1.0 / gamma 51 | 52 | # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal 53 | for p in range(n_passes): 54 | 55 | # The index is actually 'p - tid' but need to force it in-bounds 56 | J = max(0, min(p - tid, max_j - 1)) 57 | 58 | # For simplicity, we define i, j which start from 1 (offset from I, J) 59 | i = I + 1 60 | j = J + 1 61 | 62 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds 63 | if I + J == p and (I < max_i and J < max_j): 64 | # Don't compute if outside bandwidth 65 | if not (abs(i - j) > bandwidth > 0): 66 | r0 = -R[b, i - 1, j - 1] * inv_gamma 67 | r1 = -R[b, i - 1, j] * inv_gamma 68 | r2 = -R[b, i, j - 1] * inv_gamma 69 | rmax = max(max(r0, r1), r2) 70 | rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax) 71 | softmin = -gamma * (math.log(rsum) + rmax) 72 | R[b, i, j] = D[b, i - 1, j - 1] + softmin 73 | 74 | # Wait for other threads in this block 75 | cuda.syncthreads() 76 | 77 | # ---------------------------------------------------------------------------------------------------------------------- 78 | @cuda.jit 79 | def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E): 80 | k = cuda.blockIdx.x 81 | tid = cuda.threadIdx.x 82 | 83 | # Indexing logic is the same as above, however, the anti-diagonal needs to 84 | # progress backwards 85 | I = tid 86 | 87 | for p in range(n_passes): 88 | # Reverse the order to make the loop go backward 89 | rev_p = n_passes - p - 1 90 | 91 | # convert tid to I, J, then i, j 92 | J = max(0, min(rev_p - tid, max_j - 1)) 93 | 94 | i = I + 1 95 | j = J + 1 96 | 97 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds 98 | if I + J == rev_p and (I < max_i and J < max_j): 99 | 100 | if math.isinf(R[k, i, j]): 101 | R[k, i, j] = -math.inf 102 | 103 | # Don't compute if outside bandwidth 104 | if not (abs(i - j) > bandwidth > 0): 105 | a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma) 106 | b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma) 107 | c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma) 108 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 109 | 110 | # Wait for other threads in this block 111 | cuda.syncthreads() 112 | 113 | # ---------------------------------------------------------------------------------------------------------------------- 114 | class _SoftDTWCUDA(Function): 115 | """ 116 | CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444: 117 | "Developing a pattern discovery method in time series data and its GPU acceleration" 118 | """ 119 | 120 | @staticmethod 121 | def forward(ctx, D, gamma, bandwidth): 122 | dev = D.device 123 | dtype = D.dtype 124 | gamma = torch.cuda.FloatTensor([gamma]) 125 | bandwidth = torch.cuda.FloatTensor([bandwidth]) 126 | 127 | B = D.shape[0] 128 | N = D.shape[1] 129 | M = D.shape[2] 130 | threads_per_block = max(N, M) 131 | n_passes = 2 * threads_per_block - 1 132 | 133 | # Prepare the output array 134 | R = torch.ones((B, N + 2, M + 2), device=dev, dtype=dtype) * math.inf 135 | R[:, 0, 0] = 0 136 | 137 | # Run the CUDA kernel. 138 | # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair) 139 | # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal) 140 | compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()), 141 | gamma.item(), bandwidth.item(), N, M, n_passes, 142 | cuda.as_cuda_array(R)) 143 | ctx.save_for_backward(D, R.clone(), gamma, bandwidth) 144 | return R[:, -2, -2] 145 | 146 | @staticmethod 147 | def backward(ctx, grad_output): 148 | dev = grad_output.device 149 | dtype = grad_output.dtype 150 | D, R, gamma, bandwidth = ctx.saved_tensors 151 | 152 | B = D.shape[0] 153 | N = D.shape[1] 154 | M = D.shape[2] 155 | threads_per_block = max(N, M) 156 | n_passes = 2 * threads_per_block - 1 157 | 158 | D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) 159 | D_[:, 1:N + 1, 1:M + 1] = D 160 | 161 | R[:, :, -1] = -math.inf 162 | R[:, -1, :] = -math.inf 163 | R[:, -1, -1] = R[:, -2, -2] 164 | 165 | E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) 166 | E[:, -1, -1] = 1 167 | 168 | # Grid and block sizes are set same as done above for the forward() call 169 | compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_), 170 | cuda.as_cuda_array(R), 171 | 1.0 / gamma.item(), bandwidth.item(), N, M, n_passes, 172 | cuda.as_cuda_array(E)) 173 | E = E[:, 1:N + 1, 1:M + 1] 174 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None 175 | 176 | 177 | # ---------------------------------------------------------------------------------------------------------------------- 178 | # 179 | # The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw 180 | # Credit goes to Kanru Hua. 181 | # I've added support for batching and pruning. 182 | # 183 | # ---------------------------------------------------------------------------------------------------------------------- 184 | @jit(nopython=True) 185 | def compute_softdtw(D, gamma, bandwidth): 186 | B = D.shape[0] 187 | N = D.shape[1] 188 | M = D.shape[2] 189 | R = np.ones((B, N + 2, M + 2)) * np.inf 190 | R[:, 0, 0] = 0 191 | for b in range(B): 192 | for j in range(1, M + 1): 193 | for i in range(1, N + 1): 194 | 195 | # Check the pruning condition 196 | if 0 < bandwidth < np.abs(i - j): 197 | continue 198 | 199 | r0 = -R[b, i - 1, j - 1] / gamma 200 | r1 = -R[b, i - 1, j] / gamma 201 | r2 = -R[b, i, j - 1] / gamma 202 | rmax = max(max(r0, r1), r2) 203 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 204 | softmin = - gamma * (np.log(rsum) + rmax) 205 | R[b, i, j] = D[b, i - 1, j - 1] + softmin 206 | return R 207 | 208 | # ---------------------------------------------------------------------------------------------------------------------- 209 | @jit(nopython=True) 210 | def compute_softdtw_backward(D_, R, gamma, bandwidth): 211 | B = D_.shape[0] 212 | N = D_.shape[1] 213 | M = D_.shape[2] 214 | D = np.zeros((B, N + 2, M + 2)) 215 | E = np.zeros((B, N + 2, M + 2)) 216 | D[:, 1:N + 1, 1:M + 1] = D_ 217 | E[:, -1, -1] = 1 218 | R[:, :, -1] = -np.inf 219 | R[:, -1, :] = -np.inf 220 | R[:, -1, -1] = R[:, -2, -2] 221 | for k in range(B): 222 | for j in range(M, 0, -1): 223 | for i in range(N, 0, -1): 224 | 225 | if np.isinf(R[k, i, j]): 226 | R[k, i, j] = -np.inf 227 | 228 | # Check the pruning condition 229 | if 0 < bandwidth < np.abs(i - j): 230 | continue 231 | 232 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma 233 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma 234 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma 235 | a = np.exp(a0) 236 | b = np.exp(b0) 237 | c = np.exp(c0) 238 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 239 | return E[:, 1:N + 1, 1:M + 1] 240 | 241 | # ---------------------------------------------------------------------------------------------------------------------- 242 | class _SoftDTW(Function): 243 | """ 244 | CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw 245 | """ 246 | 247 | @staticmethod 248 | def forward(ctx, D, gamma, bandwidth): 249 | dev = D.device 250 | dtype = D.dtype 251 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed 252 | bandwidth = torch.Tensor([bandwidth]).to(dev).type(dtype) 253 | D_ = D.detach().cpu().numpy() 254 | g_ = gamma.item() 255 | b_ = bandwidth.item() 256 | R = torch.Tensor(compute_softdtw(D_, g_, b_)).to(dev).type(dtype) 257 | ctx.save_for_backward(D, R, gamma, bandwidth) 258 | return R[:, -2, -2] 259 | 260 | @staticmethod 261 | def backward(ctx, grad_output): 262 | dev = grad_output.device 263 | dtype = grad_output.dtype 264 | D, R, gamma, bandwidth = ctx.saved_tensors 265 | D_ = D.detach().cpu().numpy() 266 | R_ = R.detach().cpu().numpy() 267 | g_ = gamma.item() 268 | b_ = bandwidth.item() 269 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).type(dtype) 270 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None 271 | 272 | # ---------------------------------------------------------------------------------------------------------------------- 273 | class SoftDTW(torch.nn.Module): 274 | """ 275 | The soft DTW implementation that optionally supports CUDA 276 | """ 277 | 278 | def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None): 279 | """ 280 | Initializes a new instance using the supplied parameters 281 | :param use_cuda: Flag indicating whether the CUDA implementation should be used 282 | :param gamma: sDTW's gamma parameter 283 | :param normalize: Flag indicating whether to perform normalization 284 | (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) 285 | :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning. 286 | :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used. 287 | """ 288 | super(SoftDTW, self).__init__() 289 | self.normalize = normalize 290 | self.gamma = gamma 291 | self.bandwidth = 0 if bandwidth is None else float(bandwidth) 292 | self.use_cuda = use_cuda 293 | 294 | # Set the distance function 295 | if dist_func is not None: 296 | self.dist_func = dist_func 297 | else: 298 | self.dist_func = SoftDTW._euclidean_dist_func 299 | 300 | def _get_func_dtw(self, x, y): 301 | """ 302 | Checks the inputs and selects the proper implementation to use. 303 | """ 304 | bx, lx, dx = x.shape 305 | by, ly, dy = y.shape 306 | # Make sure the dimensions match 307 | assert bx == by # Equal batch sizes 308 | assert dx == dy # Equal feature dimensions 309 | 310 | use_cuda = self.use_cuda 311 | 312 | if use_cuda and (lx > 1024 or ly > 1024): # We should be able to spawn enough threads in CUDA 313 | print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)") 314 | use_cuda = False 315 | 316 | # Finally, return the correct function 317 | return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply 318 | 319 | @staticmethod 320 | def _euclidean_dist_func(x, y): 321 | """ 322 | Calculates the Euclidean distance between each element in x and y per timestep 323 | """ 324 | n = x.size(1) 325 | m = y.size(1) 326 | d = x.size(2) 327 | x = x.unsqueeze(2).expand(-1, n, m, d) 328 | y = y.unsqueeze(1).expand(-1, n, m, d) 329 | return torch.pow(x - y, 2).sum(3) 330 | 331 | def forward(self, X, Y): 332 | """ 333 | Compute the soft-DTW value between X and Y 334 | :param X: One batch of examples, batch_size x seq_len x dims 335 | :param Y: The other batch of examples, batch_size x seq_len x dims 336 | :return: The computed results 337 | """ 338 | 339 | # Check the inputs and get the correct implementation 340 | func_dtw = self._get_func_dtw(X, Y) 341 | 342 | if self.normalize: 343 | # Stack everything up and run 344 | x = torch.cat([X, X, Y]) 345 | y = torch.cat([Y, X, Y]) 346 | D = self.dist_func(x, y) 347 | out = func_dtw(D, self.gamma, self.bandwidth) 348 | out_xy, out_xx, out_yy = torch.split(out, X.shape[0]) 349 | return out_xy - 1 / 2 * (out_xx + out_yy) 350 | else: 351 | D_xy = self.dist_func(X, Y) 352 | return func_dtw(D_xy, self.gamma, self.bandwidth) 353 | 354 | # ---------------------------------------------------------------------------------------------------------------------- 355 | def timed_run(a, b, sdtw): 356 | """ 357 | Runs a and b through sdtw, and times the forward and backward passes. 358 | Assumes that a requires gradients. 359 | :return: timing, forward result, backward result 360 | """ 361 | from timeit import default_timer as timer 362 | 363 | # Forward pass 364 | start = timer() 365 | forward = sdtw(a, b) 366 | end = timer() 367 | t = end - start 368 | 369 | grad_outputs = torch.ones_like(forward) 370 | 371 | # Backward 372 | start = timer() 373 | grads = torch.autograd.grad(forward, a, grad_outputs=grad_outputs)[0] 374 | end = timer() 375 | 376 | # Total time 377 | t += end - start 378 | 379 | return t, forward, grads 380 | 381 | # ---------------------------------------------------------------------------------------------------------------------- 382 | def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward): 383 | sdtw = SoftDTW(False, gamma=1.0, normalize=False) 384 | sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False) 385 | n_iters = 6 386 | 387 | print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims)) 388 | 389 | times_cpu = [] 390 | times_gpu = [] 391 | 392 | for i in range(n_iters): 393 | a_cpu = torch.rand((batch_size, seq_len_a, dims), requires_grad=True) 394 | b_cpu = torch.rand((batch_size, seq_len_b, dims)) 395 | a_gpu = a_cpu.cuda() 396 | b_gpu = b_cpu.cuda() 397 | 398 | # GPU 399 | t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda) 400 | 401 | # CPU 402 | t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw) 403 | 404 | # Verify the results 405 | assert torch.allclose(forward_cpu, forward_gpu.cpu()) 406 | assert torch.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward) 407 | 408 | if i > 0: # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script) 409 | times_cpu += [t_cpu] 410 | times_gpu += [t_gpu] 411 | 412 | # Average and log 413 | avg_cpu = np.mean(times_cpu) 414 | avg_gpu = np.mean(times_gpu) 415 | print("\tCPU: ", avg_cpu) 416 | print("\tGPU: ", avg_gpu) 417 | print("\tSpeedup: ", avg_cpu / avg_gpu) 418 | print() 419 | 420 | # ---------------------------------------------------------------------------------------------------------------------- 421 | if __name__ == "__main__": 422 | from timeit import default_timer as timer 423 | 424 | torch.manual_seed(1234) 425 | 426 | profile(128, 17, 15, 2, tol_backward=1e-6) 427 | profile(512, 64, 64, 2, tol_backward=1e-4) 428 | profile(512, 256, 256, 2, tol_backward=1e-3) 429 | -------------------------------------------------------------------------------- /dfibert/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dataset module is handling the datasets usable for training and testing. 3 | """ 4 | import json 5 | import os 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | import numpy as np 10 | 11 | from .exceptions import WrongDatasetTypePassedError, FeatureShapesNotEqualError 12 | 13 | 14 | class MovableData(): 15 | """ 16 | This class can be used to make classes handling multiple tensors more easily movable. 17 | 18 | With simple inheritance, all of those must be instances of `torch.Tensor` or `MovableData`. 19 | Also, they have to be direct attributes of the object and are not allowed to be nested. 20 | 21 | Attributes 22 | ---------- 23 | device: torch.device, optional 24 | The device the movable data currently is located on. 25 | 26 | Inheritance 27 | ----------- 28 | To modify and inherit the `MovableData` class, overwrite the following functions: 29 | 30 | `_get_tensors()` 31 | This should return all `torch.Tensor` and `MovableData` instances of your class, 32 | in a key value pair `dict`. 33 | 34 | `_set_tensor(key, tensor)` 35 | This should replace the reference to the tensor with given key with the new, moved tensor. 36 | 37 | If those two methods are properly inherited, the visible functions should work as normal. 38 | If you plan on add other class types to the `_get_tensors` method, make sure that they implement 39 | the cuda, cpu, to and get_device methods in the same manner as `torch.Tensor` instances. 40 | """ 41 | def __init__(self, device=None): 42 | """ 43 | Parameters 44 | ---------- 45 | device : torch.device, optional 46 | The device which the `MovableData` should be moved to on load, by default cpu. 47 | """ 48 | if device is None: 49 | device = torch.device("cpu") 50 | self.device = device 51 | 52 | def _get_tensors(self): 53 | """ 54 | Returns a dict containing all `torch.Tensor` and `MovableData` instances 55 | and their assigned keys. 56 | 57 | The default implementation searches for those on the attribute level. 58 | If your child class contains tensors at other positions, it is recommendable to 59 | overwrite this function and the `_set_tensor` function. 60 | 61 | Returns 62 | ------- 63 | dict 64 | The dict containing every `torch.Tensor` and `MovableData` with their assigned keys. 65 | 66 | See Also 67 | -------- 68 | _set_tensor: implementations depend on each other 69 | """ 70 | tensors = {} 71 | for key, value in vars(self).items(): 72 | if isinstance(value, torch.Tensor) or isinstance(value, MovableData): 73 | tensors[key] = value 74 | return tensors 75 | 76 | def _set_tensor(self, key, tensor): 77 | """ 78 | Sets the tensor with the assigned key to his value. 79 | 80 | In the default implementation, this works analogously to `_get_tensors`: 81 | It sets the attribute with the name key to the given object/tensor. 82 | If your child class contains tensors at other positions, it is recommendable to 83 | overwrite this function and the `_get_tensors` function. 84 | 85 | Parameters 86 | ---------- 87 | key : str 88 | The key of the original tensor. 89 | tensor : object 90 | The new tensor which should replace the original one. 91 | 92 | See Also 93 | -------- 94 | _get_tensors: implementations depend on each other 95 | """ 96 | setattr(self, key, tensor) 97 | 98 | def cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): 99 | """ 100 | Returns this object in CUDA memory. 101 | 102 | If this object is already in CUDA memory and on the correct device, 103 | then no movement is performed and the original object is returned. 104 | 105 | Parameters 106 | ---------- 107 | device : `torch.device`, optional 108 | The destination GPU device. Defaults to the current CUDA device. 109 | non_blocking : `bool`, optional 110 | If `True` and the source is in pinned memory, the copy will be asynchronous with 111 | respect to the host. Otherwise, the argument has no effect, by default `False`. 112 | memory_format : `torch.memory_format`, optional 113 | the desired memory format of returned Tensor, by default `torch.preserve_format`. 114 | 115 | Returns 116 | ------- 117 | MovableData 118 | The object moved to specified device 119 | """ 120 | for attribute, tensor in self._get_tensors().items(): 121 | cuda_tensor = tensor.cuda(device=device, non_blocking=non_blocking, 122 | memory_format=memory_format) 123 | self._set_tensor(attribute, cuda_tensor) 124 | self.device = cuda_tensor.device 125 | return self 126 | 127 | def cpu(self, memory_format=torch.preserve_format): 128 | """ 129 | Returns a copy of this object in CPU memory. 130 | 131 | If this object is already in CPU memory and on the correct device, 132 | then no copy is performed and the original object is returned. 133 | 134 | Parameters 135 | ---------- 136 | memory_format : `torch.memory_format`, optional 137 | the desired memory format of returned Tensor, by default `torch.preserve_format`. 138 | 139 | Returns 140 | ------- 141 | MovableData 142 | The object moved to specified device 143 | """ 144 | for attribute, tensor in self._get_tensors().items(): 145 | cpu_tensor = tensor.cpu(memory_format=memory_format) 146 | self._set_tensor(attribute, cpu_tensor) 147 | self.device = torch.device('cpu') 148 | return self 149 | 150 | def to(self, *args, **kwargs): 151 | """ 152 | Performs Tensor dtype and/or device conversion. 153 | A `torch.dtype` and `torch.device` are inferred from the arguments of 154 | `self.to(*args, **kwargs)`. 155 | 156 | Here are the ways to call `to`: 157 | 158 | `to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format)` -> Tensor 159 | Returns MovableData with specified `dtype` 160 | 161 | `to(device=None, dtype=None, non_blocking=False, copy=False, 162 | memory_format=torch.preserve_format)` -> Tensor 163 | Returns MovableData on specified `device` 164 | 165 | `to(other, non_blocking=False, copy=False)` -> Tensor 166 | Returns MovableData with same `dtype` and `device` as `other` 167 | Returns 168 | ------- 169 | MovableData 170 | The object moved to specified device 171 | """ 172 | for attribute, tensor in self._get_tensors().items(): 173 | tensor = tensor.to(*args, **kwargs) 174 | self._set_tensor(attribute, tensor) 175 | self.device = tensor.device 176 | return self 177 | 178 | def get_device(self): 179 | """ 180 | For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor 181 | resides. For CPU tensors, an error is thrown. 182 | 183 | Returns 184 | ------- 185 | int 186 | The device ordinal 187 | 188 | Raises 189 | ------ 190 | DeviceNotRetrievableError 191 | This description is thrown if the tensor is currently on the cpu, 192 | therefore, no device ordinal exists. 193 | """ 194 | if self.device.type == "cpu": 195 | raise DeviceNotRetrievableError(self.device) 196 | return self.device.index 197 | 198 | class BaseDataset(MovableData): 199 | """The base class for Datasets in this library. 200 | 201 | It extends `MovableData`. 202 | 203 | Attributes 204 | ---------- 205 | device: torch.device, optional 206 | The device the movable data currently is located on. 207 | data_container: DataContainer 208 | The DataContainer the dataset is based on 209 | id: str 210 | An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. 211 | 212 | Methods 213 | ------- 214 | cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) 215 | Moves the MovableData to specified or default CUDA device. 216 | cpu(memory_format=torch.preserve_format) 217 | Moves the MovableData to cpu. 218 | to(*args, **kwargs) 219 | Moves the MovableData to specified device. 220 | See `torch.Tensor.to(...)` for more details on usage. 221 | get_device() 222 | Returns the CUDA device number if possible. Raises `DeviceNotRetrievableError` otherwise. 223 | 224 | Inheritance 225 | ----------- 226 | See `MovableData` for details. 227 | 228 | """ 229 | def __init__(self, data_container, device=None): 230 | """ 231 | Parameters 232 | ---------- 233 | data_container: DataContainer 234 | The DataContainer the dataset uses 235 | device : torch.device, optional 236 | The device which the `MovableData` should be moved to on load, by default cpu. 237 | """ 238 | MovableData.__init__(self, device=device) 239 | self.data_container = data_container 240 | self.id = str(self.__class__.__name__) 241 | if data_container is not None: 242 | self.id = self.id + "[" + str(data_container.id) + "]" 243 | 244 | 245 | class IterableDataset(BaseDataset, torch.utils.data.Dataset): 246 | def __init__(self, data_container, device=None): 247 | BaseDataset.__init__(self, data_container, device=device) 248 | torch.utils.data.Dataset.__init__(self) 249 | 250 | def __len__(self): 251 | if type(self) is IterableDataset: 252 | raise NotImplementedError() from None 253 | 254 | def __getitem__(self, index): 255 | if type(self) is IterableDataset: 256 | raise NotImplementedError() from None 257 | 258 | 259 | class SaveableDataset(IterableDataset): 260 | def __init__(self, data_container, device=None): 261 | IterableDataset.__init__(self,data_container, device=device) 262 | 263 | def _get_variable_elements_data(self): 264 | lengths = np.zeros(len(self), dtype=int) 265 | for i, (inp, out) in enumerate(self): 266 | assert len(inp) == len(out) 267 | lengths[i] = len(inp) 268 | return lengths, inp.shape[1:], out.shape[1:] 269 | 270 | def saveToPath(self, path): 271 | 272 | os.makedirs(path, exist_ok=True) 273 | lengths, in_shape, out_shape = self._get_variable_elements_data() 274 | print(lengths) 275 | data_length = int(np.sum(lengths)) 276 | 277 | in_shape=tuple([data_length] + list(in_shape)) 278 | out_shape=tuple([data_length] + list(out_shape)) 279 | 280 | inp_memmap = np.memmap(os.path.join(path, 'input.npy'), dtype='float32', shape=in_shape, mode='w+') 281 | out_memmap = np.memmap(os.path.join(path, 'output.npy'), dtype='float32', shape=out_shape, mode='w+') 282 | 283 | idx = 0 284 | assert (len(self) == len(lengths)) 285 | for i in range(len(self)): 286 | inp,out = self[i] 287 | print(i, ": ", inp.shape, " l " ,lengths[i], " - ",lengths.shape) 288 | assert(len(inp) == lengths[i]) 289 | inp_memmap[idx:(idx + lengths[i])] = inp.numpy() 290 | out_memmap[idx:(idx + lengths[i])] = out.numpy() 291 | idx = idx + lengths[i] 292 | print("{}/{}".format(i, len(lengths)), end="\r") 293 | np.save(os.path.join(path, 'lengths.npy'), lengths) 294 | with open(os.path.join(path, 'info.json'), 'w') as infofile: 295 | json.dump({"id": self.id, "input_shape":in_shape, "output_shape":out_shape}, infofile) 296 | 297 | 298 | class LoadedDataset(IterableDataset): 299 | def __init__(self, path, device=None, passSingleElements=False): 300 | IterableDataset.__init__(self, None, device=device) 301 | self.path = path 302 | 303 | with open(os.path.join(self.path, 'info.json')) as infofile: 304 | info_data = json.load(infofile) 305 | self.id = info_data["id"] + "-loaded" 306 | 307 | inp_shape = tuple(info_data["input_shape"]) 308 | out_shape = tuple(info_data["output_shape"]) 309 | self.feature_shapes = np.prod(info_data["input_shape"][1:]), np.prod(info_data["output_shape"][1:]) 310 | 311 | if not passSingleElements: 312 | self.sl_lengths = np.load(os.path.join(self.path, 'lengths.npy')) 313 | else: 314 | self.sl_lengths = np.ones((inp_shape[0])) 315 | self.sl_start_indices = np.append(0, np.cumsum(self.sl_lengths)) 316 | 317 | self.inp_memmap = np.memmap(os.path.join(self.path, 'input.npy'), dtype='float32', shape=inp_shape, mode='r') 318 | self.out_memmap = np.memmap(os.path.join(self.path, 'output.npy'), dtype='float32', shape=out_shape, mode='r') 319 | 320 | def __len__(self): 321 | return len(self.sl_lengths) 322 | 323 | def __getitem__(self, index): 324 | inp = torch.from_numpy(self.inp_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]).to(self.device) 325 | out = torch.from_numpy(self.out_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]).to(self.device) 326 | return (inp, out) 327 | 328 | def get_feature_shapes(self): 329 | return self.feature_shapes 330 | 331 | class ConcatenatedDataset(SaveableDataset): 332 | def __init__(self, datasets, device=None): 333 | IterableDataset.__init__(self, None, device=device) 334 | self.id = self.id + "[" 335 | self.__lens = [0] 336 | for index, ds in enumerate(datasets): 337 | if not isinstance(ds, IterableDataset): 338 | raise WrongDatasetTypePassedError(self, ds, 339 | ("Dataset {} doesn't inherit IterableDataset. " 340 | "It is {} ").format(index, type(ds)) 341 | ) from None 342 | ds.to(self.device) 343 | self.id = self.id + ds.id + ", " 344 | self.__lens.append(len(ds) + self.__lens[-1]) 345 | self.id = self.id[:-2] + "]" 346 | self.datasets = datasets 347 | self.options = SimpleNamespace() 348 | 349 | def __len__(self): 350 | return self.__lens[-1] 351 | def __getitem__(self, index): 352 | if index >= len(self): 353 | raise IndexError('index {i} out of bounds for ConcatenatedDataset with length {l}.' 354 | .format(i=index, l=len(self))) from None 355 | i = 0 356 | while self.__lens[i+1] <= index: 357 | i = i + 1 358 | 359 | return self.datasets[i][index - self.__lens[i]] 360 | 361 | def get_feature_shapes(self): 362 | # assert that each dataset has same dataset shape 363 | (inp, out) = self.datasets[0].get_feature_shapes() 364 | for i in range(1, len(self.datasets)): 365 | (inp2, out2) = self.datasets[i].get_feature_shapes() 366 | if (not torch.all(torch.tensor(inp).eq(torch.tensor(inp2))) or 367 | not torch.all(torch.tensor(out).eq(torch.tensor(out2)))): 368 | raise FeatureShapesNotEqualError(i, (inp, out), (inp2, out2)) 369 | return (inp, out) 370 | 371 | def cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): 372 | for dataset in self.datasets: 373 | dataset.cuda(device=device, non_blocking=non_blocking, 374 | memory_format=memory_format) 375 | self.device = dataset.device 376 | return self 377 | 378 | def cpu(self, memory_format=torch.preserve_format): 379 | for dataset in self.datasets: 380 | dataset.cpu(memory_format=memory_format) 381 | self.device = dataset.device 382 | return self 383 | 384 | def to(self, *args, **kwargs): 385 | for dataset in self.datasets: 386 | dataset.to(*args, **kwargs) 387 | self.device = dataset.device 388 | return self 389 | 390 | class StreamlineDataset(SaveableDataset): 391 | 392 | def __init__(self, streamlines, data_container, processing, 393 | device=None, append_reverse=True, online_caching=True): 394 | IterableDataset.__init__(self, data_container, device=device) 395 | self.streamlines = streamlines 396 | self.id = self.id + "-{}-(".format(processing.id) + ")" 397 | self.options = SimpleNamespace() 398 | self.options.append_reverse = append_reverse 399 | self.options.online_caching = online_caching 400 | self.options.processing = processing 401 | if online_caching: 402 | self.cache = [None] * len(self) 403 | self.feature_shapes = None 404 | 405 | def _get_variable_elements_data(self): 406 | lengths = np.zeros(len(self) , dtype=int) 407 | for i, sl in enumerate(self.streamlines): 408 | lengths[i] = len(sl) 409 | if self.options.append_reverse: 410 | lengths[len(self.streamlines):] = lengths[:len(self.streamlines)] 411 | (inp, out) = self[0] 412 | return lengths, inp.shape[1:], out.shape[1:] 413 | 414 | 415 | def __len__(self): 416 | if self.options.append_reverse: 417 | return 2*len(self.streamlines) 418 | return len(self.streamlines) 419 | 420 | def __getitem__(self, index): 421 | if self.options.online_caching and self.cache[index] is not None: 422 | return self.cache[index] 423 | (inp, output) = self._calculate_item(index) 424 | inp = torch.from_numpy(inp).to(device=self.device, dtype=torch.float32) # TODO work on dtypes 425 | output = torch.from_numpy(output).to(device=self.device, dtype=torch.float32) 426 | 427 | if self.options.online_caching: 428 | self.cache[index] = (inp, output) 429 | return self.cache[index] 430 | else: 431 | return (inp, output) 432 | 433 | def _calculate_item(self, index): 434 | streamline = self._get_streamline(index) 435 | return self.options.processing.calculate_streamline(self.data_container, streamline) 436 | 437 | def _get_streamline(self, index): 438 | reverse = False 439 | if self.options.append_reverse and index >= len(self.streamlines): 440 | reverse = True 441 | index = index - len(self.streamlines) 442 | if reverse: 443 | streamline = self.streamlines[index][::-1] 444 | else: 445 | streamline = self.streamlines[index] 446 | return streamline 447 | 448 | 449 | def get_feature_shapes(self): 450 | if self.feature_shapes is None: 451 | dwi, next_dir = self[0] 452 | # assert that every type of data processing maintains same shape 453 | # and that every element has same shape 454 | input_shape = torch.tensor(dwi.shape) 455 | input_shape[0] = 1 456 | 457 | output_shape = torch.tensor(next_dir.shape) 458 | output_shape[0] = 1 459 | self.feature_shapes = (torch.prod(input_shape).item(), torch.prod(output_shape).item()) 460 | return self.feature_shapes 461 | 462 | def cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): 463 | if not self.options.online_caching: 464 | return 465 | dwi = None 466 | for index, el in enumerate(self.cache): 467 | if el is None: 468 | continue 469 | dwi, next_dir = el 470 | dwi = dwi.cuda(device=device, non_blocking=non_blocking, 471 | memory_format=memory_format) 472 | next_dir = next_dir.cuda(device=device, non_blocking=non_blocking, 473 | memory_format=memory_format) 474 | self.cache[index] = (dwi, next_dir) 475 | if self.device == dwi.device: # move is unnecessary 476 | return 477 | if dwi is not None: 478 | self.device = dwi.device 479 | return self 480 | 481 | def cpu(self, memory_format=torch.preserve_format): 482 | if not self.options.online_caching: 483 | return 484 | dwi = None 485 | for index, el in enumerate(self.cache): 486 | if el is None: 487 | continue 488 | dwi, next_dir = el 489 | dwi = dwi.cpu(memory_format=memory_format) 490 | next_dir = next_dir.cpu(memory_format=memory_format) 491 | self.cache[index] = (dwi, next_dir) 492 | if self.device == dwi.device: # move is unnecessary 493 | return 494 | if dwi is not None: 495 | self.device = dwi.device 496 | return self 497 | 498 | def to(self, *args, **kwargs): 499 | if not self.options.online_caching: 500 | return 501 | dwi = None 502 | for index, el in enumerate(self.cache): 503 | if el is None: 504 | continue 505 | dwi, next_dir = el 506 | dwi = dwi.to(*args, **kwargs) 507 | next_dir = next_dir.to(*args, **kwargs) 508 | self.cache[index] = (dwi, next_dir) 509 | if self.device == dwi.device: # move is unnecessary 510 | return 511 | if dwi is not None: 512 | self.device = dwi.device 513 | return self 514 | --------------------------------------------------------------------------------