├── .gitignore ├── HiddenStateExtractor ├── __init__.py ├── cv2_feature.py ├── deprecated │ ├── cpca.py │ ├── morphology_clustering.py │ ├── movement_clustering.py │ └── vq_vae_extra.py ├── losses.py ├── naive_imagenet.py ├── resnet.py ├── vae.py ├── vq_vae.py └── vq_vae_supp.py ├── LICENSE ├── NNsegmentation ├── __init__.py ├── data.py ├── layers.py ├── models.py └── run.py ├── NOVEMBER_Analysis.ipynb ├── NOVEMBER_Progenitor.ipynb ├── README.md ├── SingleCellPatch ├── __init__.py ├── extract_patches.py ├── generate_trajectories.py └── instance_clustering.py ├── configs ├── .config_run_patch.yml ├── __init__.py ├── bryant_rubella_experiments.yml ├── config_example.yml ├── config_reader.py └── config_run_patch.yml ├── documents ├── 1-preprocessing.md ├── 2-segmentation.md ├── 3-patching.md ├── 4-latent_encoding.md └── 5-dim_reduction.md ├── graphicalabstract_dynamorph.jpg ├── pipeline.jpg ├── pipeline ├── __init__.py ├── patch_VAE.py ├── preprocess.py ├── segmentation.py ├── segmentation_validation.py └── train_utils.py ├── plot_scripts ├── B4_temp.py ├── PC_samples.py ├── plotting_cm.py ├── plottings.ipynb ├── plottings.py └── recon_loss.py ├── requirements.txt ├── requirements ├── default.txt └── pwrai_docker.txt ├── run_VAE.py ├── run_dim_reduction.py ├── run_patch.py ├── run_preproc.py ├── run_segmentation.py └── run_training.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.eps 3 | *.gif 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /HiddenStateExtractor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/HiddenStateExtractor/__init__.py -------------------------------------------------------------------------------- /HiddenStateExtractor/cv2_feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Aug 12 09:53:46 2019 5 | 6 | @author: michaelwu 7 | """ 8 | import cv2 9 | import numpy as np 10 | import scipy 11 | import pickle 12 | import random 13 | import os 14 | import cmath 15 | import matplotlib.pyplot as plt 16 | from .naive_imagenet import preprocess, read_file_path, CHANNEL_MAX 17 | import multiprocessing as mp 18 | 19 | 20 | def extract_features(x, vector_size=32): 21 | """ Calculate KAZE features for input image 22 | 23 | Args: 24 | x (np.array): input image mat 25 | vector_size (int, optional): feature vector size 26 | 27 | Returns: 28 | np.array: KAZE features 29 | 30 | """ 31 | x = x.astype('uint8') 32 | try: 33 | dscs = [] 34 | alg = cv2.KAZE_create() 35 | for x_slice in x: 36 | # finding image keypoints 37 | kps = alg.detect(x_slice) 38 | kps = sorted(kps, key=lambda x: -x.response)[:vector_size] 39 | # computing descriptors vector 40 | kps, dsc = alg.compute(x_slice, kps) 41 | dsc = dsc.flatten() 42 | # Padding 43 | needed_size = (vector_size * 64) 44 | if dsc.size < needed_size: 45 | dsc = np.concatenate([dsc, np.zeros(needed_size - dsc.size)]) 46 | dscs.append(dsc) 47 | dscs = np.stack(dscs, 0) 48 | except Exception as e: 49 | print('Error: ' + str(e)) 50 | return None 51 | return dscs 52 | 53 | 54 | def worker(f_n): 55 | """ Helper function for parallelization """ 56 | x = preprocess(f_n, cs=[0, 1], channel_max=CHANNEL_MAX) 57 | y = extract_features(x, vector_size=32) 58 | return y 59 | 60 | 61 | def get_size(mask): 62 | """ Calculate cell size based on mask 63 | 64 | Args: 65 | mask (np.array): segmentation mask of a single cell 66 | 67 | Returns: 68 | int: number of pixels in the cell area 69 | int: size of the cell contour 70 | 71 | """ 72 | 73 | _, contours, _ = cv2.findContours(mask.astype('uint8'), 1, 2) 74 | areas = [cv2.contourArea(cnt) for cnt in contours] 75 | return mask.sum(), np.max(areas) 76 | 77 | 78 | def get_intensity_profile(dat, mask=None): 79 | """ Calculate peak phase/retardance values 80 | 81 | See docs of `get_size` for input details 82 | 83 | Args: 84 | dat (list): list of 2D np.arrays for each channel of the patch 85 | mask (np.array, optional): segmentation mask 86 | 87 | Returns: 88 | list (of tuples): list of intensity properties for each channel 89 | max phase intensity; 90 | 95th percentile phase intensity; 91 | 200-th value of top phase intensities; 92 | summed phase intensities. 93 | 94 | """ 95 | 96 | output = [] 97 | for channel_ind in range(len(dat)): 98 | channel_slice = dat[channel_ind] 99 | channel_slice = channel_slice / 65535. 100 | 101 | # bg = np.median(channel_slice[np.where(mask == 0)]) 102 | bg = 0. 103 | 104 | peak_int = ((channel_slice - bg) * mask).max() 105 | sum_int = ((channel_slice - bg) * mask).sum() 106 | intensities = (channel_slice - bg)[np.where(mask)] 107 | quantile_int = np.percentile(intensities, 95) 108 | top200_int = np.mean(sorted(intensities)[-200:]) 109 | 110 | output.append((peak_int, quantile_int, top200_int, sum_int)) 111 | 112 | return output 113 | 114 | 115 | # def get_aspect_ratio(dat): 116 | # """ Calcualte aspect ratio (cv2.minAreaRect) 117 | 118 | # This function is deprecated and should be replaced by `get_angle_apr` 119 | 120 | # See docs of `get_size` for input details. 121 | 122 | # Args: 123 | # dat (np.array): single cell patch 124 | 125 | # Returns: 126 | # float: width 127 | # float: height 128 | # float: angle of long axis 129 | 130 | # """ 131 | # _, contours, _ = cv2.findContours(dat[:, :, 2].astype('uint8'), 1, 2) 132 | # areas = [cv2.contourArea(cnt) for cnt in contours] 133 | # rect = cv2.minAreaRect(contours[np.argmax(areas)]) 134 | # w, h = rect[1] 135 | # ang = rect[2] 136 | # if w < h: 137 | # ang = ang - 90 138 | # return w, h, ang 139 | 140 | 141 | def rotate_bound(image, angle): 142 | """ Rotate target mask 143 | 144 | Args: 145 | image (np.array): target mask of single cell patch 146 | angle (float): rotation angle 147 | 148 | Returns: 149 | np.array: rotated mask 150 | 151 | """ 152 | # grab the dimensions of the image and then determine the 153 | # center 154 | (h, w) = image.shape[:2] 155 | (cX, cY) = (w // 2, h // 2) 156 | # grab the rotation matrix (applying the negative of the 157 | # angle to rotate clockwise), then grab the sine and cosine 158 | M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0) 159 | cos = np.abs(M[0, 0]) 160 | sin = np.abs(M[0, 1]) 161 | # compute the new bounding dimensions of the image 162 | nW = int((h * sin) + (w * cos)) 163 | nH = int((h * cos) + (w * sin)) 164 | # adjust the rotation matrix to take into account translation 165 | M[0, 2] += (nW / 2) - cX 166 | M[1, 2] += (nH / 2) - cY 167 | # perform the actual rotation and return the image 168 | return cv2.warpAffine(image, M, (nW, nH)) 169 | 170 | 171 | def get_angle_apr(mask): 172 | """ Find long axis and calcualte aspect ratio 173 | 174 | See docs of `get_size` for input details. 175 | 176 | Args: 177 | mask (np.array): segmentation mask of a single cell 178 | 179 | Returns: 180 | float: aspect ratio 181 | float: angle of long axis 182 | 183 | """ 184 | y, x = np.nonzero(mask) 185 | x = x - np.mean(x) 186 | y = y - np.mean(y) 187 | coords = np.stack([x, y], 0) 188 | cov = np.cov(coords) 189 | evals, evecs = np.linalg.eig(cov) 190 | main_axis = evecs[:, np.argmax(evals)] # Eigenvector with largest eigenvalue 191 | angle = cmath.polar(complex(*main_axis))[1] 192 | 193 | rotated = rotate_bound(mask, -angle/np.pi * 180) 194 | _, contours, _ = cv2.findContours(rotated.astype('uint8'), 1, 2) 195 | areas = [cv2.contourArea(cnt) for cnt in contours] 196 | rect = cv2.boundingRect(contours[np.argmax(areas)]) 197 | return rect[2], rect[3], angle 198 | 199 | 200 | def get_aspect_ratio_no_rotation(mask): 201 | """ Calcualte aspect ratio of untouched target mask 202 | 203 | See docs of `get_size` for input details. 204 | 205 | Args: 206 | mask (np.array): segmentation mask of a single cell 207 | 208 | Returns: 209 | float: width 210 | float: height 211 | 212 | """ 213 | _, contours, _ = cv2.findContours(mask.astype('uint8'), 1, 2) 214 | areas = [cv2.contourArea(cnt) for cnt in contours] 215 | rect = cv2.boundingRect(contours[np.argmax(areas)]) 216 | return rect[2], rect[3] 217 | 218 | 219 | # if __name__ == '__main__': 220 | # path = '/mnt/comp_micro/Projects/CellVAE' 221 | # fs = read_file_path(os.path.join(path, 'Data', 'StaticPatches')) 222 | # sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 223 | 224 | # sizes = {} 225 | # densities = {} 226 | # aprs = {} 227 | # aprs_nr = {} 228 | # for site in sites: 229 | # dat = pickle.load(open('./%s_all_patches.pkl' % site, 'rb')) 230 | # for f in dat: 231 | # d = dat[f]["masked_mat"] 232 | # sizes[f] = get_size(d) 233 | # densities[f] = get_density(d) 234 | # aprs[f] = get_aspect_ratio(d) 235 | # aprs_nr[f] = get_aspect_ratio_no_rotation(d) -------------------------------------------------------------------------------- /HiddenStateExtractor/deprecated/cpca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import contrastive 4 | from HiddenStateExtractor.naive_imagenet import read_file_path, DATA_ROOT 5 | import matplotlib 6 | from matplotlib import cm 7 | matplotlib.use('AGG') 8 | import matplotlib.pyplot as plt 9 | 10 | dats = pickle.load(open('./save_0005_bkp4.pkl', 'rb')) 11 | fs = pickle.load(open('./HiddenStateExtractor/file_paths_bkp.pkl', 'rb')) 12 | trajs = pickle.load(open('./HiddenStateExtractor/trajectory_in_inds.pkl', 'rb')) 13 | site_dat = torch.load('../data_temp/B4_all_adjusted_static_patches.pt') 14 | 15 | B4_dats = pickle.load(open('./save_0005_bkp4_B4.pkl', 'rb')) 16 | B4_fs = sorted(B4_dats.keys()) 17 | B4_dats = np.stack([B4_dats[f] for f in B4_fs], 0).reshape((len(B4_fs), -1)) 18 | B4_trajs = pickle.load(open('./HiddenStateExtractor/B4_trajectory_in_inds.pkl', 'rb')) 19 | 20 | mdl = contrastive.CPCA() 21 | projected_data, alphas = mdl.fit_transform(B4_dats, dats, return_alphas=True) 22 | 23 | 24 | for fold in range(1, 4): 25 | #os.mkdir('/data/michaelwu/PC_samples/cpca_alpha%d_PC1' % fold) 26 | #os.mkdir('/data/michaelwu/PC_samples/cpca_alpha%d_PC2' % fold) 27 | dats_ = projected_data[fold] 28 | plt.clf() 29 | fig, ax = plt.subplots() 30 | ax.scatter(dats_[:, 0], dats_[:, 1], s=0.5, edgecolors='none') 31 | plt.savefig('/data/michaelwu/PC_samples/cpca_alpha%d.png' % fold, dpi=300) 32 | 33 | names = [] 34 | out_paths = [] 35 | PC1s = dats_[:, 0] 36 | for i in range(5): 37 | rang = [np.quantile(PC1s, i * 0.2), np.quantile(PC1s, (i+1) * 0.2)] 38 | rang_fs = [f for i, f in enumerate(B4_fs) if rang[0] <= PC1s[i] < rang[1]] 39 | ct = 0 40 | base = np.zeros((128, 128), dtype=float) 41 | for j, f in enumerate(rang_fs): 42 | ind = B4_fs.index(f) 43 | slic = site_dat[ind][0][0].cpu().numpy().astype('float') 44 | base = base + slic 45 | ct += 1 46 | aver = base/ct 47 | aver = (aver * 65535).astype('uint16') 48 | cv2.imwrite('/data/michaelwu/PC_samples/cpca_alpha%d_PC1_fold%d_aver.png' % (fold, i), enhance_contrast(aver, a=2, b=-50000)) 49 | for j, f in enumerate(np.random.choice(rang_fs, (20,), replace=False)): 50 | names.append(f) 51 | out_paths.append('/data/michaelwu/PC_samples/cpca_alpha%d_PC1/PC1_%d_%d_sample%d.png' % (fold, i, i+1, j)) 52 | 53 | PC2s = dats_[:, 1] 54 | for i in range(5): 55 | rang = [np.quantile(PC2s, i * 0.2), np.quantile(PC2s, (i+1) * 0.2)] 56 | rang_fs = [f for i, f in enumerate(B4_fs) if rang[0] <= PC2s[i] < rang[1]] 57 | ct = 0 58 | base = np.zeros((128, 128), dtype=float) 59 | for j, f in enumerate(rang_fs): 60 | ind = B4_fs.index(f) 61 | slic = site_dat[ind][0][0].cpu().numpy().astype('float') 62 | base = base + slic 63 | ct += 1 64 | aver = base/ct 65 | aver = (aver * 65535).astype('uint16') 66 | cv2.imwrite('/data/michaelwu/PC_samples/cpca_alpha%d_PC2_fold%d_aver.png' % (fold, i), enhance_contrast(aver, a=2, b=-50000)) 67 | for j, f in enumerate(np.random.choice(rang_fs, (20,), replace=False)): 68 | names.append(f) 69 | out_paths.append('/data/michaelwu/PC_samples/cpca_alpha%d_PC2/PC2_%d_%d_sample%d.png' % (fold, i, i+1, j)) 70 | 71 | for name, out_path in zip(names, out_paths): 72 | ind = B4_fs.index(name) 73 | slic = (site_dat[ind][0][0].cpu().numpy() * 65535).astype('uint16') 74 | cv2.imwrite(out_path, enhance_contrast(slic, a=2, b=-50000)) 75 | 76 | 77 | 78 | # names = [] 79 | # out_paths = [] 80 | # np.random.seed(122) 81 | 82 | # PC1s = dats_[:, 0] 83 | # lower_ = np.quantile(PC1s, 0.2) 84 | # lower_fs = [f for i, f in enumerate(B4_fs) if PC1s[i] < lower_] 85 | # upper_ = np.quantile(PC1s, 0.8) 86 | # upper_fs = [f for i, f in enumerate(B4_fs) if PC1s[i] > upper_] 87 | # for i, f in enumerate(np.random.choice(lower_fs, (50,), replace=False)): 88 | # names.append(f) 89 | # out_paths.append('/home/michaelwu/cpca_PC1_lower_sample%d.png' % i) 90 | # for i, f in enumerate(np.random.choice(upper_fs, (50,), replace=False)): 91 | # names.append(f) 92 | # out_paths.append('/home/michaelwu/cpca_PC1_upper_sample%d.png' % i) 93 | 94 | 95 | # PC1_range = (np.quantile(PC1s, 0.4), np.quantile(PC1s, 0.6)) 96 | # PC2s = dats_[:, 1] 97 | # lower_ = np.quantile(PC2s, 0.2) 98 | # lower_fs = [f for i, f in enumerate(B4_fs) if PC2s[i] < lower_ and PC1_range[0] < PC1s[i] < PC1_range[1]] 99 | # upper_ = np.quantile(PC2s, 0.8) 100 | # upper_fs = [f for i, f in enumerate(B4_fs) if PC2s[i] > upper_ and PC1_range[0] < PC1s[i] < PC1_range[1]] 101 | # for i, f in enumerate(np.random.choice(lower_fs, (50,), replace=False)): 102 | # names.append(f) 103 | # out_paths.append('/home/michaelwu/cpca_PC2_lower_sample%d.png' % i) 104 | # for i, f in enumerate(np.random.choice(upper_fs, (50,), replace=False)): 105 | # names.append(f) 106 | # out_paths.append('/home/michaelwu/cpca_PC2_upper_sample%d.png' % i) 107 | 108 | 109 | # def enhance_contrast(mat, a=1.5, b=-10000): 110 | # mat2 = cv2.addWeighted(mat, a, mat, 0, b) 111 | # return mat2 112 | 113 | 114 | -------------------------------------------------------------------------------- /HiddenStateExtractor/deprecated/morphology_clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.decomposition import PCA 4 | from sklearn.manifold import TSNE 5 | from sklearn.cluster import KMeans 6 | import matplotlib.pyplot as plt 7 | from .naive_imagenet import DATA_ROOT, read_file_path 8 | import pickle 9 | import cv2 10 | import h5py 11 | from matplotlib.patches import Rectangle 12 | from matplotlib import cm 13 | import imageio 14 | #import tifffile 15 | 16 | def generate_cell_sizes(fs, out_path=None): 17 | sizes = {} 18 | for i, f_n in enumerate(fs): 19 | if i % 1000 == 0: 20 | print("Processed %d" % i) 21 | if not out_path is None: 22 | with open(out_path, 'wb') as f_w: 23 | pickle.dump(sizes, f_w) 24 | with h5py.File(f_n, 'r') as f: 25 | size = f['masked_mat'][:, :, 2].sum() 26 | sizes[f_n] = size 27 | if not out_path is None: 28 | with open(out_path, 'wb') as f_w: 29 | pickle.dump(sizes, f_w) 30 | return sizes 31 | 32 | def generate_cell_aspect_ratios(fs, out_path=None): 33 | aps = {} 34 | for i, f_n in enumerate(fs): 35 | if i % 1000 == 0: 36 | print("Processed %d" % i) 37 | if not out_path is None: 38 | with open(out_path, 'wb') as f_w: 39 | pickle.dump(aps, f_w) 40 | with h5py.File(f_n, 'r') as f: 41 | mask = np.array(f['masked_mat'][:, :, 2]).astype('uint8') 42 | 43 | cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[-2] 44 | cnt = sorted(cnts, key=cv2.contourArea, reverse=True)[0] 45 | rbox = cv2.minAreaRect(cnt) 46 | aps[f_n] = rbox[1][0]/rbox[1][1] 47 | if not out_path is None: 48 | with open(out_path, 'wb') as f_w: 49 | pickle.dump(aps, f_w) 50 | return sizes 51 | 52 | def select_clean_trajecteories(dats_, trajs): 53 | clean_trajs = {} 54 | traj_diffs_dict = {} 55 | for t in trajs: 56 | traj_dats_ = dats_[np.array(trajs[t])] 57 | traj_diffs = np.linalg.norm(traj_dats_[1:] - traj_dats_[:-1], ord=2, axis=1) 58 | traj_diffs_dict[t] = traj_diffs 59 | thr = np.quantile(np.concatenate(list(traj_diffs_dict.values())), 0.9) 60 | for t in trajs: 61 | if np.quantile(traj_diffs_dict[t], 0.7) < thr: 62 | clean_trajs[t] = trajs[t] 63 | return clean_trajs 64 | 65 | def read_trajectories(fs, out_path=None): 66 | latent_space_trajs = {} 67 | sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 68 | for site in sites: 69 | trajectories = pickle.load(open(DATA_ROOT + '/Data/DynamicPatches/%s/mg_traj.pkl' % site, 'rb'))[0] # Select from [trajectories, trajectories_positions] 70 | for i, t in enumerate(trajectories): 71 | names = [DATA_ROOT + '/Data/StaticPatches/%s/%d_%d.h5' % (site, k, t[k]) for k in sorted(t.keys())] 72 | inds = [fs.index(name) for name in names if name in fs] 73 | latent_space_trajs[site + '/%d' % i] = inds 74 | if not out_path is None: 75 | with open(out_path, 'wb') as f: 76 | pickle.dump(latent_space_trajs, f) 77 | return latent_space_trajs 78 | 79 | def step_displacement_histogram(vs, trajs): 80 | np.random.seed(123) 81 | traj_step_sizes = [] 82 | random_traj_step_sizes = [] 83 | for traj in trajs: 84 | traj_ = np.stack([vs[i] for i in traj], 0) 85 | step_sizes = np.linalg.norm(traj_[1:] - traj_[:-1], ord=2, axis=1) 86 | traj_step_sizes.append(step_sizes) 87 | 88 | random_traj = np.random.randint(0, len(vs), size=(len(traj),)) 89 | random_traj_ = np.stack([vs[i] for i in random_traj], 0) 90 | random_step_sizes = np.linalg.norm(random_traj_[1:] - random_traj_[:-1], ord=2, axis=1) 91 | random_traj_step_sizes.append(random_step_sizes) 92 | 93 | traj_step_sizes = np.concatenate(traj_step_sizes) 94 | random_traj_step_sizes = np.concatenate(random_traj_step_sizes) 95 | 96 | traj_step_sizes = np.array(traj_step_sizes)/np.median(random_traj_step_sizes) 97 | random_traj_step_sizes = np.array(random_traj_step_sizes)/np.median(random_traj_step_sizes) 98 | plt.clf() 99 | plt.hist(random_traj_step_sizes, bins=np.arange(0, 2, 0.02), color=(0, 0, 1, 0.5), label='random') 100 | plt.hist(traj_step_sizes, bins=np.arange(0, 2, 0.02), color=(0, 1, 0, 0.5), label='trajectory, mean: %f' % np.mean(traj_step_sizes)) 101 | plt.legend() 102 | 103 | def generate_short_traj_morphorlogy(vs, traj_list, length=5): 104 | short_trajs = [] 105 | for t in traj_list: 106 | n_sub_trajs = len(t) - (length - 1) 107 | for i in range(n_sub_trajs): 108 | sub_traj = t[i:(i+length)] 109 | 110 | sub_v = vs[np.array(sub_traj)] 111 | short_trajs.append(sub_v) 112 | short_trajs = np.stack(short_trajs, 0).reshape((len(short_trajs), -1)) 113 | return short_trajs 114 | 115 | def Kmean_on_short_trajs(vs, trajs, length=5, n_clusters=4): 116 | short_trajs = generate_short_traj_morphorlogy(vs, list(trajs.values()), length=length) 117 | short_trajs = short_trajs.reshape((len(short_trajs), -1)) 118 | 119 | clustering = KMeans(n_clusters=n_clusters) 120 | clustering.fit(short_trajs) 121 | predicted_classes = {} 122 | for t in trajs: 123 | sub_trajs = generate_short_traj_morphorlogy(vs, [trajs[t]], length=length) 124 | sub_trajs = sub_trajs.reshape((len(sub_trajs), -1)) 125 | labels = clustering.predict(sub_trajs) 126 | predicted_classes[t] = labels 127 | return predicted_classes 128 | 129 | def Kmean_on_short_traj_diffs(vs, trajs, length=5, n_clusters=4): 130 | short_trajs = generate_short_traj_morphorlogy(vs, list(trajs.values()), length=length) 131 | short_traj_diffs = (short_trajs[:, 1:] - short_trajs[:, :-1]).reshape((len(short_trajs), -1)) 132 | 133 | clustering = KMeans(n_clusters=n_clusters) 134 | clustering.fit(short_traj_diffs) 135 | predicted_classes = {} 136 | for t in trajs: 137 | sub_trajs = generate_short_traj_morphorlogy(vs, [trajs[t]], length=length) 138 | sub_traj_diffs = (sub_trajs[:, 1:] - sub_trajs[:, :-1]).reshape((len(sub_trajs), -1)) 139 | labels = clustering.predict(sub_traj_diffs) 140 | predicted_classes[t] = labels 141 | return predicted_classes 142 | 143 | def save_traj(k, output_path=None): 144 | input_path = DATA_ROOT + '/Data/DynamicPatches/%s/mg_traj_%s.tif' % (k.split('/')[0], k.split('/')[1]) 145 | # images = tifffile.imread(input_path) 146 | _, images = cv2.imreadmulti(input_path, flags=cv2.IMREAD_ANYDEPTH) 147 | images = np.array(images) 148 | if output_path is None: 149 | output_path = './%s.gif' % (t, k[:9] + '_' + k[10:]) 150 | imageio.mimsave(output_path, images) 151 | return 152 | 153 | 154 | if __name__ == '__main__': 155 | 156 | feat = 'save_0005_before' 157 | sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 158 | fs = sorted(read_file_path(DATA_ROOT + '/Data/StaticPatches')) 159 | 160 | # TRAJECTORIES 161 | #trajs = read_trajectories(fs, './trajectory_in_inds.pkl') 162 | trajs = pickle.load(open('./trajectory_in_inds.pkl', 'rb')) 163 | 164 | # IMAGE REPRESENTATIONS 165 | dats = pickle.load(open(DATA_ROOT + '/Data/%s.pkl' % feat, 'rb')) 166 | ks = sorted([k for k in dats.keys() if dats[k] is not None]) 167 | assert ks == fs 168 | vs = [dats[k] for k in ks] 169 | vs = np.stack(vs, 0).reshape((len(ks), -1)) 170 | 171 | # CELL SIZES 172 | #sizes = generate_cell_sizes(fs, path + '/Data/EncodedSizes.pkl') 173 | sizes = pickle.load(open(DATA_ROOT + '/Data/EncodedSizes.pkl', 'rb')) 174 | ss = [sizes[k] for k in ks] 175 | 176 | ########################################### 177 | step_displacement_histogram(vs, list(trajs.values())) 178 | ########################################### 179 | # pca = PCA(n_components=0.5) 180 | # dats_ = pca.fit_transform(vs) 181 | # with open('./%s_PCA.pkl' % feat, 'wb') as f: 182 | # pickle.dump(dats_, f) 183 | length = 5 184 | n_clusters = 3 185 | dats_ = pickle.load(open('./%s_PCA.pkl' % feat, 'rb')) 186 | 187 | clean_trajs = select_clean_trajecteories(dats_, trajs) 188 | 189 | traj_classes = Kmean_on_short_trajs(dats_, trajs, length=length, n_clusters=n_clusters) 190 | 191 | 192 | representative_trajs = {} 193 | try: 194 | os.mkdir('%s_clustered_traj_diffs' % feat) 195 | except: 196 | pass 197 | traj_names = list(traj_classes.keys()) 198 | np.random.shuffle(traj_names) 199 | for t in traj_names: 200 | if np.unique(traj_classes[t]).shape[0] == 1: 201 | cl = str(traj_classes[t][0]) 202 | elif np.unique(traj_classes[t]).shape[0] == 2: 203 | if np.unique(traj_classes[t][:5]).shape[0] == 1 and \ 204 | np.unique(traj_classes[t][-5:]).shape[0] == 1 and \ 205 | traj_classes[t][0] != traj_classes[t][-1]: 206 | cl = str(traj_classes[t][0]) + '_' + str(traj_classes[t][-1]) 207 | else: 208 | continue 209 | else: 210 | continue 211 | if not cl in representative_trajs: 212 | try: 213 | os.mkdir('./%s_clustered_traj_diffs/%s' % (feat, cl)) 214 | except: 215 | pass 216 | representative_trajs[cl] = [] 217 | representative_trajs[cl].append(t) 218 | if len(representative_trajs[cl]) < 50: 219 | save_traj(t, output_path='./%s_clustered_traj_diffs/%s/%s.gif' % (feat, cl, t[:9] + '_' + t[10:])) 220 | 221 | ############################################## 222 | color_range = [np.array((0., 0., 1., 0.5)), 223 | np.array((1., 0., 0., 0.5))] 224 | range_min = np.log(min(ss)) 225 | range_max = np.log(max(ss)) 226 | colors = [(np.log(s) - range_min)/(range_max - range_min) * color_range[0] + \ 227 | (range_max - np.log(s))/(range_max - range_min) * color_range[1] for s in ss] 228 | 229 | 230 | plt.clf() 231 | plt.scatter(dats_[:, 0], dats_[:, 1], c=colors, s=0.1) 232 | plt.legend() 233 | plt.xlabel("PC1") 234 | plt.ylabel("PC2") 235 | plt.savefig('/home/michaelwu/pca_%s.png' % feat, dpi=300) 236 | 237 | 238 | 239 | single_patch_classes = -np.ones((len(dats_),), dtype=int) 240 | for cl in range(n_clusters): 241 | trajs_cl = representative_trajs[str(cl)] 242 | for t in trajs_cl: 243 | for ind in trajs[t]: 244 | single_patch_classes[ind] = cl 245 | cmap = cm.get_cmap('tab10') 246 | colors = list(cmap.colors[:(n_clusters+1)]) 247 | colors[-1] = (0.8, 0.8, 0.8) 248 | 249 | unplotted = np.where(single_patch_classes < 0)[0] 250 | plotted = np.where(single_patch_classes >= 0)[0] 251 | plt.clf() 252 | 253 | plt.scatter(dats_[np.where(single_patch_classes==0)[0][0], 1], 254 | dats_[np.where(single_patch_classes==0)[0][0], 2], c=colors[0], s=1., label='Cluster 0') 255 | plt.scatter(dats_[np.where(single_patch_classes==1)[0][0], 1], 256 | dats_[np.where(single_patch_classes==1)[0][0], 2], c=colors[1], s=1., label='Cluster 1') 257 | plt.scatter(dats_[np.where(single_patch_classes==2)[0][0], 1], 258 | dats_[np.where(single_patch_classes==2)[0][0], 2], c=colors[2], s=1., label='Cluster 2') 259 | plt.scatter(dats_[unplotted][:, 0], dats_[unplotted][:, 1], c=(0.8, 0.8, 0.8), s=0.1) 260 | plt.scatter(dats_[plotted][:, 0], dats_[plotted][:, 1], c=[colors[i] for i in single_patch_classes[plotted]], s=0.1) 261 | 262 | plt.legend() 263 | plt.xlabel("PC1") 264 | plt.ylabel("PC2") 265 | plt.savefig('/home/michaelwu/pca_%s2.png' % feat, dpi=300) 266 | 267 | # range0 = [np.quantile(dats_[:, 0], 0.02), np.quantile(dats_[:, 0], 0.98)] 268 | # range1 = [np.quantile(dats_[:, 1], 0.02), np.quantile(dats_[:, 1], 0.98)] 269 | # range2 = [np.quantile(dats_[:, 2], 0.02), np.quantile(dats_[:, 2], 0.98)] 270 | # range0 = [range0[0] - (range0[1] - range0[0]) * 0.2, range0[1] + (range0[1] - range0[0]) * 0.2] 271 | # range1 = [range1[0] - (range1[1] - range1[0]) * 0.2, range1[1] + (range1[1] - range1[0]) * 0.2] 272 | # range2 = [range2[0] - (range2[1] - range2[0]) * 0.2, range2[1] + (range2[1] - range2[0]) * 0.2] 273 | # plt.xlim(range0) 274 | # plt.ylim(range1) 275 | 276 | 277 | 278 | -------------------------------------------------------------------------------- /HiddenStateExtractor/deprecated/movement_clustering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 4 16:38:26 2019 5 | 6 | @author: michaelwu 7 | """ 8 | import os 9 | import numpy as np 10 | import pickle 11 | import matplotlib.pyplot as plt 12 | from sklearn.cluster import KMeans 13 | from sklearn.decomposition import PCA 14 | from matplotlib import cm 15 | import imageio 16 | # import tifffile 17 | # import statsmodels.api as sm 18 | from .naive_imagenet import DATA_ROOT 19 | 20 | def generate_MSD_distri(trajectories_positions): 21 | MSD = {i: [] for i in range(1, 15)} 22 | for t in trajectories_positions: 23 | for t1 in sorted(t.keys()): 24 | for t2 in range(t1+1, min(max(t.keys())+1, t1+14)): 25 | if t2 in t: 26 | dist = np.linalg.norm(t[t2] - t[t1], ord=2) 27 | MSD[t2-t1].append(dist**2) 28 | return MSD 29 | 30 | def plot_MSD(trajectories_positions, fit=True, with_intercept=False, first_n_points=5): 31 | MSD = generate_MSD_distri(trajectories_positions) 32 | ks = sorted(MSD.keys()) 33 | points = np.array([(k, np.mean(MSD[k])) for k in ks]) 34 | 35 | plt.plot(points[:, 0], points[:, 1], '.-', label='MSD') 36 | 37 | X = points[:first_n_points, 0] 38 | y = points[:first_n_points, 1] 39 | if with_intercept: 40 | X = sm.add_constant(X) 41 | res = sm.OLS(y, X).fit() 42 | slope = res.params[1] 43 | intercept = res.params[0] 44 | plt.plot(points[:, 0], points[:, 0] * slope + intercept, '--', label='Linear Control') 45 | else: 46 | res = sm.OLS(y, X).fit() 47 | slope = res.params[0] 48 | plt.plot(points[:, 0], points[:, 0] * slope, '--', label='Linear Control') 49 | plt.legend() 50 | return 51 | 52 | def generate_short_traj_collections(trajectories_positions, length=5, raw=False): 53 | short_trajs = [] 54 | for t in trajectories_positions: 55 | t_keys = sorted(t.keys()) 56 | assert len(t_keys) > length 57 | for t_point in range(len(t_keys) - (length - 1)): 58 | if raw: 59 | short_trajs.append({t_keys[t_point+i]: t[t_keys[t_point+i]] for i in range(length)}) 60 | else: 61 | short_t = [t[t_keys[t_point + i]] for i in range(length)] 62 | 63 | short_t_ = [] 64 | #initial_position = short_t[0] 65 | for i in range(length - 1): 66 | d = np.linalg.norm(short_t[i+1] - short_t[i], ord=2) 67 | #d2 = np.linalg.norm(short_t[i+1] - initial_position, ord=2) 68 | short_t_.append(d) 69 | #short_t_2.append(d2/np.sqrt(i+1)) 70 | short_trajs.append(short_t_) 71 | return short_trajs 72 | 73 | def save_traj(k, output_path=None): 74 | input_path = DATA_ROOT + '/Data/DynamicPatches/%s/mg_traj_%s.tif' % (k.split('/')[0], k.split('/')[1]) 75 | # images = tifffile.imread(input_path) 76 | _, images = cv2.imreadmulti(input_path, flags=cv2.IMREAD_ANYDEPTH) 77 | images = np.array(images) 78 | if output_path is None: 79 | output_path = './%s.gif' % (t, k[:9] + '_' + k[10:]) 80 | imageio.mimsave(output_path, images) 81 | return 82 | 83 | if __name__ == '__main__': 84 | sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 85 | 86 | 87 | all_mg_trajs = {} 88 | all_non_mg_trajs = {} 89 | for site in sites: 90 | _, mg_trajectories_positions = pickle.load(open(DATA_ROOT + '/Data/DynamicPatches/%s/mg_traj.pkl' % site, 'rb')) 91 | _, non_mg_trajectories_positions = pickle.load(open(DATA_ROOT + '/Data/DynamicPatches/%s/non_mg_traj.pkl' % site, 'rb')) 92 | 93 | for i, traj in enumerate(mg_trajectories_positions): 94 | all_mg_trajs[site + '/%d' % i] = traj 95 | for i, traj in enumerate(non_mg_trajectories_positions): 96 | all_non_mg_trajs[site + '/%d' % i] = traj 97 | 98 | # Clustering 99 | np.random.seed(123) 100 | traj_length = 9 101 | n_clusters = 3 102 | short_trajs = generate_short_traj_collections(all_mg_trajs.values(), length=traj_length) 103 | clustering = KMeans(n_clusters=n_clusters) 104 | clustering.fit(short_trajs) 105 | 106 | 107 | clustering_labels = { 108 | 0: '00', 109 | 1: '0', 110 | 2: '000'} 111 | 112 | # PCA 113 | pca = PCA(n_components=3) 114 | short_trajs_ = pca.fit_transform(short_trajs) 115 | short_trajs_labels = clustering.predict(short_trajs) 116 | cmap = cm.get_cmap('tab10') 117 | plt.clf() 118 | for i in range(n_clusters): 119 | plt.scatter(short_trajs_[:, 0][np.where(short_trajs_labels == i)], 120 | short_trajs_[:, 2][np.where(short_trajs_labels == i)], 121 | s=0.1, 122 | color=cmap.colors[i], 123 | label='cluster_%s' % clustering_labels[i]) 124 | plt.legend() 125 | plt.xlabel("PC1") 126 | plt.ylabel("PC2") 127 | plt.savefig('/home/michaelwu/pca_movement.png', dpi=300) 128 | 129 | plt.clf() 130 | plt.plot(pca.components_[0]); 131 | plt.savefig('/home/michaelwu/pc1_movement_components.png', dpi=300) 132 | plt.clf() 133 | plt.plot(pca.components_[1]); 134 | plt.savefig('/home/michaelwu/pc2_movement_components.png', dpi=300) 135 | 136 | # Generate representative trajs 137 | 138 | stagnant_trajs = {} 139 | minor_moving_trajs = {} 140 | moving_trajs = {} 141 | other_trajs = {} 142 | for k in all_mg_trajs: 143 | sub_trajs = generate_short_traj_collections([all_mg_trajs[k]], length=traj_length) 144 | labels = [clustering_labels[l] for l in clustering.predict(sub_trajs)] 145 | 146 | if labels.count('0') > 0.7 * len(labels): 147 | stagnant_trajs[k] = labels 148 | elif set(labels) <= set(['0', '00']) or labels.count('00') > 0.7 * len(labels): 149 | # Contains ('00' only) and ('0' and '00') 150 | minor_moving_trajs[k] = labels 151 | elif set(labels) <= set(['00', '000']) or labels.count('000') > 0.4 * len(labels): 152 | # Contains all trajectories with '000' and '0000' but not '0' 153 | moving_trajs[k] = labels 154 | else: 155 | other_trajs[k] = labels 156 | 157 | clustered = {"stagnant": list(stagnant_trajs.keys()), 158 | "minor_moving": list(minor_moving_trajs.keys()), 159 | "moving": list(moving_trajs.keys())} 160 | 161 | 162 | # os.mkdir('./movement_clustered_trajs') 163 | # os.mkdir('./movement_clustered_trajs/stagnant') 164 | # os.mkdir('./movement_clustered_trajs/minor_moving') 165 | # os.mkdir('./movement_clustered_trajs/moving') 166 | # os.mkdir('./movement_clustered_trajs/other') 167 | # for k in np.random.choice(list(stagnant_trajs.keys()), (30,), replace=False): 168 | # save_traj(k, './movement_clustered_trajs/stagnant/%s.gif' % (k[:9] + '_' + k[10:])) 169 | # for k in np.random.choice(list(minor_moving_trajs.keys()), (30,), replace=False): 170 | # save_traj(k, './movement_clustered_trajs/minor_moving/%s.gif' % (k[:9] + '_' + k[10:])) 171 | # for k in np.random.choice(list(moving_trajs.keys()), (30,), replace=False): 172 | # save_traj(k, './movement_clustered_trajs/moving/%s.gif' % (k[:9] + '_' + k[10:])) 173 | # for k in np.random.choice(list(other_trajs.keys()), (30,), replace=False): 174 | # save_traj(k, './movement_clustered_trajs/other/%s.gif' % (k[:9] + '_' + k[10:])) 175 | 176 | # MSD curve 177 | plt.clf() 178 | plot_MSD(list(all_mg_trajs.values())) 179 | plt.xlabel("time step") 180 | plt.ylabel("distance^2") 181 | plt.savefig("/home/michaelwu/all_microglia_combined.png", dpi=300) 182 | plt.clf() 183 | plot_MSD([all_mg_trajs[t] for t in clustered["stagnant"]]) 184 | plt.xlabel("time step") 185 | plt.ylabel("distance^2") 186 | plt.savefig("/home/michaelwu/mg_stagnant.png", dpi=300) 187 | plt.clf() 188 | plot_MSD([all_mg_trajs[t] for t in clustered["minor_moving"]]) 189 | plt.xlabel("time step") 190 | plt.ylabel("distance^2") 191 | plt.savefig("/home/michaelwu/mg_minor_moving.png", dpi=300) 192 | plt.clf() 193 | plot_MSD([all_mg_trajs[t] for t in clustered["moving"]]) 194 | plt.xlabel("time step") 195 | plt.ylabel("distance^2") 196 | plt.savefig("/home/michaelwu/mg_moving.png", dpi=300) 197 | plt.clf() 198 | plot_MSD(list(all_non_mg_trajs.values())) 199 | plt.xlabel("time step") 200 | plt.ylabel("distance^2") 201 | plt.savefig("/home/michaelwu/all_non_microglia_combined.png", dpi=300) 202 | 203 | 204 | -------------------------------------------------------------------------------- /HiddenStateExtractor/deprecated/vq_vae_extra.py: -------------------------------------------------------------------------------- 1 | # 32 * 32 * 128, strong decoder 2 | #class VQ_VAE(nn.Module): 3 | # def __init__(self, 4 | # num_inputs=3, 5 | # num_hiddens=128, 6 | # num_residual_hiddens=64, 7 | # num_residual_layers=2, 8 | # num_embeddings=128, 9 | # commitment_cost=0.25, 10 | # channel_var=CHANNEL_VAR, 11 | # alpha=0.1, 12 | # **kwargs): 13 | # super(VQ_VAE, self).__init__(**kwargs) 14 | # self.num_inputs = num_inputs 15 | # self.num_hiddens = num_hiddens 16 | # self.num_residual_layers = num_residual_layers 17 | # self.num_residual_hiddens = num_residual_hiddens 18 | # self.num_embeddings = num_embeddings 19 | # self.commitment_cost = commitment_cost 20 | # self.channel_var = nn.Parameter(t.from_numpy(channel_var).float().reshape((1, 3, 1, 1)), requires_grad=False) 21 | # self.alpha = alpha 22 | # self.enc = nn.Sequential( 23 | # nn.Conv2d(self.num_inputs, self.num_hiddens//2, 1), 24 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens//2, 4, stride=2, padding=1), 25 | # nn.BatchNorm2d(self.num_hiddens//2), 26 | # nn.ReLU(), 27 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens, 4, stride=2, padding=1), 28 | # nn.BatchNorm2d(self.num_hiddens), 29 | # nn.ReLU(), 30 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 3, padding=1), 31 | # nn.BatchNorm2d(self.num_hiddens), 32 | # ResidualBlock(self.num_hiddens, self.num_residual_hiddens, self.num_residual_layers)) 33 | # self.vq = VectorQuantizer(self.num_hiddens, self.num_embeddings, commitment_cost=self.commitment_cost) 34 | # self.dec = nn.Sequential( 35 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 3, padding=1), 36 | # ResidualBlock(self.num_hiddens, self.num_residual_hiddens, self.num_residual_layers), 37 | # nn.ConvTranspose2d(self.num_hiddens, self.num_hiddens//2, 4, stride=2, padding=1), 38 | # nn.BatchNorm2d(self.num_hiddens//2), 39 | # nn.ReLU(), 40 | # nn.ConvTranspose2d(self.num_hiddens//2, self.num_hiddens//4, 4, stride=2, padding=1), 41 | # nn.BatchNorm2d(self.num_hiddens//4), 42 | # nn.ReLU(), 43 | # nn.Conv2d(self.num_hiddens//4, self.num_inputs, 1)) 44 | 45 | # 16*16*16, strong decoder 46 | #class VQ_VAE(nn.Module): 47 | # def __init__(self, 48 | # num_inputs=3, 49 | # num_hiddens=16, 50 | # num_residual_hiddens=64, 51 | # num_residual_layers=2, 52 | # num_embeddings=64, 53 | # commitment_cost=0.25, 54 | # channel_var=CHANNEL_VAR, 55 | # alpha=0.1, 56 | # **kwargs): 57 | # super(VQ_VAE, self).__init__(**kwargs) 58 | # self.num_inputs = num_inputs 59 | # self.num_hiddens = num_hiddens 60 | # self.num_residual_layers = num_residual_layers 61 | # self.num_residual_hiddens = num_residual_hiddens 62 | # self.num_embeddings = num_embeddings 63 | # self.commitment_cost = commitment_cost 64 | # self.channel_var = nn.Parameter(t.from_numpy(channel_var).float().reshape((1, 3, 1, 1)), requires_grad=False) 65 | # self.alpha = alpha 66 | # self.enc = nn.Sequential( 67 | # nn.Conv2d(self.num_inputs, self.num_hiddens//2, 1), 68 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens//2, 4, stride=2, padding=1), 69 | # nn.BatchNorm2d(self.num_hiddens//2), 70 | # nn.ReLU(), 71 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens, 4, stride=2, padding=1), 72 | # nn.BatchNorm2d(self.num_hiddens), 73 | # nn.ReLU(), 74 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 4, stride=2, padding=1), 75 | # nn.BatchNorm2d(self.num_hiddens), 76 | # nn.ReLU(), 77 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 3, padding=1), 78 | # nn.BatchNorm2d(self.num_hiddens), 79 | # ResidualBlock(self.num_hiddens, self.num_residual_hiddens, self.num_residual_layers)) 80 | # self.vq = VectorQuantizer(self.num_hiddens, self.num_embeddings, commitment_cost=self.commitment_cost) 81 | # self.dec = nn.Sequential( 82 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 3, padding=1), 83 | # ResidualBlock(self.num_hiddens, self.num_residual_hiddens, self.num_residual_layers), 84 | # nn.ConvTranspose2d(self.num_hiddens, self.num_hiddens//2, 4, stride=2, padding=1), 85 | # nn.BatchNorm2d(self.num_hiddens//2), 86 | # nn.ReLU(), 87 | # nn.ConvTranspose2d(self.num_hiddens//2, self.num_hiddens//4, 4, stride=2, padding=1), 88 | # nn.BatchNorm2d(self.num_hiddens//4), 89 | # nn.ReLU(), 90 | # nn.ConvTranspose2d(self.num_hiddens//4, self.num_hiddens//4, 4, stride=2, padding=1), 91 | # nn.BatchNorm2d(self.num_hiddens//4), 92 | # nn.ReLU(), 93 | # nn.Conv2d(self.num_hiddens//4, self.num_inputs, 1)) 94 | 95 | # 32*32*128, weak decoder 96 | #class VQ_VAE(nn.Module): 97 | # def __init__(self, 98 | # num_inputs=3, 99 | # num_hiddens=128, 100 | # num_residual_hiddens=64, 101 | # num_residual_layers=2, 102 | # num_embeddings=32, 103 | # commitment_cost=0.25, 104 | # channel_var=CHANNEL_VAR, 105 | # alpha=0.1, 106 | # **kwargs): 107 | # super(VQ_VAE, self).__init__(**kwargs) 108 | # self.num_inputs = num_inputs 109 | # self.num_hiddens = num_hiddens 110 | # self.num_residual_layers = num_residual_layers 111 | # self.num_residual_hiddens = num_residual_hiddens 112 | # self.num_embeddings = num_embeddings 113 | # self.commitment_cost = commitment_cost 114 | # self.channel_var = nn.Parameter(t.from_numpy(channel_var).float().reshape((1, 3, 1, 1)), requires_grad=False) 115 | # self.alpha = alpha 116 | # self.enc = nn.Sequential( 117 | # nn.Conv2d(self.num_inputs, self.num_hiddens//2, 1), 118 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens//2, 4, stride=2, padding=1), 119 | # nn.BatchNorm2d(self.num_hiddens//2), 120 | # nn.ReLU(), 121 | # nn.Conv2d(self.num_hiddens//2, self.num_hiddens, 4, stride=2, padding=1), 122 | # nn.BatchNorm2d(self.num_hiddens), 123 | # nn.ReLU(), 124 | # nn.Conv2d(self.num_hiddens, self.num_hiddens, 3, padding=1), 125 | # nn.BatchNorm2d(self.num_hiddens), 126 | # ResidualBlock(self.num_hiddens, self.num_residual_hiddens, self.num_residual_layers)) 127 | # self.vq = VectorQuantizer(self.num_hiddens, self.num_embeddings, commitment_cost=self.commitment_cost) 128 | # self.dec = nn.Sequential( 129 | # nn.ConvTranspose2d(self.num_hiddens, self.num_hiddens//2, 4, stride=2, padding=1), 130 | # nn.ReLU(), 131 | # nn.ConvTranspose2d(self.num_hiddens//2, self.num_hiddens//4, 4, stride=2, padding=1), 132 | # nn.ReLU(), 133 | # nn.Conv2d(self.num_hiddens//4, self.num_inputs, 1)) -------------------------------------------------------------------------------- /HiddenStateExtractor/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from typing import Tuple 5 | 6 | 7 | class TripletMiner(nn.Module): 8 | """Triplet Miner 9 | Adapted from https://github.com/TowardHumanizedInteraction/TripletTorch 10 | Tripelt Mining base class. 11 | Attributes 12 | ---------- 13 | margin: int 14 | Margin distance between positive and negative samples from anchor 15 | perspective. Default to 0.5. 16 | """ 17 | 18 | def __init__(self: 'TripletMiner', margin: int = .5) -> None: 19 | """Init 20 | Parameters 21 | ---------- 22 | margin: int 23 | Margin distance between positive and negative samples from 24 | anchor perspective. Default to 0.5. 25 | """ 26 | super(TripletMiner, self).__init__() 27 | self.margin = margin 28 | 29 | def _pairwise_dist( 30 | self: 'TripletMiner', 31 | embeddings: torch.Tensor 32 | ) -> torch.Tensor: 33 | """compute pairwise euclidean distances 34 | Parameters 35 | ---------- 36 | embeddings: torch.Tensor 37 | Embeddings is the ouput of the neural network given data 38 | samples from the dataset. 39 | Returns 40 | ------- 41 | pairwise_dist: torch.Tensor 42 | Pairwise distances between each samples. 43 | """ 44 | dot_product = torch.matmul(embeddings, embeddings.t()) 45 | square_norm = torch.diag(dot_product) 46 | pairwise_dist = square_norm.unsqueeze(0) - \ 47 | 2. * dot_product + square_norm.unsqueeze(1) 48 | pdn = pairwise_dist < 0. 49 | pairwise_dist[pdn] = 0. 50 | return pairwise_dist 51 | 52 | def forward( 53 | self: 'TripletMiner', 54 | ids: torch.Tensor, 55 | embeddings: torch.Tensor 56 | ) -> Tuple[torch.Tensor]: 57 | """Forward 58 | Parameters 59 | ---------- 60 | ids : torch.Tensor 61 | Labels of samples from the dataset respectively to the 62 | embeddings. 63 | embeddings: torch.Tensor 64 | Embeddings is the ouput of the neural network given data 65 | samples from the dataset. 66 | Raises 67 | ------ 68 | NotImplementedError: Based class does not provide implementation of the 69 | mining technique. 70 | """ 71 | raise NotImplementedError('Mining function not implemented yet!') 72 | 73 | 74 | class AllTripletMiner(TripletMiner): 75 | """AllTripletMiner 76 | The class provides mining for all valid triplet from a given dataset. 77 | Attributes 78 | ---------- 79 | margin: int 80 | Margin distance between positive and negative samples from anchor 81 | perspective. Default to 0.5. 82 | """ 83 | 84 | def __init__(self: 'AllTripletMiner', margin: int = .5) -> None: 85 | """Init 86 | Params 87 | ------ 88 | margin: int 89 | Margin distance between positive and negative samples from anchor 90 | perspective. Default to 0.5. 91 | """ 92 | super(AllTripletMiner, self).__init__(margin) 93 | 94 | def _triplet_mask(self: 'AllTripletMiner', ids: torch.Tensor) -> torch.Tensor: 95 | """TripletMask 96 | Parameters 97 | ---------- 98 | ids: torch.Tensor 99 | Labels of samples from the dataset respectively to the 100 | embeddings. 101 | Returns 102 | ------- 103 | mask: torch.Tensor 104 | Mask for every valid triplet from the selected samples. 105 | """ 106 | # eye = torch.eye( ids.size( 0 ), requires_grad = False ).cuda( ) if ids.is_cuda else \ 107 | eye = torch.eye(ids.size(0), requires_grad=False).to(ids.device) 108 | 109 | ids_not_eq = (1 - eye).bool() 110 | i_not_eq_j = ids_not_eq.unsqueeze(2) 111 | i_not_eq_k = ids_not_eq.unsqueeze(1) 112 | j_not_eq_k = ids_not_eq.unsqueeze(0) 113 | distinct_idx = ((i_not_eq_j & i_not_eq_k) & j_not_eq_k) 114 | 115 | ids_eq = ids.unsqueeze(0) == ids.unsqueeze(1) 116 | i_eq_j = ids_eq.unsqueeze(2) 117 | i_eq_k = ids_eq.unsqueeze(1) 118 | 119 | valid_ids = (i_eq_j & ~i_eq_k) 120 | mask = distinct_idx & valid_ids 121 | return mask 122 | 123 | def forward( 124 | self: 'AllTripletMiner', 125 | ids: torch.Tensor, 126 | embeddings: torch.Tensor 127 | ) -> Tuple[torch.Tensor]: 128 | """Forward 129 | Parameters 130 | ---------- 131 | ids : torch.Tensor 132 | Labels of samples from the dataset respectively to the 133 | embeddings. 134 | embeddings: torch.Tensor 135 | Embeddings is the ouput of the neural network given data 136 | samples from the dataset. 137 | Returns 138 | ------- 139 | loss : torch.Tensor 140 | Loss obtained with the AllTripletMiner sampling technique. 141 | f_pos_tri: torch.Tensor 142 | Proportion of postive triplets. Less is better. The value 143 | should decrease with training. 144 | """ 145 | pairwise_dist = self._pairwise_dist(embeddings) 146 | pos_dist = pairwise_dist.unsqueeze(2) 147 | neg_dist = pairwise_dist.unsqueeze(1) 148 | 149 | mask = self._triplet_mask(ids).float() 150 | loss = pos_dist - neg_dist + self.margin 151 | loss *= mask 152 | loss = torch.clamp(loss, min=0.) 153 | 154 | n_pos_tri = torch.sum((loss > 1e-16).float()) 155 | n_val_tri = torch.sum(mask) 156 | f_pos_tri = n_pos_tri / (n_val_tri + 1e-16) 157 | 158 | loss = torch.sum(loss) / (n_pos_tri + 1e-16) 159 | # loss = torch.mean(loss) 160 | 161 | return loss, f_pos_tri 162 | 163 | 164 | class HardNegativeTripletMiner(TripletMiner): 165 | """HardNegativeTripletMiner 166 | The class provides mining for hard negative triplet only. 167 | Attributes 168 | ---------- 169 | margin: int 170 | Margin distance between positive and negative samples from anchor 171 | perspective. Default to 0.5. 172 | """ 173 | 174 | def __init__(self: 'HardNegativeTripletMiner', margin: int = .5) -> None: 175 | """Init 176 | Params 177 | ------ 178 | margin: int 179 | Margin distance between positive and negative samples from anchor 180 | perspective. Default to 0.5. 181 | """ 182 | super(HardNegativeTripletMiner, self).__init__(margin) 183 | 184 | def _pos_dist( 185 | self: 'HardNegativeTripletMiner', 186 | ids: torch.Tensor, 187 | pairwise_dist: torch.Tensor 188 | ) -> torch.Tensor: 189 | """PositiveDistances 190 | Parameters 191 | ---------- 192 | ids : torch.Tensor 193 | Labels of samples from the dataset respectively to the 194 | embeddings. 195 | pairwise_dist: torch.Tensor 196 | Pairwise distances between each samples. 197 | Returns 198 | ------- 199 | anc_pos_dist: torch.Tensor 200 | Distances between positives and anchors. 201 | """ 202 | eye = torch.eye(ids.size(0), requires_grad=False).cuda() if ids.is_cuda else \ 203 | torch.eye(ids.size(0), requires_grad=False) 204 | 205 | mask_anc_pos = (~eye.bool() & (ids.unsqueeze(0) == ids.unsqueeze(1))) 206 | anc_pos_dist = mask_anc_pos.float() * pairwise_dist 207 | anc_pos_dist, _ = anc_pos_dist.max(axis=1, keepdim=True) 208 | return anc_pos_dist 209 | 210 | def _neg_dist( 211 | self: 'HardNegativeTripletMiner', 212 | ids: torch.Tensor, 213 | pairwise_dist: torch.Tensor 214 | ) -> torch.Tensor: 215 | """NegativeDistances 216 | Parameters 217 | ---------- 218 | ids : torch.Tensor 219 | Labels of samples from the dataset respectively to the 220 | embeddings. 221 | pairwise_dist: torch.Tensor 222 | Pairwise distances between each samples. 223 | Returns 224 | ------- 225 | anc_neg_dist: torch.Tensor 226 | Distances between negatives and anchors. 227 | """ 228 | mask_anc_neg = ids.unsqueeze(0) != ids.unsqueeze(1) 229 | max_anc_neg_dist, _ = pairwise_dist.max(axis=1, keepdim=True) 230 | anc_neg_dist = pairwise_dist + \ 231 | max_anc_neg_dist * (1. - mask_anc_neg.float()) 232 | anc_neg_dist = anc_neg_dist.mean(axis=1, keepdim=False) 233 | return anc_neg_dist 234 | 235 | def forward( 236 | self: 'HardNegativeTripletMiner', 237 | ids: torch.Tensor, 238 | embeddings: torch.Tensor 239 | ) -> Tuple[torch.Tensor]: 240 | """Forward 241 | Parameters 242 | ---------- 243 | ids : torch.Tensor 244 | Labels of samples from the dataset respectively to the 245 | embeddings. 246 | embeddings: torch.Tensor 247 | Embeddings is the ouput of the neural network given data 248 | samples from the dataset. 249 | Returns 250 | ------- 251 | loss : torch.Tensor 252 | Loss obtained with the HardNegativeTripletMiner sampling 253 | technique. 254 | _ : None 255 | To match the format of the AllTripletMiner 256 | """ 257 | pairwise_dist = self._pairwise_dist(embeddings) 258 | pos_dist = self._pos_dist(ids, pairwise_dist) 259 | neg_dist = self._neg_dist(ids, pairwise_dist) 260 | 261 | loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.) 262 | loss = loss.mean() 263 | return loss, None 264 | -------------------------------------------------------------------------------- /HiddenStateExtractor/naive_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import pickle 5 | import cv2 6 | 7 | 8 | CHANNEL_MAX = 65535. 9 | 10 | 11 | def read_file_path(root): 12 | """ Find all .h5 files 13 | 14 | Args: 15 | root (str): root folder path 16 | 17 | Returns: 18 | list of str: .h5 files 19 | 20 | """ 21 | files = [] 22 | for dir_name, dirs, fs in os.walk(root): 23 | for f in fs: 24 | if f.endswith('.h5'): 25 | files.append(os.path.join(dir_name, f)) 26 | return files 27 | 28 | 29 | def initiate_model(): 30 | """ Initialize a ResNet50 model with ImageNet pretrained weights 31 | 32 | Returns: 33 | keras model: ResNet50 model 34 | fn: data preprocessing function 35 | 36 | """ 37 | from classification_models.resnet import ResNet50 38 | from keras.models import Model 39 | from classification_models.resnet import preprocess_input as preprocess_input_resnet50 40 | model = ResNet50((224, 224, 3), weights='imagenet') 41 | target_layer = [l for l in model.layers if l.name == 'pool1'][0] 42 | hidden_extractor = Model(model.input, target_layer.output) 43 | hidden_extractor.compile(loss='mean_squared_error', optimizer='sgd') 44 | return hidden_extractor, preprocess_input_resnet50 45 | 46 | 47 | def initiate_model_inception(): 48 | """ Initialize a InceptionV2 model with ImageNet pretrained weights 49 | 50 | Returns: 51 | keras model: InceptionV2 model 52 | fn: data preprocessing function 53 | 54 | """ 55 | import classification_models.keras_applications as ka 56 | model = ka.inception_resnet_v2.InceptionResNetV2(input_shape=(224, 224, 3), 57 | weights='imagenet', 58 | include_top=False, 59 | pooling='avg') 60 | preprocess_fn = ka.inception_resnet_v2.preprocess_input 61 | return model, preprocess_fn 62 | 63 | 64 | def preprocess(f_n, cs=[0, 1], channel_max=CHANNEL_MAX): 65 | """ Preprocessing function (before model-specific preprocess_fn) 66 | 67 | Args: 68 | f_n (str): file path/single cell patch identifier 69 | cs (list of int, optional): channels in the input 70 | channel_max (list of float, optional): max val for each channel 71 | 72 | Returns: 73 | np.array: input with intensities scaled to [0, 255] 74 | 75 | """ 76 | dat = h5py.File(f_n, 'r')['masked_mat'] 77 | if cs is None: 78 | cs = np.arange(dat.shape[2]) 79 | stacks = [] 80 | for c in cs: 81 | patch_c = cv2.resize(np.array(dat[:, :, c]).astype(float), (224, 224)) 82 | stacks.append(np.stack([patch_c] * 3, 2)) 83 | 84 | x = np.stack(stacks, 0) 85 | x = x/np.array(channel_max).reshape((-1, 1, 1, 1)) 86 | x = x * 255. 87 | return x 88 | 89 | 90 | def predict(fs, 91 | extractor, 92 | preprocess_fn, 93 | batch_size=128, 94 | cs=[0, 1], 95 | channel_max=CHANNEL_MAX): 96 | """ Use ImageNet pretrained model to encode inputs 97 | 98 | Args: 99 | fs (list of str): list of input file paths 100 | extractor (keras model): pretrained model 101 | preprocess_fn (fn): model-specific preprocessing function 102 | batch_size (int, optional): batch size 103 | cs (list of int, optional): channels in the input 104 | channel_max (list of float, optional): max val for each channel 105 | 106 | Returns: 107 | list of np.array: encoded vectors 108 | 109 | """ 110 | temp_xs = [] 111 | for ct, f_n in enumerate(fs): 112 | if ct > 0 and ct % 1000 == 0: 113 | print(ct) 114 | x = preprocess_fn(preprocess(f_n, cs=cs, channel_max=channel_max)) 115 | temp_xs.append(x) 116 | if len(temp_xs) >= batch_size: 117 | temp_ys = extractor.predict(np.concatenate(temp_xs, 0)) 118 | slice_num = len(temp_ys) // len(temp_xs) 119 | for i in range(len(temp_xs)): 120 | y = temp_ys[i*slice_num:(i+1)*slice_num] 121 | ys.append(y) 122 | temp_xs = [] 123 | temp_ys = extractor.predict(np.concatenate(temp_xs, 0)) 124 | slice_num = len(temp_ys) // len(temp_xs) 125 | for i in range(len(temp_xs)): 126 | y = temp_ys[i*slice_num:(i+1)*slice_num] 127 | ys.append(y) 128 | assert len(ys) == len(fs) 129 | return ys 130 | 131 | 132 | # if __name__ == '__main__': 133 | # path = '/mnt/comp_micro/Projects/CellVAE' 134 | # fs = read_file_path(os.path.join(path, 'Data', 'StaticPatches')) 135 | # extractor, preprocess_fn = initiate_model() 136 | # ys = predict(fs, extractor, preprocess_fn, channels=[0, 1], channel_max=CHANNEL_MAX) 137 | 138 | # output = {} 139 | # for f_n, y in zip(fs, ys): 140 | # output[f_n] = y 141 | # with open(os.path.join(path, 'Data' 'EncodedResNet50.pkl'), 'wb') as f: 142 | # pickle.dump(output, f) 143 | 144 | 145 | -------------------------------------------------------------------------------- /HiddenStateExtractor/resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torchvision.models as models 3 | import torch 4 | from torch import nn 5 | from HiddenStateExtractor.losses import AllTripletMiner 6 | 7 | class ResNetEncoder(models.resnet.ResNet): 8 | """Wrapper for TorchVison ResNet Model 9 | This was needed to remove the final FC Layer from the ResNet Model""" 10 | def __init__(self, block, layers, num_inputs=2, cifar_head=False): 11 | """ 12 | Args: 13 | block (nn.Module): block to build the network 14 | layers (list): number to repeat each block 15 | num_inputs (int): number of input channels 16 | cifar_head (bool): Use modified network for cifar-10 data if True 17 | """ 18 | super().__init__(block, layers) 19 | self.cifar_head = cifar_head 20 | if cifar_head: 21 | self.conv1 = nn.Conv2d(num_inputs, 64, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn1 = self._norm_layer(64) 23 | self.relu = nn.ReLU(inplace=True) 24 | else: 25 | self.conv1 = nn.Conv2d(num_inputs, 64, kernel_size=7, stride=2, padding=3, 26 | bias=False) 27 | 28 | print('** Using avgpool **') 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | x = self.bn1(x) 33 | x = self.relu(x) 34 | if not self.cifar_head: 35 | x = self.maxpool(x) 36 | 37 | x = self.layer1(x) 38 | x = self.layer2(x) 39 | x = self.layer3(x) 40 | x = self.layer4(x) 41 | 42 | x = self.avgpool(x) 43 | x = torch.flatten(x, 1) 44 | 45 | return x 46 | 47 | class ResNet18(ResNetEncoder): 48 | def __init__(self, num_inputs=2, cifar_head=True): 49 | super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], num_inputs=num_inputs, cifar_head=cifar_head) 50 | 51 | 52 | class ResNet50(ResNetEncoder): 53 | def __init__(self, num_inputs=2, cifar_head=True): 54 | super().__init__(models.resnet.Bottleneck, [3, 4, 6, 3], num_inputs=num_inputs, cifar_head=cifar_head) 55 | 56 | class ResNet101(ResNetEncoder): 57 | def __init__(self, num_inputs=2, cifar_head=True): 58 | super().__init__(models.resnet.Bottleneck, [3, 4, 23, 3], num_inputs=num_inputs, cifar_head=cifar_head) 59 | 60 | class ResNet152(ResNetEncoder): 61 | def __init__(self, num_inputs=2, cifar_head=True): 62 | super().__init__(models.resnet.Bottleneck, [3, 8, 36, 3], num_inputs=num_inputs, cifar_head=cifar_head) 63 | 64 | class BatchNorm1dNoBias(nn.BatchNorm1d): 65 | def __init__(self, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | self.bias.requires_grad = False 68 | 69 | 70 | class EncodeProject(nn.Module): 71 | def __init__(self, 72 | arch='ResNet50', 73 | loss=AllTripletMiner(margin=1), 74 | num_inputs=2, 75 | cifar_head=False, 76 | device='cuda:0'): 77 | 78 | super().__init__() 79 | 80 | if arch == 'ResNet50': 81 | self.convnet = ResNet50(num_inputs=num_inputs, cifar_head=cifar_head) 82 | self.encoder_dim = 2048 83 | elif arch == 'ResNet101': 84 | self.convnet = ResNet101(num_inputs=num_inputs, cifar_head=cifar_head) 85 | self.encoder_dim = 2048 86 | elif arch == 'ResNet152': 87 | self.convnet = ResNet152(num_inputs=num_inputs, cifar_head=cifar_head) 88 | self.encoder_dim = 2048 89 | elif arch == 'ResNet18': 90 | self.convnet = ResNet18(num_inputs=num_inputs, cifar_head=cifar_head) 91 | self.encoder_dim = 512 92 | else: 93 | raise NotImplementedError 94 | 95 | num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad) 96 | 97 | print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters') 98 | 99 | self.proj_dim = 128 100 | projection_layers = [ 101 | ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)), 102 | ('bn1', nn.BatchNorm1d(self.encoder_dim)), 103 | ('relu1', nn.ReLU()), 104 | ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)), 105 | ('bn2', BatchNorm1dNoBias(128)), 106 | ] 107 | 108 | self.projection = nn.Sequential(OrderedDict(projection_layers)) 109 | self.loss = loss 110 | self.device = device 111 | 112 | def encode(self, x, out='z'): 113 | h = self.convnet(x) 114 | if out == 'h': 115 | return h 116 | elif out == 'z': 117 | z = self.projection(h) 118 | return z 119 | else: 120 | raise ValueError('"out" can only be "h" or "z", not {}'.format(out)) 121 | 122 | 123 | def forward(self, x, labels=None, time_matching_mat=None, batch_mask=None): 124 | z = self.encode(x) 125 | loss, f_pos_tri = self.loss(labels, z) 126 | loss_dict = {'total_loss': loss, 'positive_triplet': f_pos_tri} 127 | return z, loss_dict 128 | 129 | class LogisticRegression(nn.Module): 130 | def __init__(self, input_dim, n_class, device='cuda:0'): 131 | super().__init__() 132 | self.linear = nn.Linear(input_dim, n_class) 133 | self.linear.weight.data.zero_() 134 | self.linear.bias.data.zero_() 135 | self.device = device 136 | 137 | def forward(self, x, labels=None, time_matching_mat=None, batch_mask=None): 138 | z = self.linear(x) 139 | loss = nn.functional.cross_entropy(z, labels) 140 | torch.nn.CrossEntropyLoss() 141 | acc = (z.argmax(1) == labels).float().mean() 142 | loss_dict = {'total_loss': loss, 'acc': acc} 143 | return z, loss_dict -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Chan Zuckerberg Biohub 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /NNsegmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/NNsegmentation/__init__.py -------------------------------------------------------------------------------- /NNsegmentation/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Feb 14 17:51:20 2019 5 | 6 | @author: zqwu 7 | """ 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | from tensorflow import keras 12 | # from keras import backend as K 13 | # from keras.models import Model, load_model 14 | # from keras.layers import Dense, Layer, Input 15 | from sklearn.metrics import roc_auc_score, f1_score 16 | 17 | 18 | class SplitSlice(keras.layers.Layer): 19 | """ Customized layer for tensor reshape 20 | 21 | Used for 2.5D segmentation 22 | """ 23 | def __init__(self, 24 | n_channels, 25 | x_size, 26 | y_size, 27 | **kwargs): 28 | self.n_channels = n_channels 29 | self.x_size = x_size 30 | self.y_size = y_size 31 | super(SplitSlice, self).__init__(**kwargs) 32 | 33 | def build(self, input_shape): 34 | super(SplitSlice, self).build(input_shape) 35 | 36 | def call(self, x): 37 | # Input shape: (batch_size, n_channel, n_slice, x_size, y_size) 38 | # Output shape: (batch_size * n_slice, n_channel, x_size, y_size) 39 | _x = keras.backend.permute_dimensions(x, (0, 2, 1, 3, 4)) 40 | target_shape = (-1, self.n_channels, self.x_size, self.y_size) 41 | output = keras.backend.reshape(_x, target_shape) 42 | return output 43 | 44 | def compute_output_shape(self, input_shape): 45 | return tuple([input_shape[0], # batch 46 | input_shape[1], # c 47 | input_shape[-2], # x 48 | input_shape[-1]]) # y 49 | 50 | 51 | class MergeSlices(keras.layers.Layer): 52 | """ Customized layer for tensor reshape 53 | """ 54 | def __init__(self, 55 | n_slice=5, 56 | n_channel=32, 57 | **kwargs): 58 | self.n_slice = n_slice 59 | self.n_channel = n_channel 60 | self.output_dim = self.n_slice * self.n_channel 61 | super(MergeSlices, self).__init__(**kwargs) 62 | 63 | def build(self, input_shape): 64 | super(MergeSlices, self).build(input_shape) 65 | 66 | def call(self, x): 67 | # Input shape: (batch_size * n_slice, n_channel, x_size, y_size) 68 | # Output shape: (batch_size, n_slice * n_channel, x_size, y_size) 69 | x_shape = keras.backend.shape(x) 70 | _x = keras.backend.reshape(x, [x_shape[0]//self.n_slice, # Batch size 71 | self.n_slice, # n_slice 72 | self.n_channel, # n_channel 73 | x_shape[2], # x 74 | x_shape[3]]) # y 75 | 76 | output = keras.backend.reshape(_x, [x_shape[0]//self.n_slice, # Batch size 77 | self.output_dim, # n_slice * n_channel 78 | x_shape[2], # x 79 | x_shape[3]]) # y 80 | return output 81 | 82 | def compute_output_shape(self, input_shape): 83 | return tuple([input_shape[0], # avoiding None 84 | self.output_dim, 85 | input_shape[2], 86 | input_shape[3]]) 87 | 88 | 89 | class weighted_binary_cross_entropy(object): 90 | """ Customized loss function 91 | """ 92 | def __init__(self, n_classes=2): 93 | self.n_classes = n_classes 94 | self.__name__ = "weighted_binary_cross_entropy" 95 | 96 | def __call__(self, y_true, y_pred): 97 | """ 98 | Args: 99 | y_true (tensor): in shape (batch_size, x_size, y_size, n_classes + 1) 100 | first `n_classes` slices of the last dimension are labels 101 | last slice of the last dimension is weight 102 | y_pred (tensor): in shape (batch_size, x_size, y_size, n_classes) 103 | model predictions 104 | 105 | """ 106 | 107 | w = y_true[:, -1] 108 | y_true = y_true[:, :-1] 109 | 110 | # Switch to channel last form 111 | y_true = keras.backend.permute_dimensions(y_true, (0, 2, 3, 1)) 112 | y_pred = keras.backend.permute_dimensions(y_pred, (0, 2, 3, 1)) 113 | 114 | loss = keras.backend.categorical_crossentropy(y_true, y_pred, from_logits=True) * w 115 | return loss 116 | 117 | 118 | class ValidMetrics(keras.callbacks.Callback): 119 | """ Customized callback function for validation data evaluation 120 | 121 | Calculate ROC-AUC and F1 on validation data and test data (if applicable) 122 | after each epoch 123 | 124 | """ 125 | 126 | def __init__(self, valid_data=None, test_data=None): 127 | self.valid_data = valid_data 128 | self.test_data = test_data 129 | 130 | def on_epoch_end(self, epoch, logs={}): 131 | if self.valid_data is not None: 132 | y_pred = self.model.predict(self.valid_data[0])[:, 0] 133 | y_true = self.valid_data[1][:, 0] > 0.5 134 | roc = roc_auc_score(y_true.flatten(), y_pred.flatten()) 135 | f1 = f1_score(y_true.flatten(), y_pred.flatten()>0.5) 136 | print('\r valid-roc-auc: %f valid-f1: %f\n' % (roc, f1)) 137 | if self.test_data is not None: 138 | y_pred = self.model.predict(self.test_data[0])[:, 0] 139 | y_true = self.test_data[1][:, 0] > 0.5 140 | roc = roc_auc_score(y_true.flatten(), y_pred.flatten()) 141 | f1 = f1_score(y_true.flatten(), y_pred.flatten()>0.5) 142 | print('\r test-roc-auc: %f test-f1: %f\n' % (roc, f1)) 143 | return 144 | 145 | 146 | -------------------------------------------------------------------------------- /NNsegmentation/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Feb 6 13:22:55 2019 5 | 6 | @author: zqwu 7 | """ 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | from tensorflow import keras 12 | keras.backend.set_image_data_format('channels_first') 13 | import tempfile 14 | import os 15 | import scipy 16 | from scipy.special import logsumexp 17 | from copy import deepcopy 18 | # from keras import backend as K 19 | # from keras.models import Model, load_model 20 | # from keras.layers import Dense, Layer, Input, BatchNormalization, Conv2D, Lambda 21 | import segmentation_models 22 | from .layers import weighted_binary_cross_entropy, ValidMetrics, SplitSlice, MergeSlices 23 | from .data import load_input, preprocess 24 | 25 | 26 | def _softmax(arr, axis=-1): 27 | """ Helper function for performing softmax operation """ 28 | softmax_arr = np.exp(arr - logsumexp(arr, axis=axis, keepdims=True)) 29 | return softmax_arr 30 | 31 | 32 | class Segment(object): 33 | """ Semantic segmentation model based on U-Net """ 34 | 35 | def __init__(self, 36 | input_shape=(2, 256, 256), 37 | n_classes=3, 38 | freeze_encoder=False, 39 | model_path=None, 40 | **kwargs): 41 | """ Define model 42 | 43 | Args: 44 | input_shape (tuple of int, optional): shape of input features 45 | (without batch dimension), should be in the order of 46 | (c, x, y) or (c, z, x, y) 47 | n_classes (int, optional): number of prediction classes 48 | freeze_encoder (bool, optional): if to freeze backbone weights 49 | model_path (str or None, optional): path to save model weights 50 | if not given, a temp folder will be used 51 | 52 | """ 53 | 54 | self.input_shape = input_shape 55 | self.n_channels = self.input_shape[0] 56 | self.x_size, self.y_size = self.input_shape[-2:] 57 | 58 | self.n_classes = n_classes 59 | 60 | self.freeze_encoder = freeze_encoder 61 | if model_path is None: 62 | self.model_path = tempfile.mkdtemp() 63 | else: 64 | self.model_path = model_path 65 | self.call_backs = [keras.callbacks.TerminateOnNaN(), 66 | keras.callbacks.ReduceLROnPlateau(patience=5, min_lr=1e-7), 67 | keras.callbacks.ModelCheckpoint(self.model_path + '/weights.{epoch:02d}-{val_loss:.2f}.hdf5')] 68 | self.valid_score_callback = ValidMetrics() 69 | self.loss_func = weighted_binary_cross_entropy(n_classes=self.n_classes) 70 | self.build_model() 71 | 72 | 73 | def build_model(self): 74 | """ Define model structure and compile """ 75 | 76 | self.input = keras.layers.Input(shape=self.input_shape, dtype='float32') 77 | self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(self.input) 78 | 79 | self.unet = segmentation_models.Unet( 80 | backbone_name='resnet34', 81 | input_shape=(3, self.x_size, self.y_size), 82 | classes=self.n_classes, 83 | activation='linear', 84 | encoder_weights='imagenet', 85 | encoder_features='default', 86 | decoder_block_type='upsampling', 87 | decoder_filters=(256, 128, 64, 32, 16), 88 | decoder_use_batchnorm=True) 89 | 90 | output = self.unet(self.pre_conv) 91 | 92 | self.model = keras.models.Model(self.input, output) 93 | self.model.compile(optimizer='Adam', 94 | loss=self.loss_func, 95 | metrics=[]) 96 | 97 | 98 | def fit(self, 99 | patches, 100 | label_input='prob', 101 | batch_size=8, 102 | n_epochs=10, 103 | valid_patches=None, 104 | valid_label_input='prob', 105 | class_weights=None, 106 | **kwargs): 107 | """ Fit model 108 | 109 | Args: 110 | patches (list): list of input-label pairs 111 | see docs of `generate_patches` 112 | label_input (str or None, optional): 'prob' or 'annotation' or None 113 | label input type, probabilities or discrete annotation 114 | batch_size (int, optional): default=8, batch size 115 | n_epochs (int, optional): default=10, number of epochs 116 | valid_patches (list or None, optional): if given, input-label pairs 117 | of validation data 118 | valid_label_input (str, optional): 'prob' or 'annotation' 119 | label input type of `valid_patches` (if applicable) 120 | class_weights (None of list, optional): if given, specify training 121 | weights for different classes 122 | **kwargs: Other keyword arguments for keras model `fit` function 123 | 124 | """ 125 | 126 | if not os.path.exists(self.model_path): 127 | os.mkdir(self.model_path) 128 | # `X` and `y` should originally be 5 dimensional: (batch, c, z, x, y), 129 | # in default model z=1 will be neglected 130 | X, y = preprocess(patches, 131 | n_classes=self.n_classes, 132 | label_input=label_input, 133 | class_weights=class_weights) 134 | X = X.reshape(self.batch_input_shape) 135 | y = y.reshape(self.batch_label_shape) 136 | assert X.shape[0] == y.shape[0] 137 | 138 | validation_data = None 139 | if valid_patches is not None: 140 | valid_X, valid_y = preprocess(valid_patches, 141 | n_classes=self.n_classes, 142 | label_input=valid_label_input) 143 | valid_X = valid_X.reshape(self.batch_input_shape) 144 | valid_y = valid_y.reshape(self.batch_label_shape) 145 | assert valid_X.shape[0] == valid_y.shape[0] 146 | self.valid_score_callback.valid_data = (valid_X, valid_y) 147 | validation_data = (valid_X, valid_y) 148 | 149 | self.model.fit(x=X, 150 | y=y, 151 | batch_size=batch_size, 152 | epochs=n_epochs, 153 | verbose=1, 154 | callbacks=self.call_backs + [self.valid_score_callback], 155 | validation_data=validation_data, 156 | **kwargs) 157 | 158 | 159 | def predict(self, patches, label_input='prob'): 160 | """ Generate prediction for given data 161 | 162 | Args: 163 | patches (list): list of input-label pairs (label could be None) 164 | see docs of `generate_patches` 165 | label_input (str or None, optional): 'prob' or 'annotation' or None 166 | label input type, probabilities or discrete annotation 167 | 168 | """ 169 | 170 | if patches.__class__ is list: 171 | X, _ = preprocess(patches, label_input=label_input) 172 | X = X.reshape(self.batch_input_shape) 173 | y_pred = self.model.predict(X) 174 | elif patches.__class__ is np.ndarray: 175 | X = patches.reshape(self.batch_input_shape) 176 | y_pred = self.model.predict(X) 177 | else: 178 | raise ValueError("Input format not supported") 179 | y_pred = _softmax(y_pred, 1) 180 | assert y_pred.shape[1:] == (self.n_classes, self.x_size, self.y_size) 181 | y_pred = np.expand_dims(y_pred, 2) # Manually add z dimension 182 | return y_pred 183 | 184 | 185 | @property 186 | def batch_input_shape(self): 187 | return tuple([-1,] + list(self.input_shape)) 188 | 189 | 190 | @property 191 | def batch_label_shape(self): 192 | return tuple([-1, self.n_classes + 1, self.x_size, self.y_size]) 193 | 194 | 195 | def save(self, path): 196 | """ Save model weights to `path` """ 197 | self.model.save_weights(path) 198 | 199 | 200 | def load(self, path): 201 | """ Load model weights from `path` """ 202 | self.model.load_weights(path) 203 | 204 | 205 | 206 | class SegmentWithMultipleSlice(Segment): 207 | """ Semantic segmentation model with inputs having multiple time/z slices """ 208 | 209 | def __init__(self, 210 | unet_feat=32, 211 | **kwargs): 212 | """ Define model 213 | 214 | Args: 215 | unet_feat (int, optional): output dimension of unet (used as 216 | hidden units) 217 | **kwargs: keyword arguments for `Segment` 218 | note that `input_shape` should have 4 dimensions 219 | 220 | """ 221 | 222 | self.unet_feat = unet_feat 223 | super(SegmentWithMultipleSlice, self).__init__(**kwargs) 224 | self.n_slices = self.input_shape[1] # Input shape (c, z, x, y) 225 | 226 | 227 | def build_model(self): 228 | """ Define model structure and compile """ 229 | 230 | # input shape: batch_size, n_channel, n_slice, x_size, y_size 231 | self.input = keras.layers.Input(shape=self.input_shape, dtype='float32') 232 | 233 | # Combine time slice dimension and batch dimension 234 | inp = SplitSlice(self.n_channels, self.x_size, self.y_size)(self.input) 235 | self.pre_conv = keras.layers.Conv2D(3, (1, 1), activation=None, name='pre_conv')(inp) 236 | 237 | self.unet = segmentation_models.Unet( 238 | backbone_name='resnet34', 239 | input_shape=(3, self.x_size, self.y_size), 240 | classes=self.unet_feat, 241 | activation='linear', 242 | encoder_weights='imagenet', 243 | encoder_features='default', 244 | decoder_block_type='upsampling', 245 | decoder_filters=(256, 128, 64, 32, 16), 246 | decoder_use_batchnorm=True) 247 | 248 | output = self.unet(self.pre_conv) 249 | 250 | # Split time slice dimension and merge to channel dimension 251 | output = MergeSlices(self.n_slices, self.unet_feat)(output) 252 | output = keras.layers.Conv2D(self.unet_feat, (1, 1), activation='relu', name='post_conv')(output) 253 | output = keras.layers.Conv2D(self.n_classes, (1, 1), activation=None, name='pred_head')(output) 254 | 255 | self.model = keras.models.Model(self.input, output) 256 | self.model.compile(optimizer='Adam', 257 | loss=self.loss_func, 258 | metrics=[]) 259 | -------------------------------------------------------------------------------- /NNsegmentation/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Feb 7 18:10:01 2019 5 | 6 | @author: zqwu 7 | """ 8 | 9 | # Sample scripts for model training 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | os.environ['KERAS_BACKEND'] = 'tensorflow' 15 | import pickle 16 | from .data import generate_patches, generate_ordered_patches, predict_whole_map 17 | from .models import Segment 18 | from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score 19 | import cv2 20 | 21 | # Data path 22 | TRAIN_DATA_PATH = { 23 | 'annotation': '/mnt/comp_micro/Projects/CellVAE/Data/NNSegment/Annotations_8Sites.pkl', 24 | 'RFBG': '/mnt/comp_micro/Projects/CellVAE/Data/NNSegment/Annotations_BGRF_4Sites.pkl' 25 | } 26 | 27 | sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 28 | TEST_DATA_PATH = { 29 | site: '/mnt/comp_micro/Projects/CellVAE/Combined/%s.npy' % site for site in sites 30 | } 31 | 32 | # Training patches from human annotations 33 | train_patches = pickle.load(open(TRAIN_DATA_PATH['annotation'], 'rb')) 34 | 35 | # Supplementary training patches from RF predictions (background only) 36 | train_patches2 = pickle.load(open(TRAIN_DATA_PATH['RFBG'], 'rb')) 37 | combined = train_patches + train_patches2 38 | np.random.shuffle(combined) 39 | 40 | # Random patches used for monitoring 41 | test_patches = [train_patches[i] for i in np.random.choice(np.arange(len(train_patches)), (50,), replace=False)] 42 | 43 | # Define model 44 | model_path = './temp_save/' 45 | if not os.path.exists(model_path): 46 | os.mkdir(model_path) 47 | model = Segment(input_shape=(256, 256, 2), # Phase + Retardance 48 | unet_feat=32, 49 | fc_layers=[64, 32], 50 | n_classes=3, 51 | model_path=model_path) 52 | 53 | # In the first phase of training, only use human annotations 54 | for st in range(5): 55 | model.fit(train_patches, 56 | label_input='annotation', 57 | n_epochs=200, 58 | valid_patches=test_patches, 59 | valid_label_input='annotation') 60 | model.save(model.model_path + '/stage%d.h5' % st) 61 | 62 | # In the second phase of training, adding in RF background patches to refine edges 63 | for st in range(5): 64 | model.fit(combined, 65 | label_input='annotation', 66 | n_epochs=50, 67 | valid_patches=test_patches, 68 | valid_label_input='annotation') 69 | model.save(model.model_path + '/stage%d.h5' % st) 70 | 71 | # Generate predictions for all data 72 | for site in TEST_DATA_PATH: 73 | print(site) 74 | predict_whole_map(TEST_DATA_PATH[site], 75 | model, 76 | n_classes=3, 77 | batch_size=8, 78 | n_supp=5) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DynaMorph 2 | 3 | This repository is for sharing code related to **DynaMorph: self-supervised learning of morphodynamic states of live cells**. The related preprint is [here](https://www.biorxiv.org/content/10.1101/2020.07.20.213074v1). 4 | 5 | We summarize the components of the DynaMorph pipeline and the structure of this repository below. 6 | 7 | ![pipeline_fig](graphicalabstract_dynamorph.jpg) 8 | 9 | ### Table of contents: 10 | 11 | - [Requirements](#requirements) 12 | - [Getting Started](#getting-started) 13 | - [DynaMorph Pipeline](#dynamorph-pipeline) 14 | - [Label-free Imaging](#label-free-imaging) 15 | - [Cell Segmentation and Tracking](#cell-segmentation-and-tracking) 16 | - [Latent Representations of Morphology](#latent-representations-of-morphology) 17 | - [Usage](#usage) 18 | - [Citing DynaMorph](#citing-dynamorph) 19 | 20 | ## Requirements 21 | 22 | DynaMorph is developed and tested under Python 3.7, packages below are required. 23 | 24 | For u-net segmentation 25 | - [TensorFlow](https://www.tensorflow.org/) ==v.2.1 26 | - [segmentation-models](https://github.com/qubvel/segmentation_models) ==v1.0.1 27 | 28 | For preprcoessing, patching, latent-space encoding, latent-space training 29 | - [imageio](https://imageio.github.io/) 30 | - [tifffile](https://pypi.org/project/tifffile/) 31 | - [Matplotlib](https://matplotlib.org/) 32 | - [OpenCV](https://opencv.org/about/) 33 | - [PyTorch](https://pytorch.org/) 34 | - [SciPy](https://www.scipy.org/) 35 | - [scikit-learn](https://scikit-learn.org/) 36 | - [umap-learn](https://umap-learn.readthedocs.io/en/latest/#) 37 | - [pyyaml](https://pyyaml.org/) 38 | - [h5py](https://docs.h5py.org/en/stable/) 39 | - [POT](https://pythonot.github.io/) 40 | 41 | ## Installation 42 | 43 | To install this codebase, you must first install `git` 44 | 45 | [getting-started-installing-git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) 46 | 47 | Then you can clone this repository onto your computer: 48 | 49 | ```text 50 | git clone https://github.com/czbiohub/dynamorph.git 51 | ``` 52 | 53 | If you are missing a particular dependency (listed in /requirements/default.txt), you can install them using `pip` 54 | ```text 55 | pip install 56 | ``` 57 | ## Downloading the data 58 | Example data to test the pipeline can be downloaded [here](https://drive.google.com/drive/folders/11GoWDwaBo1PE5FO5tcnGCOzA4pzjf-Tk?usp=sharing). 59 | 60 | ## Issues and bug reports 61 | The pipeline is currently in beta and we are actively working to make the pipeline easy to use. If you encounter any bugs, please report them via Issues on this repository. 62 | 63 | ## Getting Started 64 | 65 | DynaMorph utilizes a broad set of deep learning and machine learning tools to analyze cell imaging data, the [pipeline](https://github.com/czbiohub/dynamorph/tree/master/pipeline) folder contains wrapper methods for easy access to the functionalities of DynaMorph. 66 | We also maintained some example scripts `run_preproc.py`, `run_segmentation.py`, `run_patch.py` and `run_VAE.py` to facilitate parallelization of data processing. 67 | Check [section](#cell-segmentation-and-tracking) below for functionalities this repo provides. 68 | 69 | ## DynaMorph Pipeline 70 | 71 | DynaMorph starts with raw image files from cell imaging experiments and sequentially applies a set of segmentation and encoding tools. Below we briefly introduced the main processing steps. 72 | 73 | 74 | ### Label-free Imaging, Cell Segmentation, and Tracking 75 | 76 | ![pipeline_fig](pipeline.jpg) 77 | 78 | Starting from any microscopy data (file format is .tif single-page series or multi-page stacks acquired from micro-manager) (panel A). In the dynamorph paper, we used Phase and Retardance images measured with Quantitative Label-Free Imaging with Phase and Polarization microscopy as the input. 79 | Then use a segmentation model of your choice to generate semantic segmentation maps from the input (panel C). 80 | 81 | Instance segmentation in this work is based on clustering, related methods can be found in `SingleCellPatch/extract_patches.py`. Cell tracking methods can be found in `SingleCellPatch/generate_trajectories.py`. 82 | 83 | To generate segmentation and tracking from scratch, follow steps below: 84 | 85 | ##### 1. (optional) prepare training images and labels for training segmentation models 86 | 87 | ##### 2. (optional) train a segmentation model to provide per-pixel class probabilities, see scripts in `NNsegmentation/run.py` 88 | 89 | ##### 3. prepare inputs as 5-D numpy arrays of shape (Ntime frames, Nchannels, Nslices, height, width), see `run_preproc.py` for an example 90 | 91 | ##### 4. apply trained segmentation model for semantic segmentation, see method `pipeline.segmentation.segmentation` or `run_segmentation.py` 92 | 93 | ##### 5. use predicted class probability maps for instance segmentation, see method `pipeline.segmentation.instance_segmentation` or `run_segmentation.py` 94 | 95 | ### Latent Representations of Morphology 96 | DynaMorph uses VQ-VAE to encode and reconstruct cell image patches, from which latent vectors are used as morphology descriptor. 97 | 98 | To extract single cell patches and employ morphology encoding, follow steps below: 99 | 100 | ##### 6. extract cell patches based on instance segmentation, see method `pipeline.patch_VAE.extract_patches` or `run_patch.py -m 'extract_patches'` 101 | 102 | ##### 7. extract cell trajectories based on instance segmentation, see method `pipeline.patch_VAE.extract_patches` or `run_patch.py -m 'build_trajectories'` 103 | 104 | ##### 8. train a VAE for cell patch latent-encoding, see method `run_training.py` 105 | 106 | ##### 9. assemble cell patches generated from step 7 to model-compatible datasets, see method `pipeline.patch_VAE.assemble_VAE` or `run_VAE.py -m 'assemble'` 107 | 108 | ##### 10. Generate latent representations for cell patches using trained VAE models, see method `pipeline.patch_VAE.process_VAE` or `run_VAE.py -m 'process'` 109 | 110 | 111 | ## Usage 112 | 113 | The dataset accompanying this repository is large and currently available upon request for demonstration. 114 | 115 | Scripts `run_preproc.py`, `run_segmentation.py`, `run_patch.py`, `run_VAE.py` and `run_training.py` provide command line interface to run each module. For details please check by using the `-h` option. 116 | Each CLI requires a configuration file (.yaml format) that contains parameters for each stage. Please see the example: `configs/config_example.yml` 117 | 118 | To run the dynamorph pipeline, data should first be assembled into 5-D numpy arrays ([step 3](#step3)). 119 | 120 | Semantic segmentation ([step 4](#step4)) and instance segmentation ([step 5](#step5))): 121 | 122 | python run_segmentation.py -m "segmentation" -c 123 | python run_segmentation.py -m "instance_segmentation" -c 124 | 125 | Extract patches from segmentation results ([step 6](#step6)), then connect them into trajectories ([step 7](#step7)): 126 | 127 | python run_patch.py -m "extract_patches" -c 128 | python run_patch.py -m "build_trajectories" -c 129 | 130 | Train a DNN model (VQ-VAE) to learn a representation of your image data ([step 8](#step8)): 131 | 132 | python run_training.py -c 133 | 134 | Transform image patches into DNN model (VQ-VAE) latent-space by running inference. ([step 9](#step9) and [10](#step10)): 135 | 136 | python run_VAE.py -m "assemble" -c 137 | python run_VAE.py -m "process" -c 138 | 139 | Reduce the dimension of latent vectors for visualization by fitting a PCA or UMAP model to the data. For PCA: 140 | 141 | python run_dim_reduction.py -m "pca" -c 142 | 143 | 144 | ## Citing DynaMorph 145 | 146 | To cite DynaMorph, please use the bibtex entry below: 147 | 148 | ``` 149 | @article{wu2020dynamorph, 150 | title={DynaMorph: learning morphodynamic states of human cells with live imaging and sc-RNAseq}, 151 | author={Wu, Zhenqin and Chhun, Bryant B and Schmunk, Galina and Kim, Chang and Yeh, Li-Hao and Nowakowski, Tomasz J and Zou, James and Mehta, Shalin B}, 152 | journal={bioRxiv}, 153 | year={2020}, 154 | publisher={Cold Spring Harbor Laboratory} 155 | } 156 | ``` 157 | 158 | ## Contact Us 159 | 160 | If you have any questions regarding this work or code in this repo, feel free to raise an issue or reach out to us through: 161 | - Zhenqin Wu 162 | - Bryant Chhun 163 | - Syuan-Ming Guo 164 | - Shalin Mehta 165 | -------------------------------------------------------------------------------- /SingleCellPatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/SingleCellPatch/__init__.py -------------------------------------------------------------------------------- /SingleCellPatch/instance_clustering.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Feb 8 21:27:26 2021 4 | 5 | @author: Zhenqin Wu 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import os 11 | import matplotlib 12 | matplotlib.use('AGG') 13 | import matplotlib.pyplot as plt 14 | import pickle 15 | from sklearn.cluster import DBSCAN 16 | from copy import copy 17 | 18 | """ Functions for clustering single cells from semantic segmentation """ 19 | 20 | def within_range(r, pos): 21 | """ Check if a given position is in window 22 | 23 | Args: 24 | r (tuple): window, ((int, int), (int, int)) in the form of 25 | ((x_low, x_up), (y_low, y_up)) 26 | pos (tuple): (int, int) in the form of (x, y) 27 | 28 | Returns: 29 | bool: True if `pos` is in `r`, False otherwise 30 | 31 | """ 32 | if pos[0] >= r[0][1] or pos[0] < r[0][0]: 33 | return False 34 | if pos[1] >= r[1][1] or pos[1] < r[1][0]: 35 | return False 36 | return True 37 | 38 | 39 | def check_segmentation_dim(segmentation): 40 | """ Check segmentation mask dimension. 41 | Add a background channel if n(channels)==1 42 | 43 | Args: 44 | segmentation: (np.array): segmentation mask for the frame 45 | 46 | """ 47 | 48 | assert len(segmentation.shape) == 4, "Semantic segmentation should be formatted with dimension (c, z, x, y)" 49 | n_channels, _, _, _ = segmentation.shape 50 | 51 | # binary segmentation has only foreground channel, add background channel 52 | if n_channels == 1: 53 | segmentation = np.concatenate([1 - segmentation, segmentation], axis=0) 54 | assert np.allclose(segmentation.sum(0), 1.), "Semantic segmentation doens't sum up to 1" 55 | return segmentation 56 | 57 | 58 | def instance_clustering(cell_segmentation, 59 | ct_thr=(500, 12000), 60 | instance_map=True, 61 | map_path=None, 62 | fg_thr=0.3, 63 | DBSCAN_thr=(10, 250)): 64 | """ Perform instance clustering on a static frame 65 | 66 | Args: 67 | cell_segmentation (np.array): segmentation mask for the frame, 68 | size (n_classes(3), z(1), x, y) 69 | ct_thr (tuple, optional): lower and upper threshold for cell size 70 | (number of pixels in segmentation mask) 71 | instance_map (bool, optional): if to save instance segmentation as an 72 | image 73 | map_path (str or None, optional): path to the image (if `instance_map` 74 | is True) 75 | fg_thr (float, optional): threshold of foreground, any pixel with 76 | predicted background prob less than this value would be regarded as 77 | foreground (MG or Non-MG) 78 | DBSCAN_thr (tuple, optional): parameters for DBSCAN, (eps, min_samples) 79 | 80 | Returns: 81 | (list * 3): 3 lists (MG, Non-MG, intermediate) of cell identifiers 82 | each entry in the list is a tuple of cell ID and cell center position 83 | np.array: array of x, y coordinates of foreground pixels 84 | np.array: array of cell IDs of foreground pixels 85 | 86 | """ 87 | cell_segmentation = check_segmentation_dim(cell_segmentation) 88 | all_cells = np.mean(cell_segmentation[0], axis=0) < fg_thr 89 | positions = np.array(list(zip(*np.where(all_cells)))) 90 | if len(positions) < 1000: 91 | # No cell detected 92 | return [], np.zeros((0, 2), dtype=int), np.zeros((0,), dtype=int) 93 | 94 | # DBSCAN clustering of cell pixels 95 | clustering = DBSCAN(eps=DBSCAN_thr[0], min_samples=DBSCAN_thr[1]).fit(positions) 96 | positions_labels = clustering.labels_ 97 | cell_ids, point_cts = np.unique(positions_labels, return_counts=True) 98 | 99 | cell_positions = [] 100 | for cell_id, ct in zip(cell_ids, point_cts): 101 | if cell_id < 0: 102 | # neglect unclustered pixels 103 | continue 104 | if ct <= ct_thr[0] or ct >= ct_thr[1]: 105 | # neglect cells that are too small/big 106 | continue 107 | points = positions[np.where(positions_labels == cell_id)[0]] 108 | # calculate cell center 109 | mean_pos = np.mean(points, 0).astype(int) 110 | # define window 111 | window = [(mean_pos[0]-128, mean_pos[0]+128), (mean_pos[1]-128, mean_pos[1]+128)] 112 | # skip if cell has too many outlying points 113 | outliers = [p for p in points if not within_range(window, p)] 114 | if len(outliers) > len(points) * 0.05: 115 | continue 116 | cell_positions.append((cell_id, mean_pos)) 117 | 118 | # Save instance segmentation results as image 119 | if instance_map and map_path is not None: 120 | x_size, y_size = cell_segmentation.shape[-2:] 121 | # bg as -1 122 | segmented = np.zeros((x_size, y_size)) - 1 123 | for cell_id, mean_pos in cell_positions: 124 | points = positions[np.where(positions_labels == cell_id)[0]] 125 | for p in points: 126 | segmented[p[0], p[1]] = cell_id%10 127 | plt.clf() 128 | # cmap = matplotlib.cm.get_cmap('tab10') 129 | cmap = copy(matplotlib.cm.get_cmap("tab10")) 130 | cmap.set_under(color='k') 131 | plt.imshow(segmented, cmap=cmap, vmin=-0.001, vmax=10.001) 132 | font = {'color': 'white', 'size': 4} 133 | for cell_id, mean_pos in cell_positions: 134 | plt.text(mean_pos[1], mean_pos[0], str(cell_id), fontdict=font) 135 | plt.axis('off') 136 | plt.savefig(map_path, dpi=300) 137 | return cell_positions, positions, positions_labels 138 | 139 | 140 | def process_site_instance_segmentation(raw_data, 141 | raw_data_segmented, 142 | site_supp_files_folder, 143 | **kwargs): 144 | """ 145 | Wrapper method for instance segmentation 146 | 147 | Results will be saved to the supplementary data folder as: 148 | "cell_positions.pkl": list of cells in each frame (IDs and positions); 149 | "cell_pixel_assignments.pkl": pixel compositions of cells; 150 | "segmentation_*.png": image of instance segmentation results. 151 | 152 | 153 | :param raw_data: (str) path to image stack (.npy) 154 | :param raw_data_segmented: (str) path to semantic segmentation stack (.npy) 155 | :param site_supp_files_folder: (str) path to the folder where supplementary files will be saved 156 | :param kwargs: 157 | :return: 158 | """ 159 | 160 | # TODO: Size is hardcoded here 161 | # Should be of size (n_frame, n_channels, z(1), x(2048), y(2048)), uint16 162 | print(f"\tLoading {raw_data}") 163 | image_stack = np.load(raw_data) 164 | # Should be of size (n_frame, n_classes, z(1), x(2048), y(2048)), float 165 | print(f"\tLoading {raw_data_segmented}") 166 | segmentation_stack = np.load(raw_data_segmented) 167 | 168 | cell_positions = {} 169 | cell_pixel_assignments = {} 170 | for t_point in range(image_stack.shape[0]): 171 | print("\tClustering time %d" % t_point) 172 | cell_segmentation = segmentation_stack[t_point] 173 | instance_map_path = os.path.join(site_supp_files_folder, 'segmentation_%d.png' % t_point) 174 | #TODO: expose instance clustering parameters in config 175 | res = instance_clustering(cell_segmentation, instance_map=True, map_path=instance_map_path) 176 | cell_positions[t_point] = res[0] # List of cell: (cell_id, mean_pos) 177 | cell_pixel_assignments[t_point] = res[1:] 178 | with open(os.path.join(site_supp_files_folder, 'cell_positions.pkl'), 'wb') as f: 179 | pickle.dump(cell_positions, f) 180 | with open(os.path.join(site_supp_files_folder, 'cell_pixel_assignments.pkl'), 'wb') as f: 181 | pickle.dump(cell_pixel_assignments, f) 182 | return 183 | -------------------------------------------------------------------------------- /configs/.config_run_patch.yml: -------------------------------------------------------------------------------- 1 | 2 | files: 3 | raw_dirs: ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_input_tstack', 4 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_input_tstack'] 5 | 6 | supp_dirs: ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_supp_tstack', 7 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_supp_tstack'] 8 | 9 | train_dirs: ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_train_tstack', 10 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_train_tstack'] 11 | 12 | val_dirs: ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_train_tstack/val', 13 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_train_tstack/val'] 14 | 15 | weights_dir: '/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_train_tstack/weights' 16 | # path to save weights during training, subdirectory of train_dirs 17 | 18 | 19 | preprocess: 20 | channels: ["Retardance", "Phase2D", "Brightfield"] 21 | fov: ['C5-Site_0','C5-Site_1'] 22 | multipage: False 23 | z_slice: 2 24 | 25 | 26 | patch: 27 | channels: [0,1,2] 28 | fov: ['C5-Site_0','C5-Site_1'] 29 | gpus: 1 30 | window_size: 256 31 | save_fig: False 32 | reload: True 33 | skip_boundary: False 34 | 35 | 36 | inference: 37 | model: 'VQ_VAE_z16' 38 | # VQ_VAE_z16 or UNet 39 | 40 | weights: '/gpfs/CompMicro/projects/dynamorph/CellVAE/save_0005_bkp4.pt' 41 | # pytorch weight file 42 | 43 | save_output: True 44 | # write .jpg results from run_vae.process 45 | 46 | gpus: 1 47 | gpu_id: 3 48 | fov: ['C5-Site_0','C5-Site_1'] 49 | channels: [0,1,2] 50 | num_classes: 3 51 | window_size: 256 52 | batch_size: 8 53 | num_pred_rnd: 5 54 | seg_val_cat: 'mg' 55 | 56 | 57 | training: 58 | # model definitions 59 | model: 'VQ_VAE_z32' 60 | num_inputs: 2 61 | num_hiddens: 64 62 | num_residual_hiddens: 64 63 | num_residual_layers: 2 64 | num_embeddings: 512 65 | 66 | # normalization 67 | channel_mean: [0.4, 0, 0.5] 68 | channel_std: [0.05, 0.05, 0.05] 69 | 70 | w_a: 1 71 | w_t: 0.5 72 | 73 | # training parameters 74 | commitment_cost: 0.25 75 | alpha: 0.002 76 | epochs: 5000 77 | learning_rate: 0.0001 78 | batch_size: 96 79 | gpus: True 80 | gpu_id: 3 81 | shuffle_data: False 82 | transform: True 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/configs/__init__.py -------------------------------------------------------------------------------- /configs/bryant_rubella_experiments.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | preprocess: 4 | image_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/20191107_1209_1_GW23/blank_bg_stabilized'] 5 | target_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/raw'] 6 | 7 | channels: ["Phase2D", "Retardance"] 8 | # list of channels to include. Resulting array is (phase, retardance, brightfield) = (0, 1, 2) index 9 | 10 | fov: ['C4-Site_0', 'C4-Site_1', 'C4-Site_2', 'C4-Site_3', 'C4-Site_4', 'C4-Site_5', 'C4-Site_6', 'C4-Site_7', 'C4-Site_8'] 11 | # fov: ['C4-Site_0'] 12 | 13 | multipage: True 14 | # if images are multipage-tiffs 15 | 16 | z_slice: 2 17 | # single integer to select the z-plane. "z###" must exist in file names 18 | 19 | pos_dir: True 20 | # if singlepage tiffs with z-stacks, select the z-slice to stack 21 | 22 | 23 | segmentation_inference: 24 | raw_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/raw'] 25 | supp_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/supp'] 26 | validation_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/val'] 27 | network: 'UNet' 28 | weights: '/gpfs/CompMicro/projects/dynamorph/model_save/final_reformat.h5' 29 | gpu_ids: [1,2,3] 30 | fov: ['C4-Site_0'] 31 | channels: [0,1] 32 | 33 | num_classes: 3 34 | window_size: 256 35 | batch_size: 32 36 | num_pred_rnd: 5 37 | seg_val_cat: 'mg' 38 | 39 | 40 | patch: 41 | raw_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/seg_test/raw'] 42 | supp_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/seg_test/supp'] 43 | channels: [0,1] 44 | fov: ['C4-Site_0'] 45 | # fov: ['C4-Site_1', 'C4-Site_2', 'C4-Site_3', 'C4-Site_4', 'C4-Site_5', 'C4-Site_6', 'C4-Site_7', 'C4-Site_8'] 46 | 47 | num_cpus: 4 48 | window_size: 256 49 | save_fig: False 50 | reload: False 51 | skip_boundary: False 52 | 53 | 54 | latent_encoding: 55 | raw_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/raw'] 56 | supp_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/supp'] 57 | weights: '/gpfs/CompMicro/projects/dynamorph/model_save/VQVAE_save.pt' 58 | save_output: True 59 | 60 | gpu_ids: [1,2,3] 61 | fov: ['C4-Site_0', 'C4-Site_1', ] 62 | # fov: ['C4-Site_0', 'C4-Site_1', 'C4-Site_2', 'C4-Site_3', 'C4-Site_4', 'C4-Site_5', 'C4-Site_6', 'C4-Site_7', 'C4-Site_8'] 63 | channels: [0,1] 64 | channel_mean: [0.4, 0] 65 | channel_std: [0.05, 0.05] 66 | patch_type: "masked_mat" 67 | 68 | network: 'VQ_VAE_z16' 69 | num_hiddens: 16 70 | num_residual_hiddens: 32 71 | num_embeddings: 64 72 | commitment_cost: 0.25 73 | 74 | 75 | dim_reduction: 76 | input_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/raw'] 77 | output_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/pca_out'] 78 | file_name_prefixes: ['C4'] 79 | weights_dir: '/gpfs/CompMicro/projects/dynamorph/microglia/rubella_experiments/NOVEMBER/new-3-17-2021/pca_out/weights' 80 | fit_model: False 81 | conditions: ['rubella'] 82 | 83 | 84 | training: 85 | # model definitions 86 | network: 'VQ_VAE_z32' 87 | num_inputs: 2 88 | num_hiddens: 16 89 | num_residual_hiddens: 32 90 | num_residual_layers: 2 91 | num_embeddings: 64 92 | 93 | # normalization 94 | channel_mean: [0.4, 0, 0.5] 95 | channel_std: [0.05, 0.05, 0.05] 96 | 97 | w_a: 1 98 | w_t: 0.5 99 | 100 | # training parameters 101 | commitment_cost: 0.25 102 | 103 | alpha: 0.002 104 | n_epochs: 5000 105 | learn_rate: 0.0001 106 | batch_size: 96 107 | gpus: True 108 | gpu_id: 3 109 | shuffle_data: False 110 | transform: True 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /configs/config_example.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | preprocess: 4 | image_dirs: [ 5 | '', 6 | '' 7 | ] 8 | target_dirs: [ 9 | '', 10 | '' 11 | ] 12 | 13 | channels: ["Retardance", "Phase2D", "Brightfield"] 14 | # list of channels to include. Resulting array is (phase, retardance, brightfield) = (0, 1, 2) index 15 | 16 | fov: ['C5-Site_0','C5-Site_1'] 17 | # list of subfolder or position indices that identifies a field-of-view 18 | # ex: 'all' to preprocess all positions 19 | # ex: ['C5-Site_0', 'C5-Site_1', 'C5-Site_2'] 20 | # ex: [1, 3, 10, 100] 21 | 22 | multipage: False 23 | # if images are multipage-tiffs 24 | 25 | z_slice: 2 26 | # single integer to select the in-focus z-plane. 27 | # "z###" must exist in file names 28 | # only required for multipage-tiff stacks 29 | 30 | pos_dir: True 31 | # whether each position is in a subdirectory (True), or in the same directory (False) 32 | 33 | 34 | segmentation_inference: 35 | raw_dirs: [ 36 | '', 37 | '' 38 | ] 39 | supp_dirs: [ 40 | '', 41 | '' 42 | ] 43 | validation_dirs: [ 44 | '', 45 | '' 46 | ] 47 | 48 | model: 'UNet' 49 | # only UNet was implemented for the dynamorph paper 50 | 51 | weights: '' 52 | 53 | gpu_ids: [1,2,3] 54 | # list of GPUs to distribute inference across 55 | 56 | fov: ['C4-Site_0', 'C4-Site_1', 'C4-Site_2'] 57 | # well positions (.npy is the output of "preprocess" above) 58 | 59 | channels: [0,1] 60 | num_classes: 3 61 | window_size: 256 62 | batch_size: 8 63 | num_pred_rnd: 5 64 | seg_val_cat: 'mg' 65 | 66 | 67 | patch: 68 | raw_dirs: [ 69 | '', 70 | '' 71 | ] 72 | supp_dirs: [ 73 | '', 74 | '' 75 | ] 76 | 77 | channels: [0,1] 78 | fov: ['C5-Site_0','C5-Site_1'] 79 | num_cpus: 4 80 | window_size: 256 81 | save_fig: False 82 | reload: False # ??? not functional? 83 | skip_boundary: False 84 | # True to skip patches whose edges exceed the image boundaries 85 | # False to pad patches with mean background values 86 | 87 | 88 | latent_encoding: 89 | raw_dirs: [ 90 | '', 91 | '' 92 | ] 93 | supp_dirs: [ 94 | '', 95 | '' 96 | ] 97 | 98 | weights: [''] 99 | # pytorch weight file 100 | 101 | save_output: True 102 | # write .jpg results from run_vae.process 103 | 104 | gpu_ids: [0,1,2] 105 | # list of GPU ids to run inference across. 106 | 107 | fov: ['C5-Site_0','C5-Site_1'] 108 | patch_type: "masked_mat" 109 | 110 | channels: [0,1] 111 | channel_mean: [0.4, 0] 112 | channel_std: [0.05, 0.05] 113 | # for each channel in channels, hardcode a mean and std for zscoring 114 | 115 | network: 'VQ_VAE_z16' 116 | # VQ_VAE_z16 or UNet 117 | num_hiddens: 16 118 | num_residual_hiddens: 32 119 | num_embeddings: 64 120 | commitment_cost: 0.25 121 | 122 | 123 | dim_reduction: 124 | input_dirs: [ 125 | '', 126 | '' 127 | ] 128 | output_dirs: [ 129 | '', 130 | '' 131 | ] 132 | # PCA transform outputs are written to output dirs 133 | 134 | weights_dirs: ''] 139 | 140 | 141 | training: 142 | raw_dirs: [ 143 | '/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_input_tstack', 144 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_input_tstack', 145 | ] 146 | supp_dirs: [ 147 | '/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_supp_tstack', 148 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_supp_tstack', 149 | ] 150 | 151 | weights_dirs: [ 152 | '/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_train_tstack', 153 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_train_tstack', 154 | ] 155 | 156 | # model hyperparameters 157 | network: 'VQ_VAE_z32' 158 | num_inputs: 2 159 | num_hiddens: 64 160 | num_residual_hiddens: 64 161 | num_residual_layers: 2 162 | num_embeddings: 512 163 | commitment_cost: 0.25 164 | weight_matching: 100 165 | margin: 1 166 | w_a: 1 167 | w_t: 0.5 168 | w_n: -0.5 169 | 170 | # normalization 171 | channel_mean: null 172 | channel_std: null 173 | 174 | ### microglia data#### 175 | # channel_mean: [0.4, 0, 0.5] 176 | # channel_std: [0.05, 0.05, 0.05] 177 | 178 | # training parameters 179 | 180 | n_epochs: 5000 181 | learn_rate: 0.0001 182 | batch_size: 768 183 | val_split_ratio: 0.15 184 | shuffle_data: False 185 | transform: True 186 | patience: 100 187 | n_pos_samples: 4 188 | gpu_id: 1 189 | start_model_path: null 190 | retrain: False 191 | start_epoch: 0 192 | earlystop_metric: 'positive_triplet' 193 | model_name: 'CM+kidney+A549_vqvae32_test' 194 | use_mask: False 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /configs/config_reader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | 4 | 5 | # replicate from aicsimageio logging mechanism 6 | ############################################################################### 7 | 8 | # modify the logging.ERROR level lower for more info 9 | # CRITICAL 10 | # ERROR 11 | # WARNING 12 | # INFO 13 | # DEBUG 14 | # NOTSET 15 | #TODO: Save log file to train or supp folders 16 | logging.basicConfig( 17 | level=logging.INFO, 18 | format="[%(levelname)4s: %(module)s:%(lineno)4s %(asctime)s] %(message)s", 19 | ) 20 | log = logging.getLogger(__name__) 21 | 22 | ############################################################################### 23 | 24 | # to add a new configuration parameter, simply add the string to the appropriate set here 25 | 26 | PREPROCESS = { 27 | 'image_dirs', 28 | 'target_dirs', 29 | 'channels', 30 | 'fov', 31 | 'pos_dir', 32 | 'multipage', 33 | 'z_slice', 34 | } 35 | 36 | SEGMENTATION_INFERENCE = { 37 | 'raw_dirs', 38 | 'supp_dirs', 39 | 'validation_dirs', 40 | 'network', 41 | 'weights', 42 | 'gpu_ids', 43 | 'fov', 44 | 'channels', 45 | 46 | 'num_classes', 47 | 'window_size', 48 | 'batch_size', 49 | 'num_pred_rnd', 50 | 'seg_val_cat' 51 | } 52 | 53 | PATCH = { 54 | 'raw_dirs', 55 | 'supp_dirs', 56 | 'channels', 57 | 'fov', 58 | 59 | 'num_cpus', 60 | 'window_size', 61 | 'save_fig', 62 | 'reload', 63 | 'skip_boundary' 64 | } 65 | 66 | # change this to "latent encoding" or similar 67 | LATENT_ENCODING = { 68 | 'raw_dirs', 69 | 'supp_dirs', 70 | 'weights', 71 | 'save_output', 72 | 'gpu_ids', 73 | 'fov', 74 | 'patch_type', 75 | 76 | 'channels', 77 | 'channel_mean', 78 | 'channel_std', 79 | 80 | 'network', 81 | 'num_classes', 82 | 'num_hiddens', 83 | 'num_residual_hiddens', 84 | 'num_embeddings', 85 | 'commitment_cost', 86 | } 87 | 88 | DIM_REDUCTION = { 89 | 'input_dirs', 90 | 'output_dirs', 91 | 'weights_dir', 92 | 'file_name_prefixes', 93 | 'fit_model', 94 | 'conditions' 95 | } 96 | 97 | TRAINING = { 98 | 'raw_dirs', 99 | 'supp_dirs', 100 | 'weights_dirs', 101 | 102 | 'network', 103 | 'num_inputs', 104 | 'num_hiddens', 105 | 'num_residual_hiddens', 106 | 'num_residual_layers', 107 | 'num_embeddings', 108 | 'weight_matching', 109 | 'margin', 110 | 'w_a', 111 | 'w_t', 112 | 'w_n', 113 | 'channel_mean', 114 | 'channel_std', 115 | 116 | 'commitment_cost', 117 | 'n_epochs', 118 | 'learn_rate', 119 | 'batch_size', 120 | 'val_split_ratio', 121 | 'shuffle_data', 122 | 'transform', 123 | 'patience', 124 | 'n_pos_samples', 125 | 'num_workers', 126 | 'gpu_id', 127 | 'start_model_path', 128 | 'retrain', 129 | 'start_epoch', 130 | 'earlystop_metric', 131 | 'model_name', 132 | 'use_mask', 133 | } 134 | 135 | 136 | class Object: 137 | pass 138 | 139 | 140 | class YamlReader(Object): 141 | 142 | def __init__(self): 143 | self.config = None 144 | 145 | # easy way to assign attributes to each category 146 | # self.files = Object() 147 | self.preprocess = Object() 148 | self.segmentation = Object() 149 | self.segmentation.inference = Object() 150 | self.patch = Object() 151 | self.latent_encoding = Object() 152 | self.dim_reduction = Object() 153 | self.training = Object() 154 | 155 | def read_config(self, yml_config): 156 | with open(yml_config, 'r') as f: 157 | self.config = yaml.load(f) 158 | 159 | self._parse_preprocessing() 160 | self._parse_segmentation() 161 | self._parse_patch() 162 | self._parse_inference() 163 | self._parse_dim_reduction() 164 | self._parse_training() 165 | 166 | def _parse_preprocessing(self): 167 | for key, value in self.config['preprocess'].items(): 168 | if key in PREPROCESS: 169 | setattr(self.preprocess, key, value) 170 | else: 171 | log.warning(f"yaml PREPROCESS config field {key} is not recognized") 172 | 173 | def _parse_segmentation(self): 174 | for key, value in self.config['segmentation_inference'].items(): 175 | if key in SEGMENTATION_INFERENCE: 176 | setattr(self.segmentation.inference, key, value) 177 | else: 178 | log.warning(f"yaml SEGMENTATION config field {key} is not recognized") 179 | 180 | def _parse_patch(self): 181 | for key, value in self.config['patch'].items(): 182 | if key in PATCH: 183 | setattr(self.patch, key, value) 184 | else: 185 | log.warning(f"yaml PATCH config field {key} is not recognized") 186 | 187 | def _parse_inference(self): 188 | for key, value in self.config['latent_encoding'].items(): 189 | if key in LATENT_ENCODING: 190 | setattr(self.latent_encoding, key, value) 191 | else: 192 | log.warning(f"yaml LATENT_ENCODING config field {key} is not recognized") 193 | 194 | def _parse_dim_reduction(self): 195 | for key, value in self.config['dim_reduction'].items(): 196 | if key in DIM_REDUCTION: 197 | setattr(self.dim_reduction, key, value) 198 | else: 199 | log.warning(f"yaml DIM REDUCTION config field {key} is not recognized") 200 | 201 | def _parse_training(self): 202 | for key, value in self.config['training'].items(): 203 | if key in TRAINING: 204 | setattr(self.training, key, value) 205 | else: 206 | log.warning(f"yaml TRAINING config field {key} is not recognized") 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /configs/config_run_patch.yml: -------------------------------------------------------------------------------- 1 | 2 | files: 3 | raw_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/subset_for_tests/mg/JUNE/raw'] 4 | 5 | supp_dirs: ['/gpfs/CompMicro/projects/dynamorph/microglia/subset_for_tests/mg/JUNE/supp'] 6 | 7 | train_dirs: [''] 8 | 9 | val_dirs: [''] 10 | 11 | weights_dir: '' 12 | # path to save weights during training, subdirectory of train_dirs 13 | 14 | 15 | preprocess: 16 | channels: ["Retardance", "Phase2D", "Brightfield"] 17 | fov: ['C5-Site_0','C5-Site_1'] 18 | multipage: False 19 | z_slice: 2 20 | 21 | 22 | patch: 23 | channels: [0,1,2] 24 | fov: ['D3-Site_0','D3-Site_1','D3-Site_2'] 25 | gpus: 1 26 | window_size: 256 27 | save_fig: True 28 | reload: True 29 | skip_boundary: False 30 | 31 | 32 | inference: 33 | model: 'VQ_VAE_z16' 34 | # VQ_VAE_z16 or UNet 35 | 36 | weights: '/gpfs/CompMicro/projects/dynamorph/CellVAE/save_0005_bkp4.pt' 37 | # pytorch weight file 38 | 39 | save_output: True 40 | # write .jpg results from run_vae.process 41 | 42 | gpus: 1 43 | gpu_id: 3 44 | fov: ['C5-Site_0','C5-Site_1'] 45 | channels: [0,1,2] 46 | num_classes: 3 47 | window_size: 256 48 | batch_size: 8 49 | num_pred_rnd: 5 50 | seg_val_cat: 'mg' 51 | 52 | 53 | training: 54 | # model definitions 55 | model: 'VQ_VAE_z32' 56 | num_inputs: 2 57 | num_hiddens: 64 58 | num_residual_hiddens: 64 59 | num_residual_layers: 2 60 | num_embeddings: 512 61 | 62 | # normalization 63 | channel_mean: [0.4, 0, 0.5] 64 | channel_std: [0.05, 0.05, 0.05] 65 | 66 | w_a: 1 67 | w_t: 0.5 68 | 69 | # training parameters 70 | commitment_cost: 0.25 71 | alpha: 0.002 72 | epochs: 5000 73 | learning_rate: 0.0001 74 | batch_size: 96 75 | gpus: True 76 | gpu_id: 3 77 | shuffle_data: False 78 | transform: True 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /documents/1-preprocessing.md: -------------------------------------------------------------------------------- 1 | # Preprocessing 2 | 3 | ## Purpose 4 | 5 | Subsequent stages of the dynamorph pipeline require raw data to be formatted properly. 6 | 7 | This section addresses necessary data structures to execute this CLI: 8 | 9 | ```text 10 | python run_preproc.py -c 11 | ``` 12 | 13 | ## input 14 | 15 | #### **folder structure and config file** 16 | 17 | If all the image files are contained in "position directories" such as : 18 | - "pos0", "pos1", "pos2" ... (generated by micro-manager grid generator), or such as: 19 | - "A1-Site_0", "A1-Site_1", "A2-Site_0" ... (generated by micro-manager HCS site generator) 20 | you should specify the config file flag `pos_dir` as `True` 21 | 22 | If all the image files are **NOT** contained in "position directories", this module will assume position information is encoded in the file name: 23 | - file name structure will contain "t###_p###_z###" string 24 | - this module will parse the necessary values based on the above string structure 25 | 26 | If you wish to process a subset of positions, you can specify those in the config file under "fov" 27 | 28 | #### **file structure** 29 | 30 | The desired input data format for the preprocessing cli should be one of either `single-page tiffs` or `multi-page tiffs` 31 | 32 | `single-page tiffs`: 33 | 34 | Single-page tiffs are parsed by filename and must contain string elements: 35 | - one of "Phase", "Retardance", "Brightfield" 36 | - a 3 digit integer represeting the z-slice 37 | - each single-page tiff is a .tif image file of shape (Y, X) 38 | 39 | `multi-page tiffs`: 40 | 41 | For the dynamorph paper, time series data showed slight jitter between time points. We generated large z-stacks that were "stabilized" using Gunnar-Farnebeck optical flow. The resulting "raw data" output had the following properties: 42 | 43 | - one of "Phase", "Retardance, "Brightfield" in the file name 44 | - dimensions of (T, Y, X) 45 | 46 | 47 | In both cases, there is a check that every image of the series has the same X-Y dimensions 48 | 49 | ## outupt 50 | 51 | The output file will be: 52 | 53 | - a single .npy of shape `(T, C, Z, Y, X)` 54 | - C represents "channel" and will always be length 3 55 | - Channel index 0, 1, 2 will correspond to "Phase", "Retardance", and "Brightfield" respectively 56 | - If one of the above channels is not present in the raw data, that array will be empty in the output. 57 | - .npy file is named after the position. So it would appear as "A1-Site_0.npy", "A1-Site_1.npy" etc.. or "pos0.npy", "pos1.npy" ... 58 | -------------------------------------------------------------------------------- /documents/2-segmentation.md: -------------------------------------------------------------------------------- 1 | # Segmentation 2 | 3 | ## Purpose 4 | 5 | In order to generate single-cell image patches, we must have some segmentation tool. This document outlines how to generate segmentations from .npy files built in `1-preprocessing.py`, and how to perform instance segmentation on those. 6 | 7 | 8 | The relevant CLI is: 9 | ```text 10 | python run_segmentation.py -m -c 11 | ``` 12 | 13 | where is one of "segmentation" or "instance_segmentation" and 14 | where is the full path to a .yml configuration file as specified in `.configs/config_example.yml` 15 | 16 | -------------------------------------------- 17 | #### **method = "segmentation"** 18 | 19 | This method selection will generate a "NNProbabiliites.npy" file for each of the sites specified in the "FOV" field, given a model architecture (only 'UNet' is supported currently) and a path to the UNet weights (defined in the config file). 20 | 21 | ```text 22 | python run_segmentation.py -m segmentation -c myconfig.yml 23 | ``` 24 | 25 | where `myconfig.yml` contains fields under `segmentation_inference`: 26 | ```text 27 | segmentation_inference: 28 | model: 'UNet' 29 | weights: 30 | fov: ['C4-Site_5', 'C4-Site_1', etc...] 31 | ``` 32 | 33 | **inputs** 34 | From "raw" directory 35 | - reads `.npy` file generated from `run_preproc.py` 36 | 37 | From any directory 38 | - reads `.h5` generated by whatever UNet training procedure you choose 39 | 40 | **outputs** 41 | To "raw" directory 42 | - writes `_NNProbabilities.npy` --> will be of shape (T, Y, X, C), where C is 3 in the case of the dynamorph paper 43 | - writes `_NNpred.png` 44 | 45 | ------------------------------------------- 46 | #### **method = "instance_segmentation"** 47 | 48 | This method selection will use the raw data and the probability map to generate labels and mappings to cell-instances from the segmentation 49 | 50 | ```text 51 | python run_segmentation.py -m instance_segmentation -c myconfig.yml 52 | ``` 53 | 54 | instance segmentation is done using the clustering method DBSCAN (sklearn.cluster). The process is as follows: 55 | 56 | ```text 57 | for each time point 58 | 1. filter cells whose probability qualifies it for "foreground". This is "fg_thr" < 0.3 in the paper. 59 | 2. perform DBSCAN clustering with `eps = 10` and `min_samples = 250` (values used in dynamorph paper) 60 | 3. position_labels is the output of step 2 61 | 4. cell_ids, point_counts is set of unique values from position_labels 62 | 5. for each cell_id/point_counts 63 | define a "mean position" around each cluster 64 | define a window of 256x256 around that mean 65 | exclude clusters that have too many outliers outside that window (> 5% of points) 66 | 6. append (cell_id, mean_pos) to qualifying cells to the `cell_positions` list 67 | 7. assign the output of 6 to the dictionary `cell_positions[time_point]` 68 | ``` 69 | 70 | **inputs** 71 | From "raw" directory 72 | - reads `.npy` 73 | - reads `_NNProbabilities.npy` 74 | 75 | **outputs** 76 | To "-supps/" directory 77 | - writes `cell_positions.pkl` 78 | - writes `cell_pixel_assignments.pkl` 79 | 80 | where `cell_positions.pkl` is a dictionary of {key:value} = {timepoint: (microglia-cell-map, non-microglia-cell-map, other-cell-map)} 81 | and where `-cell-map` represents `[ (cell_id, np.array(mean-x-pos, mean-y-pos)), (next_cell_id, np.array(mean-x-pos, mean-y-pos)), ... ]` 82 | 83 | 84 | where `cell_pixel_assignments.pkl` is a dictionary of {key:value} = {timepoint: (positions, position_labels)} 85 | and where `positions` represents array of (X, Y) coordinates of foreground pixels 86 | and where `position_labels` represents an array of cell_IDs of those foreground pixels 87 | -------------------------------------------------------------------------------- /documents/3-patching.md: -------------------------------------------------------------------------------- 1 | # Patching 2 | 3 | ## Purpose 4 | 5 | This document describes the process of extracting patches of single cells identified from the segmentation step. It uses the metadata generated from `segmentation` and `instance_segmentation` 6 | 7 | The relevant CLI is: 8 | ```text 9 | python run_patch.py -m -c 10 | ``` 11 | 12 | where is one of "extract_patches" or "build_trajectories" and 13 | where is the full path to a .yml configuration file as specified in `.configs/config_example.yml` 14 | 15 | -------------------------------------------- 16 | #### **method = "extract_patches"** 17 | 18 | This method generates a `stacks_.pkl` file and 19 | 20 | ```text 21 | python run_patch.py -m extract_patches -c myconfig.yml 22 | ``` 23 | 24 | where `myconfig.yml` contains fields under `patch`: 25 | ```text 26 | patch: 27 | channels: [0,1] 28 | fov: ['C4-Site_5', 'C4-Site_1', etc...] 29 | ``` 30 | 31 | **inputs** 32 | From "raw" directory 33 | - `.npy` file generated from `run_preproc.py` 34 | - `_NNProbabilities.npy` file generated from `run_segmentation.py -m segmentation` 35 | 36 | From "-supps/" directory 37 | - `cell_pixel_assignments.pkl` file generated from `run_segmentation.py -m instance_segmentation` 38 | - `cell_positions.pkl` file generated from `run_segmentation.py -m instance_segmentation` 39 | 40 | **outputs** 41 | To "-supps/" directory 42 | - `stacks_.pkl` 43 | - is a dictionary of {key:value} = {`full_path/_.h5`: `matrix_dict` } 44 | - where `matrix_dict` is dictionary of {key:value} = {"mat": }, 45 | {"masked_mat": } 46 | - `cell_positions.pkl` --> rewrites the cell_positions.pkl from the inputs 47 | 48 | ------------------------------------------- 49 | #### **method = "build_trajectories"** 50 | 51 | This method builds a `cell_traj.pkl` file that describes cell motion between frames 52 | 53 | ```text 54 | python run_patch.py -m build_trajectories -c myconfig.yml 55 | ``` 56 | 57 | Methodology 58 | ```text 59 | for each fov's supplementary folder 60 | 1. gather the cell centroid positions and sizes from the `cell_positions.pkl` and the `cell_pixel_assignments.pkl` files. 61 | 2. for each time point `T` 62 | 3. gather all cell positions and sizes at `T`, as well as the cells for timepoint `T+1` 63 | 4. generate pairs of "matched" cells by using `scipy.optimize.linear_sum_assignment` 64 | whose cost matrix is based on a 100 pixel distance cutoff 65 | 5. with all pairwise "matchings" for all timepoints, generate full trajectories using: 66 | - "Robust single-particle tracking in live-cell time-lapse sequences" (https://www.nature.com/articles/nmeth.1237) 67 | - an approach to model gaps, splits and merges of objects over time. 68 | ``` 69 | 70 | **inputs** 71 | From "raw" directory 72 | - reads `.npy` 73 | 74 | From "-supps" subdirectory 75 | - reads `cell_positions.pkl` 76 | - reads `cell_pixel_assignments.pkl` 77 | 78 | **outputs** 79 | To "-supps" directory 80 | - writes `cell_traj.pkl` 81 | - is a list of `[cell_trajectories, cell_trajectory_positions]` 82 | where cell_trajectories is a dictionary of {t_point: cell_ID} 83 | where cell_trajectory_positions is a dictionary of {t_point: cell_center_position} 84 | -------------------------------------------------------------------------------- /documents/4-latent_encoding.md: -------------------------------------------------------------------------------- 1 | # Latent Encoding 2 | 3 | ## Purpose 4 | 5 | Given a **trained** Deep Neural Network, assemble data into a necessary input format for the VQ-VAE network 6 | 7 | then run inference on the patch data and generate latent encodings 8 | 9 | The relevant CLI is: 10 | ```text 11 | python run_VAE.py -m -c 12 | ``` 13 | 14 | where is one of "assemble", "process", or "trajectory_matching" and 15 | where is the full path to a .yml configuration file as specified in `.configs/config_example.yml` 16 | 17 | -------------------------------------------- 18 | #### **method = "assemble"** 19 | 20 | Assemble well data into format needed for DNN inference 21 | 22 | ```text 23 | python run_VAE.py -m assemble -c myconfig.yml 24 | ``` 25 | 26 | where `myconfig.yml` contains fields under `latent_encoding`: 27 | 28 | **important config fields** 29 | ```text 30 | latent_encoding: 31 | weights: 32 | patch_type: 33 | network: 34 | 35 | ``` 36 | 37 | Methodology: 38 | ```text 39 | This pipeline loads data from multiple fovs, adjusts intensities to correct 40 | for batch effect, and assembles into dataset for model prediction 41 | 1. For each fov in the list of fovs in the configuration: 42 | 2. search all `-supps` for `stacks_.pkl` and gather all those filepaths. 43 | 3. "prepare" the result of 2 by using `cv2.resize` to down sample the 256x256 patches into 128x128 44 | 4. these patches are sorted by the unique `tpoint_cellid` name for each cell, the stacked into np.array 45 | 5. write the sorted filepaths as `_file_paths.pkl` 46 | 6. write the stack as `_static_patckes.pkl` 47 | 48 | 7. using the `_file_paths.pkl` AND the `cell_traj.pkl`, generate "trajectory relations" used by "matching loss" in vae training: 49 | 8. build a "relations" dictionary of {(cell_id1, cell_id2), n} 50 | which can be thought of as a large, sparse MxM matrix, where M is the number of unique cell_ids 51 | and where n is one of (0, 1, 2) that defines the relationship: 52 | 0 means "different trajectory" or not related 53 | 1 means "non-adjacent frame but same trajectory" 54 | 2 means "adjacent frame and same trajectory" 55 | 9. build a "labels" array which is a 1xM array 56 | whose indicies represent the sorted `_file_paths.pkl` indicies 57 | and whose values increment by 1 for each new trajectory loaded from all sites. 58 | 9. write the "relations" dictionary as `_static_patches_relations.pkl` 59 | 10 write the "labels" array as `_static_patches_labels.pkl` 60 | ``` 61 | 62 | **inputs** 63 | From "-supps/" directory 64 | - `stacks_.pkl` file generated from `run_patch.py -m extract_patches` 65 | - `cell-traj.pkl` file generated from `run_patch.py -m build_trajectories` 66 | 67 | **outputs** 68 | To "raw" directory 69 | Represents an aggregate of all FOV from within a well ("C5-Site_0", "C5-Site_1", "C5-Site_2", etc... will become "C5_file_paths.pkl") 70 | - `_file_paths.pkl` 71 | - `_static_patchkes.pkl` 72 | - `_static_patches_relations.pkl` 73 | - `_static_patches_labels.pkl` 74 | 75 | ------------------------------------------- 76 | #### **method = "process"** 77 | 78 | Run DNN inference on the assembled data 79 | Loads the `_static_patches.pkl`, zscores it, then casts as a TensorDataset for Pytorch inference 80 | 81 | 82 | ```text 83 | python run_VAE.py -m process -c myconfig.yml 84 | ``` 85 | 86 | **inputs** 87 | From "raw" directory 88 | - reads `_file_paths.pkl` 89 | - reads `_static_patchkes.pkl` 90 | 91 | **outputs** 92 | To "raw" directory 93 | - writes `_latent_space.pkl` 94 | which is the latent representation of the data before the quantizer 95 | - writes `_latent_space_after.pkl` 96 | which is the latent representation of the data after the quantizer 97 | 98 | 99 | ------------------------------------------- 100 | #### **method = "trajectory_matching"** 101 | 102 | Runs the trajectory matching already executed as part of `method = "assemble"`, but this time indepedent of the assemble 103 | 104 | ```text 105 | python run_VAE.py -m trajectory_matching -c myconfig.yml 106 | ``` 107 | 108 | **inputs** 109 | From "raw" directory 110 | - reads `_file_paths.pkl` 111 | 112 | From "-supps/" directory 113 | - reads `cell_traj.pkl` 114 | 115 | **outputs** 116 | To "raw" directory 117 | - writes `_trajectories.pkl` 118 | 119 | 120 | -------------------------------------------------------------------------------- /documents/5-dim_reduction.md: -------------------------------------------------------------------------------- 1 | # dimensionality reduction 2 | 3 | ## Purpose 4 | 5 | Given vectors representing the learned representation of the images, fit a PCA model and run inference 6 | 7 | The relevant CLI is: 8 | ```text 9 | python run_dim_reduction.py -m -c 10 | ``` 11 | 12 | where is one of "pca" or "umap" 13 | where is the full path to a .yml configuration file as specified in `.configs/config_example.yml` 14 | 15 | -------------------------------------------- 16 | #### **method = "pca"** 17 | 18 | Fit a PCA model on all latent space representations of the data and on all well prefixes specified in the config 19 | 20 | ```text 21 | python run_dim_reduction.py -m pca -c myconfig.yml 22 | ``` 23 | 24 | where `myconfig.yml` contains fields under `dim_reduction`: 25 | 26 | **important config fields** 27 | ```text 28 | dim_reduction: 29 | input_dirs: 30 | output_dirs: 31 | weights_dir: 32 | file_name_prefixes: _latent_space.pkl` outputted from run_vae.py 33 | 34 | fit_model: True or False 35 | ``` 36 | 37 | methodology: 38 | ```text 39 | For `fit_model: True`: 40 | 41 | 1. loops over all directories listed in config's `input_dirs` 42 | 2. loops over all prefixes in config's `file_name_prefixes` 43 | 3. [aggregate all data]: searches for `_latent_space_after.pkl` files in the input dirs and concatenates them in a vector list for subsequent PCA fitting 44 | 4. Fitting will write a model `pca_model.pkl` to the config's `weights_dir` directory. 45 | 5. Fitting will write a figure `PCA.png` to the config's `weights_dir` directory 46 | 6. finally, will loop over all pairs of `input_dirs` and `output_dirs` in the config: 47 | 7. will run inference on all individual `_latent_space_.pkl` in `input_dir` folder, where `suffix='after'` hardcoded. And where the supplied model is the one generated from step 4 above. 48 | 7. output of each inference is `_latent_space_after_PCAed.pkl` and saved to each corresponding `output_dir` from 6 49 | ``` 50 | 51 | ```text 52 | For `fit_model: False`: 53 | 1. loops over all pairs of directories listed in config's `input_dirs` / `output_dirs` 54 | 2. loops over all prefixes in config's `file_name_prefixes` 55 | 3. assumes the `weights_dir` supplied in the config is a directory, and looks for the `pca_model.pkl` file there. 56 | 4. runs inference on `_latent_space_.pkl` where `suffix=after` is hardcoded. 57 | 5. writes the transformed vectors to `_latent_space__PCAed.pkl` in the corresponding `output_dir` directory 58 | ``` 59 | 60 | **inputs** 61 | From `config.dim_reduction.input_dirs` 62 | - `_latent_space_after.pkl` files in the input dirs. 63 | - `pca_model.pkl` in the `weights_dir` directory if `fit_model: False` 64 | 65 | **outputs** 66 | To `config.dim_reduction.output_dirs` 67 | - `_latent_space_after_PCAed.pkl` 68 | 69 | To `config.dim_reduction.weights_dir`, if `fit_model: True` 70 | - `pca_model.pkl` 71 | 72 | 73 | ------------------------------------------- 74 | #### **method = "umap"** 75 | 76 | fit a UMAP model on all latent space representations of the data and on all well prefixes specified in the config. 77 | 78 | **inputs** 79 | Inputs are the same as PCA except the UMAP method takes only `fit_model: True`. 80 | 81 | **outputs** 82 | To `config.dim_reduction.weights_dir` 83 | - `umap_nbr#_a#_b#.pkl` 84 | - `embedding`: UMAP reduced embeddings 85 | - `labels`: Class label of embeddings 86 | -------------------------------------------------------------------------------- /graphicalabstract_dynamorph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/graphicalabstract_dynamorph.jpg -------------------------------------------------------------------------------- /pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehta-lab/dynamorph/b3321f4368002707fbe39d727bc5c23bd5e7e199/pipeline.jpg -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # bchhun, {2020-02-21} 2 | 3 | -------------------------------------------------------------------------------- /pipeline/preprocess.py: -------------------------------------------------------------------------------- 1 | # bchhun, {2020-02-21} 2 | 3 | import numpy as np 4 | import cv2 5 | from typing import Union 6 | import logging 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def read_image(file_path): 11 | """ 12 | Read 2D grayscale image from file. 13 | Checks file extension for npy and load array if true. Otherwise 14 | reads regular image using OpenCV (png, tif, jpg, see OpenCV for supported 15 | files) of any bit depth. 16 | :param str file_path: Full path to image 17 | :return array im: 2D image 18 | :raise IOError if image can't be opened 19 | """ 20 | if file_path[-3:] == 'npy': 21 | im = np.load(file_path) 22 | else: 23 | im = cv2.imread(file_path, cv2.IMREAD_ANYDEPTH) 24 | if im is None: 25 | raise IOError('Image "{}" cannot be found.'.format(file_path)) 26 | return im 27 | 28 | 29 | def load_raw(fullpaths: list, 30 | chans: list, 31 | z_slice: int, 32 | multipage: bool = True): 33 | """Raw data loader 34 | 35 | This function takes a list of paths to an experiment folder and 36 | loads specified site data into a numpy array. 37 | 38 | Output array will be of shape: (n_frames, 2048, 2048, 2), where 39 | channel 0 (last dimension) is phase and channel 1 is retardance 40 | 41 | Args: 42 | fullpaths (list): 43 | list of full paths to singlepage or multipage tiffs 44 | chans (list): 45 | list of strings corresponding to channel names 46 | z_slice: (int) 47 | specific slice to extract if multiple exist 48 | multipage (bool, optional): default=True 49 | if folder contains stabilized multipage tiffs 50 | only multipage tiff is supported now 51 | 52 | Returns: 53 | np.array: numpy array as described above 54 | 55 | """ 56 | 57 | # store list of every image shape in the dataset for later validation 58 | shapes = [] 59 | 60 | if not multipage: 61 | log.info(f"single-page tiffs specified") 62 | # load singlepage tiffs. String parse assuming time series and z### format 63 | for chan in chans: 64 | # files maps (key:value) = (z_index, t_y_x array) 65 | # files = [] 66 | # for z in z_indicies: 67 | # files.append([c for c in sorted(os.listdir(fullpath)) if chan in c and f"z{z:03d}" in c]) 68 | # files = np.array(files).flatten() 69 | files = [c for c in fullpaths if chan in c.split('/')[-1] and f"z{z_slice:03d}" in c.split('/')[-1]] 70 | files = sorted(files) 71 | if not files: 72 | log.warning(f"no files with {chan} identified") 73 | continue 74 | 75 | # resulting shapes are in (t, y, x) order 76 | if "Phase" in chan: 77 | phase = np.stack([read_image(f) for f in files]) 78 | # phase = phase.reshape((len(z_indicies), -1, phase.shape[-2], phase.shape[-1])) 79 | shapes.append(phase.shape) 80 | elif "Retardance" in chan: 81 | ret = np.stack([read_image(f) for f in files]) 82 | # ret = ret.reshape((len(z_indicies), -1, ret.shape[-2], ret.shape[-1])) 83 | shapes.append(ret.shape) 84 | elif "Brightfield" in chan: 85 | bf = np.stack([read_image(f) for f in files]) 86 | # bf = bf.reshape((len(z_indicies), -1, bf.shape[-2], bf.shape[-1])) 87 | shapes.append(bf.shape) 88 | else: 89 | log.warning(f'not implemented: {chan} parse from single page files') 90 | 91 | else: 92 | log.info(f"multi-page tiffs specified") 93 | # load stabilized multipage tiffs. 94 | for chan in chans: 95 | files = [c for c in fullpaths if chan in c.split('/')[-1] and '.tif' in c.split('/')[-1]] 96 | files = sorted(files) 97 | if not files: 98 | log.warning(f"no files with {chan} identified") 99 | continue 100 | if len(files) > 1: 101 | log.warning(f"duplicate matches for channel name in folder, skipping channel") 102 | continue 103 | 104 | if "Phase" in chan: 105 | # multi_tif_phase = 'img_Phase2D_stabilized.tif' 106 | _, phase = cv2.imreadmulti(files[0], 107 | flags=cv2.IMREAD_ANYDEPTH) 108 | phase = np.array(phase) 109 | shapes.append(phase.shape) 110 | if "Retardance" in chan: 111 | # multi_tif_retard = 'img__Retardance__stabilized.tif' 112 | _, ret = cv2.imreadmulti(files[0], 113 | flags=cv2.IMREAD_ANYDEPTH) 114 | ret = np.array(ret) 115 | shapes.append(ret.shape) 116 | if "Brightfield" in chan: 117 | # multi_tif_bf = 'img_Brightfield_computed_stabilized.tif' 118 | _, bf = cv2.imreadmulti(files[0], 119 | flags=cv2.IMREAD_ANYDEPTH) 120 | bf = np.array(bf) 121 | shapes.append(bf.shape) 122 | 123 | # check that all shapes are the same 124 | assert shapes.count(shapes[0]) == len(shapes) 125 | 126 | # insert images into a composite array. Composite always has 3 channels 127 | n_frame, x_size, y_size = shapes[0][:3] 128 | out = np.empty(shape=(n_frame, 3, 1, x_size, y_size)) 129 | log.info(f"writing channels ({chans}) to composite array") 130 | for chan in chans: 131 | try: 132 | if "Phase" in chan: 133 | out[:, 0, 0] = phase 134 | if "Retardance" in chan: 135 | out[:, 1, 0] = ret 136 | if "Brightfield" in chan: 137 | out[:, 2, 0] = bf 138 | except UnboundLocalError: 139 | log.warning('variable referenced before assignment') 140 | 141 | return out 142 | 143 | 144 | def adjust_range(arr): 145 | """Check value range for both channels 146 | *** currently does nothing but report mean and std *** 147 | *** image z-scoring is done at a later stage *** 148 | 149 | To maintain stability, input arrays should be within: 150 | phase channel: mean - 32767, std - 1600~2000 151 | retardance channel: mean - 1400~1600, std ~ 1500~1800 152 | 153 | Args: 154 | arr (np.array): 155 | input data array 156 | 157 | Returns: 158 | np.array: numpy array with value range adjusted 159 | 160 | """ 161 | log.info(f"z scoring data") 162 | 163 | mean_c0 = arr[:, 0, 0].mean() 164 | mean_c1 = arr[:, 1, 0].mean() 165 | mean_c2 = arr[:, 2, 0].mean() 166 | std_c0 = arr[:, 0, 0].std() 167 | std_c1 = arr[:, 1, 0].std() 168 | std_c2 = arr[:, 2, 0].std() 169 | log.info("\tPhase: %d plus/minus %d" % (mean_c0, std_c0)) 170 | log.info("\tRetardance: %d plus/minus %d" % (mean_c1, std_c1)) 171 | log.info("\tBrightfield: %d plus/minus %d" % (mean_c2, std_c2)) 172 | #TODO: manually adjust range if input doesn't satisfy 173 | return arr 174 | 175 | 176 | def write_raw_to_npy(site: Union[int, str], 177 | site_list: list, 178 | output: str, 179 | chans: list, 180 | z_slice: int, 181 | multipage: bool = True): 182 | """Wrapper method for data loading 183 | 184 | This function takes a path to an experiment folder, loads specified 185 | site data into a numpy array, and saves it under specified output path. 186 | 187 | Args: 188 | site: (int or str) 189 | name of specific position/site being processed 190 | site_list (list): 191 | list of files for this position/site 192 | output (str): 193 | path to the output folder 194 | chans (list): 195 | list of strings corresponding to channel names 196 | z_slice (int): 197 | specific z slice to stack 198 | multipage (bool, optional): default=True 199 | if folder contains stabilized multipage tiffs 200 | only multipage tiff is supported now 201 | pos_dir (bool): "True" if images from each position are saved in separate directories 202 | 203 | """ 204 | 205 | raw = load_raw(site_list, chans, z_slice=z_slice, multipage=multipage) 206 | raw_adjusted = adjust_range(raw) 207 | 208 | output_name = output + '/' + str(site) + '.npy' 209 | log.info(f"saving image stack to {output_name}") 210 | np.save(output_name, raw_adjusted) 211 | return 212 | -------------------------------------------------------------------------------- /pipeline/segmentation.py: -------------------------------------------------------------------------------- 1 | # bchhun, {2020-02-21} 2 | 3 | import os 4 | import numpy as np 5 | from NNsegmentation.models import Segment 6 | from NNsegmentation.data import load_input, predict_whole_map 7 | from SingleCellPatch.instance_clustering import process_site_instance_segmentation 8 | from configs.config_reader import YamlReader 9 | import logging 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | def segmentation(raw_folder_: str, 14 | supp_folder_: str, 15 | val_folder: str, 16 | sites: list, 17 | config_: YamlReader, 18 | **kwargs): 19 | """ Wrapper method for semantic segmentation 20 | 21 | This method performs predicion on all specified sites included in the 22 | input paths. 23 | 24 | Model weight path should be provided, if not a default path will be used: 25 | UNet: "NNsegmentation/temp_save_unsaturated/final.h5" 26 | 27 | Resulting segmentation results and sample segentation image will be saved 28 | in the summary folder as "*_NNProbabilities.npy" 29 | 30 | 31 | Args: 32 | summary_folder (str): folder for raw data, segmentation and 33 | summarized results 34 | supp_folder: 35 | val_folder: 36 | sites (list of str): list of site names 37 | 38 | n_classes (int, optional): number of prediction classes 39 | window_size (int, optional): winsow size for segmentation model 40 | prediction 41 | batch_size (int, optional): batch size 42 | num_pred_rnd (int, optional): number of extra prediction rounds 43 | each round of supplementary prediction will be initiated with 44 | different offset 45 | 46 | """ 47 | 48 | weights = config_.segmentation.inference.weights 49 | n_classes = config_.segmentation.inference.num_classes 50 | channels = config_.segmentation.inference.channels 51 | window_size = config_.segmentation.inference.window_size 52 | batch_size = config_.segmentation.inference.batch_size 53 | n_supp = config_.segmentation.inference.num_pred_rnd 54 | 55 | if config_.segmentation.inference.network == 'UNet': 56 | model = Segment(input_shape=(len(channels), window_size, window_size), 57 | n_classes=n_classes) 58 | else: 59 | raise NotImplementedError(f"segmentation model {config_.segmentation.inference.network} not implemented") 60 | 61 | try: 62 | if weights: 63 | model.load(weights) 64 | else: 65 | model.load('NNsegmentation/temp_save_unsaturated/final.h5') 66 | except Exception as ex: 67 | log.error(ex) 68 | raise ValueError("Error in loading UNet weights") 69 | 70 | for site in sites: 71 | site_path = os.path.join(raw_folder_, '%s.npy' % site) 72 | if not os.path.exists(site_path): 73 | log.info("Site not found %s" % site_path) 74 | else: 75 | log.info("Predicting %s" % site_path) 76 | try: 77 | # Generate semantic segmentation 78 | predict_whole_map(site_path, 79 | model, 80 | use_channels=np.array(channels).astype(int), 81 | batch_size=batch_size, 82 | n_supp=n_supp, 83 | **kwargs) 84 | except Exception as ex: 85 | log.error(ex) 86 | log.error("Error in predicting site %s" % site) 87 | return 88 | 89 | 90 | def instance_segmentation(raw_folder: str, 91 | supp_folder: str, 92 | val_folder: str, 93 | sites: list, 94 | config_: YamlReader, 95 | rerun=False, 96 | 97 | **kwargs): 98 | """ Helper function for instance segmentation 99 | 100 | Wrapper method `process_site_instance_segmentation` will be called, which 101 | loads "*_NNProbabilities.npy" files and performs instance segmentation. 102 | 103 | Results will be saved in the supplementary data folder, including: 104 | "cell_positions.pkl": dict of cells in each frame (IDs and positions); 105 | "cell_pixel_assignments.pkl": dict of pixel compositions of cells 106 | in each frame; 107 | "segmentation_*.png": image of instance segmentation results. 108 | 109 | Args: 110 | raw_folder (str): folder for raw data, segmentation and summarized results 111 | supp_folder (str): folder for supplementary data 112 | sites (list of str): list of site names 113 | config (YamlReader): 114 | 115 | """ 116 | 117 | for site in sites: 118 | site_path = os.path.join(raw_folder, '%s.npy' % site) 119 | site_segmentation_path = os.path.join(raw_folder, 120 | '%s_NNProbabilities.npy' % site) 121 | if not os.path.exists(site_path) or not os.path.exists(site_segmentation_path): 122 | log.info("Site not found %s" % site_path, flush=True) 123 | continue 124 | 125 | log.info("Clustering %s" % site_path, flush=True) 126 | site_supp_files_folder = os.path.join(supp_folder, 127 | '%s-supps' % site[:2], 128 | '%s' % site) 129 | 130 | if os.path.exists(os.path.join(site_supp_files_folder, 'cell_pixel_assignments.pkl')) and not rerun: 131 | log.info('Found previously saved instance clustering output in {}. Skip processing...' 132 | .format(site_supp_files_folder)) 133 | continue 134 | elif not os.path.exists(site_supp_files_folder): 135 | os.makedirs(site_supp_files_folder, exist_ok=True) 136 | 137 | process_site_instance_segmentation(site_path, 138 | site_segmentation_path, 139 | site_supp_files_folder, 140 | **kwargs) 141 | return 142 | -------------------------------------------------------------------------------- /pipeline/segmentation_validation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | from PIL import Image, ImageFilter 5 | import matplotlib.pyplot as plt 6 | import skimage.io as io 7 | from configs.config_reader import YamlReader 8 | 9 | 10 | def find_rim(cell_positions): 11 | masks = set(tuple(r) for r in cell_positions) 12 | inner_masks = set((r[0]-1, r[1]) for r in masks) & \ 13 | set((r[0]+1, r[1]) for r in masks) & \ 14 | set((r[0], r[1]-1) for r in masks) & \ 15 | set((r[0], r[1]+1) for r in masks) 16 | edge_positions = np.array(list(masks - inner_masks)) 17 | return edge_positions 18 | 19 | 20 | def drawContour(m, s, c, RGB): 21 | """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'""" 22 | # Fill contour "c" with white, make all else black 23 | # thisContour = s.point(lambda p:p==c and 255) 24 | # DEBUG: thisContour.save(f"interim{c}.png") 25 | # thisContour = s.point(lambda p:p==c and 255) 26 | thisContour = s.point(lambda x: 255 if x > 30 else False) 27 | 28 | # Find edges of this contour and make into Numpy array 29 | thisEdges = thisContour.filter(ImageFilter.FIND_EDGES) 30 | thisEdgesN = np.array(thisEdges) 31 | 32 | # Paint locations of found edges in color "RGB" onto "dim_reduction" 33 | m[np.nonzero(thisEdgesN)] = RGB 34 | return m 35 | 36 | 37 | def rescale_plot(arr1, filename, size=(1108, 1108), dpi=500): 38 | plt.clf() 39 | 40 | if type(arr1) == np.ndarray: 41 | # arr1 = auto_contrast_slice(arr1) 42 | im = Image.fromarray(arr1) 43 | else: 44 | im = arr1 45 | im = im.resize(size) 46 | 47 | ax = plt.axes([0, 0, 1, 1], frameon=False) 48 | ax.get_xaxis().set_visible(False) 49 | ax.get_yaxis().set_visible(False) 50 | plt.autoscale(tight=True) 51 | plt.imshow(np.array(im)) 52 | if type(arr1) == np.ndarray: 53 | plt.clim(0, 0.9 * arr1.max()) 54 | plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=dpi) 55 | 56 | 57 | # %% 58 | def load_and_plot(img_rgb, img_grey, output): 59 | phase = Image.open(img_rgb).convert('L').convert('RGB') 60 | segment = Image.open(img_grey).convert('L') 61 | 62 | phaseN = np.array(phase) 63 | phaseN = drawContour(phaseN, segment, 0, (255, 0, 0)) 64 | Image.fromarray(phaseN).save(output) 65 | 66 | 67 | def segmentation_validation_michael(raw_folder_ : str, 68 | supp_folder_: str, 69 | val_folder_ : str, 70 | sites : list, 71 | config_: YamlReader, 72 | **kwargs): 73 | """ 74 | 75 | :param paths: 76 | :param category: str 77 | only "mg", "nonmg", "both", "unfiltered" 78 | :param gpu_id: 79 | :return: 80 | """ 81 | 82 | temp_folder = raw_folder_ 83 | supp_folder = supp_folder_ 84 | category = config_.segmentation.seg_val_cat 85 | gpu_id = config_.segmentation.gpu_id 86 | 87 | if "NOVEMBER" in temp_folder: 88 | date = "NOVEMBER" 89 | elif "JANUARY_FAST" in temp_folder: 90 | date = "JANUARY_FAST" 91 | else: 92 | date = "JANUARY" 93 | 94 | for site in sites: 95 | print(f"building full frame validation for {site} from {temp_folder}") 96 | 97 | stack_path = os.path.join(temp_folder + '/' + site + '.npy') 98 | raw_input_stack = np.load(stack_path) 99 | 100 | NN_predictions_stack = np.load(os.path.join(temp_folder, '%s_NNProbabilities.npy' % site)) 101 | cell_pixels = pickle.load(open(os.path.join(supp_folder, f"{site[0:2]}-supps/{site}/cell_pixel_assignments.pkl"), 'rb')) 102 | 103 | # option to include filtered positions 104 | filtered_positions = pickle.load( 105 | open(supp_folder + f"/{site[0:2]}-supps/{site}/cell_positions.pkl", 'rb')) 106 | 107 | stack = [] 108 | for t_point in range(len(raw_input_stack)): 109 | print(f"\tprocessing t {t_point}") 110 | mat = raw_input_stack[t_point, :, :, 0] 111 | mat = np.stack([mat] * 3, 2) 112 | 113 | # this block represents rendering of FILTERED MG and nonMG cells 114 | (mg_cell_positions, non_mg_cell_positions, other_cells) = filtered_positions[t_point] 115 | 116 | # this block represents rendering of MG and nonMG cells, but NOT filtered by size. 117 | positions, inds = cell_pixels[t_point] 118 | 119 | if 'unfiltered' in category: 120 | for cell_ind in np.unique(inds): 121 | new_mat = _append_segmentation(positions, inds, cell_ind, NN_predictions_stack, t_point, mat) 122 | if new_mat is not None: 123 | mat = new_mat 124 | elif 'both' in category: 125 | if mg_cell_positions is None: 126 | if non_mg_cell_positions is None: 127 | continue 128 | else: 129 | ids = non_mg_cell_positions 130 | else: 131 | if non_mg_cell_positions is None: 132 | ids = mg_cell_positions 133 | else: 134 | # ForkedPdb().set_trace() 135 | ids = [i for i, _ in mg_cell_positions+non_mg_cell_positions] 136 | 137 | for both_cell_id in ids: 138 | new_mat = _append_segmentation(positions, inds, both_cell_id, NN_predictions_stack, t_point, mat) 139 | if new_mat is not None: 140 | mat = new_mat 141 | 142 | elif 'mg' in category: 143 | # ForkedPdb().set_trace() 144 | ids = [i for i, _ in mg_cell_positions] 145 | for mg_cell_id in ids: 146 | new_mat = _append_segmentation(positions, inds, mg_cell_id, NN_predictions_stack, t_point, mat) 147 | if new_mat is not None: 148 | mat = new_mat 149 | elif 'nonmg' in category: 150 | ids = [i for i, _ in non_mg_cell_positions] 151 | for non_mg_cell_id in ids: 152 | new_mat = _append_segmentation(positions, inds, non_mg_cell_id, NN_predictions_stack, t_point, mat) 153 | if new_mat is not None: 154 | mat = new_mat 155 | else: 156 | raise NotImplementedError(f"rendering category of type {category} is not impemented") 157 | 158 | stack.append(mat) 159 | 160 | # tifffile.imwrite(target+'/'+f'{date}_{site}_predictions.tiff', np.stack(stack, 0)) 161 | # np.save(target+'/'+f'{date}_{site}_predictions.npy', np.stack(stack, 0)) 162 | 163 | # using skimage.io to access tifffile on IBM machines 164 | # ForkedPdb().set_trace() 165 | target = os.path.join(supp_folder, "validation_images") 166 | io.imsave(target+'/'+f'{date}_{site}_{gpu_id}_predictions.tif', 167 | np.stack(stack, 0).astype("uint16"), 168 | plugin='tifffile') 169 | 170 | 171 | def _append_segmentation(positions_, inds_, cell_id_, NN_predictions_stack_, t_point_, output_mat_): 172 | """ 173 | adds boundary positions for a supplied cell 174 | :param positions_: 175 | :param inds_: 176 | :param cell_id_: 177 | :param NN_predictions_stack_: 178 | :param t_point_: 179 | :param output_mat_: 180 | :return: 181 | """ 182 | if cell_id_ < 0: 183 | return None 184 | 185 | cell_positions = positions_[np.where(inds_ == cell_id_)] 186 | 187 | outer_rim = find_rim(cell_positions) 188 | mask_identities = NN_predictions_stack_[t_point_][(cell_positions[:, 0], cell_positions[:, 1])].mean(0) 189 | if mask_identities[1] > mask_identities[2]: 190 | c = 'b' 191 | output_mat_[(outer_rim[:, 0], outer_rim[:, 1])] = np.array([0, 65535, 0]).reshape((1, 3)) 192 | else: 193 | c = 'r' 194 | output_mat_[(outer_rim[:, 0], outer_rim[:, 1])] = np.array([65535, 0, 0]).reshape((1, 3)) 195 | return output_mat_ 196 | 197 | 198 | def segmentation_validation_bryant(paths, id): 199 | """ 200 | this approach uses the outputted .png segmentations (per frame) and stitches it back with the raw data using PIL 201 | 202 | :param paths: 203 | :return: 204 | """ 205 | 206 | temp_folder, supp_folder, target, sites = paths[0], paths[1], paths[2], paths[3] 207 | 208 | for site in sites: 209 | print(f"building full frame validation for {site} from {temp_folder}") 210 | 211 | stack_path = os.path.join(temp_folder + '/' + site + '.npy') 212 | segmentations_png_path = os.path.join(supp_folder + f"/{site[0:2]}-supps/{site}") 213 | 214 | raw_input_stack = np.load(stack_path) 215 | 216 | for tp in range(len(raw_input_stack)): 217 | seg = Image.open(segmentations_png_path + os.sep + f'segmentation_{tp}.png').convert('L') 218 | 219 | # site[t,:,:,0] is phase channel 220 | rescale_plot(raw_input_stack[tp, :, :, 0], target + f"/temp_phase_{id}.png") 221 | rescale_plot(seg, target + f"/temp_seg_{id}.png") 222 | 223 | if "NOVEMBER" in temp_folder: 224 | date = "NOVEMBER" 225 | else: 226 | date = "JAN_FAST" 227 | 228 | if not os.path.exists(target+'/'+date): 229 | os.makedirs(target+'/'+date) 230 | 231 | load_and_plot(target + f"/temp_phase_{id}.png", 232 | target + f"/temp_seg_{id}.png", 233 | target + f"/{date}/{date}_{site}_{tp}.png") 234 | 235 | 236 | def segmentation_validation_to_tiff(paths): 237 | """ 238 | paths is a tuple of: 239 | (target folder, date, sites) 240 | 241 | target folder is the EXPERIMENT folder (not subfolder) 242 | 243 | :param paths: 244 | :return: 245 | """ 246 | import tifffile as tf 247 | import imageio as io 248 | 249 | target, date, sites = paths[0], paths[1], paths[2] 250 | 251 | for site in sites: 252 | png_path = f"{target}/{date}/" 253 | 254 | matched = [file for file in os.listdir(png_path) if f"{date}_{site}" in file] 255 | smatched = sorted(matched) 256 | 257 | ref = io.imread(png_path+'/'+smatched[0]) 258 | x, y, c = ref.shape 259 | output = np.empty(shape=(len(smatched), x, y, c)) 260 | for idx, path in enumerate(smatched): 261 | frame = io.imread(png_path+'/'+path) 262 | output[idx] = frame 263 | 264 | io.mimwrite(png_path+f'/{date}_{site}_composite.tif', output) 265 | 266 | import sys 267 | import pdb 268 | 269 | 270 | class ForkedPdb(pdb.Pdb): 271 | """A Pdb subclass that may be used 272 | from a forked multiprocessing child 273 | 274 | """ 275 | def interaction(self, *args, **kwargs): 276 | _stdin = sys.stdin 277 | try: 278 | sys.stdin = open('/dev/stdin') 279 | pdb.Pdb.interaction(self, *args, **kwargs) 280 | finally: 281 | sys.stdin = _stdin 282 | -------------------------------------------------------------------------------- /pipeline/train_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from typing import Callable 5 | from typing import Tuple 6 | 7 | 8 | class EarlyStopping: 9 | """Early stops the training if validation loss doesn't improve after a given patience. Adapted from 10 | https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py 11 | """ 12 | 13 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 14 | """ 15 | Args: 16 | patience (int): How long to wait after last time validation loss improved. 17 | Default: 7 18 | verbose (bool): If True, prints a message for each validation loss improvement. 19 | Default: False 20 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 21 | Default: 0 22 | path (str): Path for the checkpoint to be saved to. 23 | Default: 'checkpoint.pt' 24 | trace_func (function): trace print function. 25 | Default: print 26 | """ 27 | self.patience = patience 28 | self.verbose = verbose 29 | self.counter = 0 30 | self.best_score = None 31 | self.early_stop = False 32 | self.val_loss_min = np.Inf 33 | self.delta = delta 34 | self.path = path 35 | self.trace_func = trace_func 36 | 37 | def __call__(self, val_loss, model): 38 | 39 | score = -val_loss 40 | 41 | if self.best_score is None: 42 | self.best_score = score 43 | self.save_checkpoint(val_loss, model) 44 | elif score < self.best_score + self.delta: 45 | self.counter += 1 46 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 47 | if self.counter >= self.patience: 48 | self.early_stop = True 49 | else: 50 | self.best_score = score 51 | self.save_checkpoint(val_loss, model) 52 | self.counter = 0 53 | 54 | def save_checkpoint(self, val_loss, model): 55 | '''Saves model when validation loss decrease.''' 56 | if self.verbose: 57 | self.trace_func( 58 | f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 59 | torch.save(model.state_dict(), self.path) 60 | self.val_loss_min = val_loss 61 | 62 | 63 | class TripletDataset(Dataset): 64 | """TripletDataset 65 | Adapted from https://github.com/TowardHumanizedInteraction/TripletTorch 66 | The TripletDataset extends the standard Dataset provided by the pytorch 67 | utils. It provides simple access to data with the possibility of returning 68 | more than one sample per index based on the label. 69 | Attributes 70 | ---------- 71 | labels : np.ndarray 72 | Array containing all the labels respectively to each data sample. 73 | Labels needs to provide a way to access a sample label by index. 74 | data_fn : Callable 75 | The data_fn provides access to sample data given its index in the 76 | dataset. Providding a function instead of array has been chosen 77 | for preprocessing and other reasons. 78 | size : int 79 | Size gives the dataset size, number of samples. 80 | n_sample: int 81 | The value represents the number of sample per index. The other 82 | samples will be chosen to be the same label as the selected one. This 83 | allows to augment the number of possible valid triplet when used 84 | with a tripelt mining strategy. 85 | """ 86 | 87 | def __init__( 88 | self: 'TripletDataset', 89 | labels: np.ndarray, 90 | data_fn: Callable, 91 | n_sample: int, 92 | ) -> None: 93 | """Init 94 | Parameters 95 | ---------- 96 | labels : np.ndarray 97 | Array containing all the labels respectively to each data 98 | sample. Labels needs to provide a way to access a sample label 99 | by index. 100 | data_fn : Callable 101 | The data_fn provides access to sample data given its index in 102 | the dataset. Providding a function instead of array has been 103 | chosen for preprocessing and other reasons. 104 | size : int 105 | Size gives the dataset size, number of samples. 106 | n_sample: int 107 | The value represents the number of sample per index. The other 108 | samples will be chosen to be the same as the selected one. 109 | This allows to augment the number of possible valid triplet 110 | when used with a tripelt mining strategy. 111 | """ 112 | super(Dataset, self).__init__() 113 | self.labels = labels 114 | self.data_fn = data_fn 115 | self.size = len(labels) 116 | self.n_sample = n_sample 117 | 118 | def __len__(self: 'TripletDataset') -> int: 119 | """Len 120 | Returns 121 | ------- 122 | size: int 123 | Returns the size of the dataset, number of samples. 124 | """ 125 | return self.size 126 | 127 | def __getitem__(self: 'TripletDataset', index: int) -> Tuple[np.ndarray]: 128 | """GetItem 129 | Parameters 130 | ---------- 131 | index: int 132 | Index of the sample to draw. The value should be less than the 133 | dataset size and positive. 134 | Returns 135 | ------- 136 | labels: torch.Tensor 137 | Returns the labels respectively to each of the samples drawn. 138 | First sample is the sample is the one at the selected index, 139 | and others are selected randomly from the rest of the dataset. 140 | data : torch.Tensor 141 | Returns the data respectively to each of the samples drawn. 142 | First sample is the sample is the one at the selected index, 143 | and others are selected randomly from the rest of the dataset. 144 | Raises 145 | ------ 146 | IndexError: If index is negative or greater than the dataset size. 147 | """ 148 | if not (index >= 0 and index < len(self)): 149 | raise IndexError(f'Index {index} is out of range [ 0, {len(self)} ]') 150 | 151 | label = np.array([self.labels[index]]) 152 | datum = np.array([self.data_fn(index)]) 153 | 154 | if self.n_sample == 1: 155 | return label, datum 156 | 157 | mask = self.labels == label 158 | # mask[ index ] = False 159 | mask = mask.astype(np.float32) 160 | 161 | indexes = mask.nonzero()[0] 162 | indexes = np.random.choice(indexes, self.n_sample - 1, replace=True) 163 | data = np.array([self.data_fn(i) for i in indexes]) 164 | 165 | labels = np.repeat(label, self.n_sample) 166 | data = np.concatenate((datum, data), axis=0) 167 | 168 | labels = torch.from_numpy(labels) 169 | data = torch.from_numpy(data) 170 | 171 | return labels, data 172 | 173 | 174 | class ImageDataset(Dataset): 175 | """Basic dataset class for inference 176 | Attributes 177 | ---------- 178 | data : np.ndarray 179 | The data_fn provides access to sample data given its index in the 180 | dataset. Providding a function instead of array has been chosen 181 | for preprocessing and other reasons. 182 | """ 183 | 184 | def __init__( 185 | self: 'ImageDataset', 186 | data: np.ndarray, 187 | ) -> None: 188 | 189 | super(Dataset, self).__init__() 190 | self.data = data 191 | self.size = len(data) 192 | 193 | def __len__(self: 'ImageDataset') -> int: 194 | """Len 195 | Returns 196 | ------- 197 | size: int 198 | Returns the size of the dataset, number of samples. 199 | """ 200 | return self.size 201 | 202 | def __getitem__(self: 'ImageDataset', index: int) -> np.ndarray: 203 | """GetItem 204 | Parameters 205 | ---------- 206 | index: int 207 | Index of the sample to draw. The value should be less than the 208 | dataset size and positive. 209 | Returns 210 | ------- 211 | labels: torch.Tensor 212 | Returns the labels respectively to each of the samples drawn. 213 | First sample is the sample is the one at the selected index, 214 | and others are selected randomly from the rest of the dataset. 215 | datum : torch.Tensor 216 | sample drawn at the selected index, 217 | Raises 218 | ------ 219 | IndexError: If index is negative or greater than the dataset size. 220 | """ 221 | if not (index >= 0 and index < len(self)): 222 | raise IndexError(f'Index {index} is out of range [ 0, {len(self)} ]') 223 | datum = np.array([self.data[index]]) 224 | 225 | return datum 226 | 227 | 228 | def zscore(input_image, channel_mean=None, channel_std=None): 229 | """ 230 | Performs z-score normalization. Adds epsilon in denominator for robustness 231 | 232 | :param input_image: input image for intensity normalization 233 | :return: z score normalized image 234 | """ 235 | if not channel_mean: 236 | channel_mean = np.mean(input_image, axis=(0, 2, 3)) 237 | if not channel_std: 238 | channel_std = np.std(input_image, axis=(0, 2, 3)) 239 | channel_slices = [] 240 | for c in range(len(channel_mean)): 241 | mean = channel_mean[c] 242 | std = channel_std[c] 243 | channel_slice = (input_image[:, c, ...] - mean) / \ 244 | (std + np.finfo(float).eps) 245 | # channel_slice = t.clamp(channel_slice, -1, 1) 246 | channel_slices.append(channel_slice) 247 | norm_img = np.stack(channel_slices, 1) 248 | print('channel_mean:', channel_mean) 249 | print('channel_std:', channel_std) 250 | return norm_img 251 | 252 | def zscore_patch(imgs): 253 | """ 254 | Performs z-score normalization. Adds epsilon in denominator for robustness 255 | 256 | :param input_image: input image for intensity normalization 257 | :return: z score normalized image 258 | """ 259 | means = np.mean(imgs, axis=(2, 3)) 260 | stds = np.std(imgs, axis=(2, 3)) 261 | imgs_norm = [] 262 | for img_chan, channel_mean, channel_std in zip(imgs, means, stds): 263 | channel_slices = [] 264 | for img, mean, std in zip(img_chan, channel_mean, channel_std): 265 | channel_slice = (img - mean) / \ 266 | (std + np.finfo(float).eps) 267 | # channel_slice = t.clamp(channel_slice, -1, 1) 268 | channel_slices.append(channel_slice) 269 | channel_slices = np.stack(channel_slices) 270 | imgs_norm.append(channel_slices) 271 | imgs_norm = np.stack(imgs_norm) 272 | # print('channel_mean:', channel_mean) 273 | # print('channel_std:', channel_std) 274 | return imgs_norm -------------------------------------------------------------------------------- /plot_scripts/B4_temp.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import cv2 6 | import matplotlib 7 | matplotlib.use('AGG') 8 | from matplotlib import cm 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | from HiddenStateExtractor.vq_vae import VQ_VAE, CHANNEL_MAX, CHANNEL_VAR, CHANNEL_RANGE, prepare_dataset, rescale 12 | from sklearn.decomposition import PCA 13 | 14 | sites = ['B4-Site_%d' % i for i in [0, 2, 3, 5, 6]] 15 | dats = pickle.load(open('./save_0005_bkp4.pkl', 'rb')) 16 | fs = pickle.load(open('./HiddenStateExtractor/file_paths_bkp.pkl', 'rb')) 17 | trajs = pickle.load(open('./HiddenStateExtractor/trajectory_in_inds.pkl', 'rb')) 18 | 19 | sites = ['B4-Site_%d' % i for i in [0, 2, 3, 5, 6]] 20 | B4_dats = pickle.load(open('./save_0005_bkp4_B4.pkl', 'rb')) 21 | B4_fs = sorted(B4_dats.keys()) 22 | B4_dats = np.stack([B4_dats[f] for f in B4_fs], 0).reshape((len(B4_fs), -1)) 23 | # B4_trajs = {} 24 | # B4_trajs_nonmg_ratio = {} 25 | # for site in sites: 26 | # site_trajs = pickle.load(open('../data_temp/B4-supps/%s/cell_trajs.pkl' % site, 'rb'))[0] 27 | # site_pixel_assignments = pickle.load(open('../data_temp/B4-supps/%s/cell_pixel_assignments.pkl' % site, 'rb')) 28 | # site_segmentations = np.load('../data_temp/%s_NNProbabilities.npy' % site) 29 | # for i, t in enumerate(site_trajs): 30 | # name = '%s/%d' % (site, i) 31 | # B4_traj_ind = [] 32 | # B4_traj_nonmg_ratio = [] 33 | # for t_point in sorted(t.keys()): 34 | # a, b = site_pixel_assignments[t_point] 35 | # cell_id = t[t_point] 36 | # cell_ps = a[np.where(b == cell_id)] 37 | # cell_segs = np.stack([site_segmentations[t_point, l[0], l[1]] for l in cell_ps]) 38 | # cell_nonmg_ratio = (cell_segs[:, 2] > cell_segs[:, 1]).sum()/float(len(cell_ps)) 39 | # B4_traj_nonmg_ratio.append(cell_nonmg_ratio) 40 | # patch_name = '/data/michaelwu/data_temp/B4-supps/%s/%d_%d.png' % (site, t_point, t[t_point]) 41 | # B4_traj_ind.append(B4_fs.index(patch_name)) 42 | # B4_trajs[name] = B4_traj_ind 43 | # B4_trajs_nonmg_ratio[name] = B4_traj_nonmg_ratio 44 | # valid_ts = [] 45 | # for t in B4_trajs_nonmg_ratio: 46 | # r = np.quantile(B4_trajs_nonmg_ratio[t], 0.9) 47 | # if r < 0.2: 48 | # valid_ts.append(t) 49 | # B4_trajs = {t: B4_trajs[t] for t in valid_ts if len(B4_trajs[t]) > 30} 50 | B4_trajs = pickle.load(open('./HiddenStateExtractor/B4_trajectory_in_inds.pkl', 'rb')) 51 | 52 | pca = PCA(0.5) 53 | dats_ = pca.fit_transform(dats) 54 | B4_dats_ = pca.transform(B4_dats) 55 | 56 | ###################################################################### 57 | cs = [0, 1] 58 | input_shape = (128, 128) 59 | gpu = False 60 | B4_dataset = torch.load('../data_temp/B4_all_adjusted_static_patches.pt') 61 | B4_dataset = rescale(B4_dataset) 62 | model = VQ_VAE(alpha=0.0005, gpu=gpu) 63 | model.load_state_dict(torch.load('./HiddenStateExtractor/save_0005_bkp4.pt', map_location='cpu')) 64 | 65 | sample_fs = ['/data/michaelwu/data_temp/B4-supps/B4-Site_5/35_13.png', 66 | '/data/michaelwu/data_temp/B4-supps/B4-Site_0/149_82.png', 67 | '/data/michaelwu/data_temp/B4-supps/B4-Site_2/118_75.png', 68 | '/data/michaelwu/data_temp/B4-supps/B4-Site_5/151_13.png'] 69 | 70 | for i, f in enumerate(sample_fs): 71 | sample_ind = B4_fs.index(f) 72 | sample = B4_dataset[sample_ind:(sample_ind+1)][0] 73 | output = model(sample)[0] 74 | inp = sample.cpu().data.numpy() 75 | out = output.cpu().data.numpy() 76 | input_phase = (inp[0, 0] * 65535).astype('uint16') 77 | output_phase = (out[0, 0] * 65535).astype('uint16') 78 | input_retardance = (inp[0, 1] * 65535).astype('uint16') 79 | output_retardance = (out[0, 1] * 65535).astype('uint16') 80 | cv2.imwrite('/home/michaelwu/supp_fig8_B4_VAE_pair%d_input_phase.png' % i, enhance_contrast(input_phase, 1., -10000)) # Note dataset has been rescaled 81 | cv2.imwrite('/home/michaelwu/supp_fig8_B4_VAE_pair%d_output_phase.png' % i, enhance_contrast(output_phase, 1., -10000)) 82 | cv2.imwrite('/home/michaelwu/supp_fig8_B4_VAE_pair%d_input_retardance.png' % i, enhance_contrast(input_retardance, 2., 0.)) 83 | cv2.imwrite('/home/michaelwu/supp_fig8_B4_VAE_pair%d_output_retardance.png' % i, enhance_contrast(output_retardance, 2., 0.)) 84 | 85 | ###################################################################### 86 | 87 | plt.clf() 88 | sns.kdeplot(dats_[:, 0], dats_[:, 1], shade=True, cmap="Blues", n_levels=16) 89 | plt.xlim(-4, 4) 90 | plt.ylim(-3, 5) 91 | plt.savefig('/home/michaelwu/supp_fig8_PC1-2_wt.eps') 92 | plt.savefig('/home/michaelwu/supp_fig8_PC1-2_wt.png', dpi=300) 93 | 94 | plt.clf() 95 | sns.kdeplot(B4_dats_[:, 0], B4_dats_[:, 1], shade=True, cmap="Reds", n_levels=16) 96 | plt.xlim(-4, 4) 97 | plt.ylim(-3, 5) 98 | plt.savefig('/home/michaelwu/supp_fig8_PC1-2_sti.eps', dpi=300) 99 | plt.savefig('/home/michaelwu/supp_fig8_PC1-2_sti.png', dpi=300) 100 | 101 | ###################################################################### 102 | 103 | # ts_of_I = [] 104 | # for t in B4_trajs: 105 | # t_dats_ = B4_dats_[np.array(B4_trajs[t])] 106 | # if np.std(t_dats_[:15, 0]) + np.std(t_dats_[:15, 1]) < 1.2 and \ 107 | # np.std(t_dats_[-15:, 0]) + np.std(t_dats_[-15:, 1]) < 1.2 and \ 108 | # np.square(np.mean(t_dats_[-15:, :2], 0) - np.mean(t_dats_[:15, :2], 0)).sum() > 9: 109 | # ts_of_I.append(t) 110 | 111 | # # ['B4-Site_0/2', 112 | # # 'B4-Site_0/18', 113 | # # 'B4-Site_2/212', 114 | # # 'B4-Site_2/258', 115 | # # 'B4-Site_3/12', 116 | # # 'B4-Site_6/10'] 117 | 118 | # for t in ts_of_I: 119 | # os.system('cp /data/michaelwu/data_temp/B4-supps/%s/traj_movies/cell_traj_%s.gif /data/michaelwu/temp_B4_sample_%s.gif' % (t.split('/')[0], t.split('/')[1], t.replace('/', '_'))) 120 | 121 | ###################################################################### 122 | 123 | # Substitute for supp fig 5 124 | 125 | traj_PC1_diffs = [] 126 | traj_PC2_diffs = [] 127 | base_PC1_diffs = [] 128 | base_PC2_diffs = [] 129 | for t in trajs: 130 | traj_PC1 = dats_[np.array(trajs[t])][:, 0] 131 | traj_PC2 = dats_[np.array(trajs[t])][:, 1] 132 | traj_PC1_diff = np.abs(traj_PC1[1:] - traj_PC1[:-1]) 133 | traj_PC2_diff = np.abs(traj_PC2[1:] - traj_PC2[:-1]) 134 | traj_PC1_diffs.append(traj_PC1_diff) 135 | traj_PC2_diffs.append(traj_PC2_diff) 136 | random_PC1 = dats_[np.random.choice(np.arange(dats_.shape[0]), (len(trajs[t]),), replace=False), 0] 137 | random_PC2 = dats_[np.random.choice(np.arange(dats_.shape[0]), (len(trajs[t]),), replace=False), 1] 138 | base_PC1_diffs.append(np.abs(random_PC1[1:] - random_PC1[:-1])) 139 | base_PC2_diffs.append(np.abs(random_PC2[1:] - random_PC2[:-1])) 140 | traj_PC1_diffs = np.concatenate(traj_PC1_diffs) 141 | traj_PC2_diffs = np.concatenate(traj_PC2_diffs) 142 | base_PC1_diffs = np.concatenate(base_PC1_diffs) 143 | base_PC2_diffs = np.concatenate(base_PC2_diffs) 144 | 145 | B4_traj_PC1_diffs = [] 146 | B4_traj_PC2_diffs = [] 147 | for t in B4_trajs: 148 | traj_PC1 = B4_dats_[np.array(B4_trajs[t])][:, 0] 149 | traj_PC2 = B4_dats_[np.array(B4_trajs[t])][:, 1] 150 | traj_PC1_diff = np.abs(traj_PC1[1:] - traj_PC1[:-1]) 151 | traj_PC2_diff = np.abs(traj_PC2[1:] - traj_PC2[:-1]) 152 | B4_traj_PC1_diffs.append(traj_PC1_diff) 153 | B4_traj_PC2_diffs.append(traj_PC2_diff) 154 | B4_traj_PC1_diffs = np.concatenate(B4_traj_PC1_diffs) 155 | B4_traj_PC2_diffs = np.concatenate(B4_traj_PC2_diffs) 156 | 157 | line_orig = np.histogram(traj_PC1_diffs, bins=np.arange(0, 8, 0.2), density=True) 158 | line_B4 = np.histogram(B4_traj_PC1_diffs, bins=np.arange(0, 8, 0.2), density=True) 159 | line_base = np.histogram(base_PC1_diffs, bins=np.arange(0, 8, 0.2), density=True) 160 | plt.clf() 161 | plt.bar(line_orig[1][:-1]+0.1-0.09, line_orig[0], width=0.06, color=cm.get_cmap('Blues')(0.6), label='Original Sites Trajectories') 162 | plt.bar(line_orig[1][:-1]+0.1-0.03, line_B4[0], width=0.06, color=cm.get_cmap('Reds')(0.6), label='B4 Sites Trajectories') 163 | plt.bar(line_orig[1][:-1]+0.1+0.03, line_base[0], width=0.06, color=cm.get_cmap('Greys')(0.5), label='Random Baseline') 164 | plt.legend(fontsize=16) 165 | plt.xlabel('PC1 diff', fontsize=16) 166 | plt.ylabel('Frequency', fontsize=16) 167 | plt.savefig('/home/michaelwu/supp_fig5_distri_PC1.eps') 168 | plt.savefig('/home/michaelwu/supp_fig5_distri_PC1.png', dpi=300) 169 | 170 | line_orig = np.histogram(traj_PC2_diffs, bins=np.arange(0, 8, 0.2), density=True) 171 | line_B4 = np.histogram(B4_traj_PC2_diffs, bins=np.arange(0, 8, 0.2), density=True) 172 | line_base = np.histogram(base_PC2_diffs, bins=np.arange(0, 8, 0.2), density=True) 173 | plt.clf() 174 | plt.bar(line_orig[1][:-1]+0.1-0.09, line_orig[0], width=0.06, color=cm.get_cmap('Blues')(0.6), label='Original Sites Trajectories') 175 | plt.bar(line_orig[1][:-1]+0.1-0.03, line_B4[0], width=0.06, color=cm.get_cmap('Reds')(0.6), label='B4 Sites Trajectories') 176 | plt.bar(line_orig[1][:-1]+0.1+0.03, line_base[0], width=0.06, color=cm.get_cmap('Greys')(0.5), label='Random Baseline') 177 | plt.legend(fontsize=16) 178 | plt.xlabel('PC2 diff', fontsize=16) 179 | plt.ylabel('Frequency', fontsize=16) 180 | plt.savefig('/home/michaelwu/supp_fig5_distri_PC2.eps') 181 | plt.savefig('/home/michaelwu/supp_fig5_distri_PC2.png', dpi=300) 182 | 183 | ###################################################################### 184 | 185 | B4_trajs_positions = {} 186 | for site in sites: 187 | site_trajs, site_trajs_positions = pickle.load(open('../data_temp/B4-supps/%s/cell_trajs.pkl' % site, 'rb')) 188 | for i, t in enumerate(site_trajs): 189 | name = '%s/%d' % (site, i) 190 | if not name in B4_trajs: 191 | continue 192 | t_positions = site_trajs_positions[i] 193 | B4_trajs_positions[name] = t_positions 194 | 195 | traj_average_moving_distances = {} 196 | traj_PC1 = {} 197 | traj_PC2 = {} 198 | for t in B4_trajs: 199 | t_keys = sorted(B4_trajs_positions[t].keys()) 200 | dists = [] 201 | for t_point in range(len(t_keys) - 3): 202 | d = np.linalg.norm(B4_trajs_positions[t][t_keys[t_point+3]] - \ 203 | B4_trajs_positions[t][t_keys[t_point]], ord=2) #+3 to adjust for experiment settings 204 | dists.append(d) 205 | traj_average_moving_distances[t] = np.mean(dists) 206 | pc1s = [B4_dats_[ind, 0] for ind in B4_trajs[t]] 207 | pc2s = [B4_dats_[ind, 1] for ind in B4_trajs[t]] 208 | traj_PC1[t] = np.mean(pc1s) 209 | traj_PC2[t] = np.mean(pc2s) 210 | 211 | t_arrays = sorted(B4_trajs.keys()) 212 | df = pd.DataFrame({'PC1': [traj_PC1[t] for t in t_arrays], 213 | 'PC2': [traj_PC2[t] for t in t_arrays], 214 | 'dists': [np.log(traj_average_moving_distances[t] * 0.722222) for t in t_arrays]}) #0.72um/h for 1pixel/27min 215 | 216 | sns.set_style('white') 217 | bins_y = np.linspace(0.1, 4.3, 20) 218 | bins_x = np.linspace(-4, 4, 20) 219 | plt.clf() 220 | g = sns.JointGrid(x='PC1', y='dists', data=df, ylim=(0.1, 4.3), xlim=(-4, 4)) 221 | _ = g.ax_marg_x.hist(df['PC1'], bins=bins_x) 222 | _ = g.ax_marg_y.hist(df['dists'], bins=bins_y, orientation='horizontal') 223 | g.plot_joint(sns.kdeplot, cmap="Blues", shade=True) 224 | y_ticks = np.array([1.5, 3., 6., 12., 24., 48.]) 225 | g.ax_joint.set_yticks(np.log(y_ticks)) 226 | g.ax_joint.set_yticklabels(y_ticks) 227 | g.set_axis_labels('', '') 228 | plt.savefig('/home/michaelwu/supp_fig8_correlation_kde.eps') 229 | plt.savefig('/home/michaelwu/supp_fig8_correlation_kde.png', dpi=300) 230 | 231 | ###################################################################### 232 | 233 | MSD_length = 60 234 | MSD_min_length = 20 235 | 236 | traj_ensembles = [] 237 | for t in B4_trajs_positions: 238 | t_init = min(B4_trajs_positions[t].keys()) 239 | t_end = max(B4_trajs_positions[t].keys()) + 1 240 | for t_start in range(t_init, t_end - MSD_min_length): 241 | if t_start in B4_trajs_positions[t]: 242 | s_traj = {(t_now - t_start): B4_trajs_positions[t][t_now] \ 243 | for t_now in range(t_start, t_start + MSD_length) if t_now in B4_trajs_positions[t]} 244 | traj_ensembles.append(s_traj) 245 | 246 | traj_MSDs = {} 247 | traj_MSDs_trimmed = {} 248 | for i in range(MSD_length): 249 | s_dists = [np.square(t[i] - t[0]).sum() for t in traj_ensembles if i in t] 250 | traj_MSDs[i] = s_dists 251 | traj_MSDs_trimmed[i] = scipy.stats.trimboth(s_dists, 0.25) 252 | 253 | x = np.arange(3, MSD_length) # Start from 27min to keep consistent 254 | y_bins = np.arange(0.9, 11.7, 0.6) 255 | density_map = np.zeros((MSD_length, len(y_bins) - 1)) 256 | y = [] 257 | for i in range(3, MSD_length): 258 | for d in traj_MSDs[i]: 259 | if d == 0: 260 | continue 261 | ind_bin = ((np.log(d) - y_bins) > 0).sum() - 1 262 | if ind_bin < density_map.shape[1] and ind_bin >= 0: 263 | density_map[i][ind_bin] += 1 264 | y.append((np.log(np.mean(traj_MSDs[i])) - 0.9)/(y_bins[1] - y_bins[0])) 265 | density_map = density_map/density_map.sum(1, keepdims=True) 266 | 267 | def forceAspect(ax,aspect=1): 268 | im = ax.get_images() 269 | extent = im[0].get_extent() 270 | ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect) 271 | 272 | sns.set_style('white') 273 | plt.clf() 274 | fig = plt.figure() 275 | ax = fig.add_subplot(121) 276 | ax.imshow(np.transpose(density_map), cmap='Reds', origin='lower', vmin=0.01, vmax=0.3, alpha=0.5) 277 | ax.plot(x, np.array(y) - 0.5, '.-', c='#ba4748') # -0.5 is the adjustment for imshow 278 | ax.set_xscale('log') 279 | xticks = np.array([0.5, 1, 2, 4, 8]) 280 | xticks_positions = xticks / (9/60) 281 | ax.set_xticks(xticks_positions) 282 | ax.set_xticklabels(xticks) 283 | ax.xaxis.set_minor_locator(NullLocator()) 284 | yticks = np.array([0.5, 2, 8, 32, 128, 512, 2048]) 285 | yticks_positions = (np.log(yticks / (0.325 * 0.325)) - 0.9)/(y_bins[1] - y_bins[0]) - 0.5 # same adjustment for imshow 286 | ax.set_yticks(yticks_positions) 287 | ax.set_yticklabels(yticks) 288 | plt.savefig('/home/michaelwu/supp_fig8_B4_MSD.eps') 289 | plt.savefig('/home/michaelwu/supp_fig8_B4_MSD.png', dpi=300) 290 | 291 | X = np.log(np.arange(1, 60)) 292 | y_ = [np.mean(traj_MSDs[i]) for i in np.arange(1, 60)] 293 | y_ = np.log(np.array(y_)) 294 | est = sm.OLS(y_, sm.add_constant(X)).fit() 295 | print(est.params) 296 | # [5.44530955 1.00921815] 297 | 298 | -------------------------------------------------------------------------------- /plot_scripts/PC_samples.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import pickle 5 | import torch as t 6 | import h5py 7 | import pandas as pd 8 | from NNsegmentation.models import Segment 9 | from NNsegmentation.data import predict_whole_map 10 | from SingleCellPatch.extract_patches import within_range 11 | from pipeline.segmentation import instance_clustering 12 | from SingleCellPatch.generate_trajectories import frame_matching 13 | import matplotlib 14 | from matplotlib import cm 15 | matplotlib.use('AGG') 16 | import matplotlib.pyplot as plt 17 | from matplotlib.ticker import NullLocator 18 | import seaborn as sns 19 | import imageio 20 | from HiddenStateExtractor.vq_vae import VQ_VAE, CHANNEL_MAX, CHANNEL_VAR, prepare_dataset 21 | from HiddenStateExtractor.naive_imagenet import read_file_path, DATA_ROOT 22 | from HiddenStateExtractor.morphology_clustering import select_clean_trajecteories, Kmean_on_short_trajs 23 | from HiddenStateExtractor.movement_clustering import save_traj 24 | import statsmodels.api as sm 25 | import scipy 26 | 27 | RAW_DATA_PATH = '/mnt/comp_micro/Projects/CellVAE/Combined' 28 | sites = ['D%d-Site_%d' % (i, j) for j in range(9) for i in range(3, 6)] 29 | 30 | def enhance_contrast(mat, a=1.5, b=-10000): 31 | mat2 = cv2.addWeighted(mat, 1.5, mat, 0, -10000) 32 | return mat2 33 | 34 | def plot_patch(sample_path, out_path, boundary=False, channel=0): 35 | with h5py.File(sample_path, 'r') as f: 36 | mat = np.array(f['masked_mat'][:, :, channel].astype('uint16')) 37 | mask = np.array(f['masked_mat'][:, :, 2].astype('uint16')) 38 | mat2 = enhance_contrast(mat, 1.5, -10000) 39 | cv2.imwrite(out_path, mat2) 40 | 41 | feat = 'save_0005_before' 42 | fs = sorted(pickle.load(open('./HiddenStateExtractor/file_paths_bkp.pkl', 'rb'))) 43 | trajs = pickle.load(open('./HiddenStateExtractor/trajectory_in_inds.pkl', 'rb')) 44 | dats_ = pickle.load(open('./HiddenStateExtractor/%s_PCA.pkl' % feat, 'rb')) 45 | sizes = pickle.load(open(DATA_ROOT + '/Data/EncodedSizes.pkl', 'rb')) 46 | ss = [sizes[f][0] for f in fs] 47 | 48 | 49 | PC1_vals = dats_[:, 0] 50 | PC1_range = (np.quantile(PC1_vals, 0.4), np.quantile(PC1_vals, 0.6)) 51 | PC2_vals = dats_[:, 1] 52 | PC2_range = (np.quantile(PC2_vals, 0.4), np.quantile(PC2_vals, 0.6)) 53 | 54 | # PC1 55 | vals = dats_[:, 0] 56 | path = '/data/michaelwu/CellVAE/PC_samples/PC1' 57 | val_std = np.std(vals) 58 | 59 | thr0 = np.quantile(vals, 0.1) 60 | thr1 = np.quantile(vals, 0.9) 61 | samples0 = [f for i, f in enumerate(fs) if vals[i] < thr0] 62 | samples1 = [f for i, f in enumerate(fs) if vals[i] > thr1] 63 | sample_ts = [] 64 | for t in trajs: 65 | traj_PCs = np.array([vals[ind] for ind in trajs[t]]) 66 | start = np.mean(traj_PCs[:3]) 67 | end = np.mean(traj_PCs[-3:]) 68 | traj_PC_diff = traj_PCs[1:] - traj_PCs[:-1] 69 | if np.abs(end - start) > 1.2 * val_std and np.median(traj_PC_diff) < 0.5 * val_std: 70 | sample_ts.append(t) 71 | 72 | np.random.seed(123) 73 | for i, f in enumerate(np.random.choice(samples0, (10,), replace=False)): 74 | plot_patch(f, path + '/sample_low_%d.png' % i) 75 | for i, f in enumerate(np.random.choice(samples1, (10,), replace=False)): 76 | plot_patch(f, path + '/sample_high_%d.png' % i) 77 | for t in np.random.choice(sample_ts, (10,), replace=False): 78 | save_traj(t, path + '/sample_traj_%s.gif' % t.replace('/', '_')) 79 | 80 | # PC2, controlling for PC1 81 | vals = dats_[:, 1] 82 | path = '/data/michaelwu/CellVAE/PC_samples/PC2' 83 | vals_filtered = [v for i, v in enumerate(vals) if PC1_range[0] < PC1_vals[i] < PC1_range[1]] 84 | val_std = np.std(vals_filtered) 85 | 86 | thr0 = np.quantile(vals_filtered, 0.1) 87 | thr1 = np.quantile(vals_filtered, 0.9) 88 | samples0 = [f for i, f in enumerate(fs) if vals[i] < thr0 and PC1_range[0] < PC1_vals[i] < PC1_range[1]] 89 | samples1 = [f for i, f in enumerate(fs) if vals[i] > thr1 and PC1_range[0] < PC1_vals[i] < PC1_range[1]] 90 | sample_ts = [] 91 | for t in trajs: 92 | traj_PCs = np.array([vals[ind] for ind in trajs[t]]) 93 | start = np.mean(traj_PCs[:3]) 94 | end = np.mean(traj_PCs[-3:]) 95 | traj_PC_diff = traj_PCs[1:] - traj_PCs[:-1] 96 | if np.abs(end - start) > 1.2 * val_std and np.median(traj_PC_diff) < 0.5 * val_std: 97 | sample_ts.append(t) 98 | 99 | np.random.seed(123) 100 | for i, f in enumerate(np.random.choice(samples0, (10,), replace=False)): 101 | plot_patch(f, path + '/sample_low_%d.png' % i) 102 | for i, f in enumerate(np.random.choice(samples1, (10,), replace=False)): 103 | plot_patch(f, path + '/sample_high_%d.png' % i) 104 | for t in np.random.choice(sample_ts, (10,), replace=False): 105 | save_traj(t, path + '/sample_traj_%s.gif' % t.replace('/', '_')) 106 | 107 | 108 | # PC3, controlling for PC1, PC2 109 | vals = dats_[:, 2] 110 | path = '/data/michaelwu/CellVAE/PC_samples/PC3' 111 | vals_filtered = [v for i, v in enumerate(vals) \ 112 | if PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 113 | val_std = np.std(vals_filtered) 114 | 115 | thr0 = np.quantile(vals_filtered, 0.1) 116 | thr1 = np.quantile(vals_filtered, 0.9) 117 | samples0 = [f for i, f in enumerate(fs) if vals[i] < thr0 and \ 118 | PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 119 | samples1 = [f for i, f in enumerate(fs) if vals[i] > thr1 and \ 120 | PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 121 | sample_ts = [] 122 | for t in trajs: 123 | traj_PCs = np.array([vals[ind] for ind in trajs[t]]) 124 | start = np.mean(traj_PCs[:3]) 125 | end = np.mean(traj_PCs[-3:]) 126 | traj_PC_diff = traj_PCs[1:] - traj_PCs[:-1] 127 | if np.abs(end - start) > 1.2 * val_std and np.median(traj_PC_diff) < 0.5 * val_std: 128 | sample_ts.append(t) 129 | 130 | np.random.seed(123) 131 | for i, f in enumerate(np.random.choice(samples0, (10,), replace=False)): 132 | plot_patch(f, path + '/sample_low_%d.png' % i) 133 | for i, f in enumerate(np.random.choice(samples1, (10,), replace=False)): 134 | plot_patch(f, path + '/sample_high_%d.png' % i) 135 | for t in np.random.choice(sample_ts, (10,), replace=False): 136 | save_traj(t, path + '/sample_traj_%s.gif' % t.replace('/', '_')) 137 | 138 | 139 | # PC4, controlling for PC1, PC2 140 | vals = dats_[:, 3] 141 | path = '/data/michaelwu/CellVAE/PC_samples/PC4' 142 | vals_filtered = [v for i, v in enumerate(vals) \ 143 | if PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 144 | val_std = np.std(vals_filtered) 145 | 146 | thr0 = np.quantile(vals_filtered, 0.1) 147 | thr1 = np.quantile(vals_filtered, 0.9) 148 | samples0 = [f for i, f in enumerate(fs) if vals[i] < thr0 and \ 149 | PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 150 | samples1 = [f for i, f in enumerate(fs) if vals[i] > thr1 and \ 151 | PC1_range[0] < PC1_vals[i] < PC1_range[1] and PC2_range[0] < PC2_vals[i] < PC2_range[1]] 152 | sample_ts = [] 153 | for t in trajs: 154 | traj_PCs = np.array([vals[ind] for ind in trajs[t]]) 155 | start = np.mean(traj_PCs[:3]) 156 | end = np.mean(traj_PCs[-3:]) 157 | traj_PC_diff = traj_PCs[1:] - traj_PCs[:-1] 158 | if np.abs(end - start) > 1.2 * val_std and np.median(traj_PC_diff) < 0.5 * val_std: 159 | sample_ts.append(t) 160 | 161 | np.random.seed(123) 162 | for i, f in enumerate(np.random.choice(samples0, (10,), replace=False)): 163 | plot_patch(f, path + '/sample_low_%d.png' % i) 164 | for i, f in enumerate(np.random.choice(samples1, (10,), replace=False)): 165 | plot_patch(f, path + '/sample_high_%d.png' % i) 166 | for t in np.random.choice(sample_ts, (10,), replace=False): 167 | save_traj(t, path + '/sample_traj_%s.gif' % t.replace('/', '_')) -------------------------------------------------------------------------------- /plot_scripts/plotting_cm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import matplotlib 5 | matplotlib.use('AGG') 6 | import matplotlib.pyplot as plt 7 | import umap 8 | 9 | def zoom_axis(x, y, ax, zoom_cutoff=1): 10 | xlim = [np.percentile(x, zoom_cutoff), np.percentile(x, 100 - zoom_cutoff)] 11 | ylim = [np.percentile(y, zoom_cutoff), np.percentile(y, 100 - zoom_cutoff)] 12 | ax.set_xlim(left=xlim[0], right=xlim[1]) 13 | ax.set_ylim(bottom=ylim[0], top=ylim[1]) 14 | 15 | # train_dirs = ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_train', 16 | # '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_train'] 17 | # train_dir = '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_train' 18 | # input_dirs = ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_input_tstack/mock_matching_point2', 19 | # '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_input_tstack/mock_matching_point2'] 20 | # input_dirs = ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_input_tstack/mock+low_moi_matching_point05', 21 | # '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_input_tstack/mock+low_moi_matching_point05'] 22 | input_dirs = ['/CompMicro/projects/cardiomyocytes/200721_CM_Mock_SPS_Fluor/20200721_CM_Mock_SPS/dnm_input_tstack/mock_z32_nh16_nrh16_ne512_cc0.25', 23 | '/CompMicro/projects/cardiomyocytes/20200722CM_LowMOI_SPS_Fluor/20200722 CM_LowMOI_SPS/dnm_input_tstack/mock_z32_nh16_nrh16_ne512_cc0.25'] 24 | 25 | 26 | 27 | dats = [] 28 | pcas = [] 29 | labels = [] 30 | label = 0 31 | for input_dir in input_dirs: 32 | dat = pickle.load(open(os.path.join(input_dir, 'im_latent_space_after.pkl'), 'rb')) 33 | pca = pickle.load(open(os.path.join(input_dir, 'im_latent_space_after_PCAed.pkl'), 'rb')) 34 | # dats = pickle.load(open(os.path.join(input_path, 'im_latent_space.pkl'), 'rb')) 35 | # dats_ = pickle.load(open(os.path.join(input_path, 'im_latent_space_PCAed.pkl'), 'rb')) 36 | dats.append(dat) 37 | pcas.append(pca) 38 | labels += [label] * dat.shape[0] 39 | label += 1 40 | dats = np.concatenate(dats, axis=0) 41 | pcas = np.concatenate(pcas, axis=0) 42 | #%% 43 | plt.clf() 44 | zoom_cutoff = 1 45 | conditions = ['mock', 'infected'] 46 | fig, ax = plt.subplots() 47 | scatter = ax.scatter(pcas[:, 0], pcas[:, 1], s=7, c=labels, cmap='Paired', alpha=0.1) 48 | scatter.set_facecolor("none") 49 | zoom_axis(pcas[:, 0], pcas[:, 1], ax, zoom_cutoff=zoom_cutoff) 50 | legend1 = ax.legend(handles=scatter.legend_elements()[0], 51 | loc="upper right", title="condition", labels=conditions) 52 | ax.set_xlabel('PC 1') 53 | ax.set_ylabel('PC 2') 54 | plt.savefig(os.path.join(input_dir, 'PCA.png'), dpi=300) 55 | #%% 56 | # a_s = [1.58, 1, 1, 0.5] 57 | # b_s = [0.9, 0.9, 1.5, 1.5] 58 | a_s = [1.58] 59 | b_s = [0.9] 60 | n_nbrs = [15, 50, 200, 1000] 61 | n_rows = 2 62 | n_cols = 2 63 | # xlim = [-7, 7] 64 | # # ylim = [-7, 7] 65 | fig, ax = plt.subplots(n_rows, n_cols, squeeze=False) 66 | ax = ax.flatten() 67 | fig.set_size_inches((5 * n_cols, 5 * n_rows)) 68 | axis_count = 0 69 | # top and bottom % of data to cut off 70 | zoom_cutoff = 1 71 | for n_nbr in n_nbrs: 72 | for a, b in zip(a_s, b_s): 73 | # embedding, labels = pickle.load(open(os.path.join(input_dir, 'umap_{}_nbr.pkl'.format(n_nbr)), 'rb')) 74 | 75 | reducer = umap.UMAP(a=a, b=b, n_neighbors=n_nbr) 76 | embedding = reducer.fit_transform(dats) 77 | with open(os.path.join(input_dir, 'umap_{}_nbr.pkl'.format(n_nbr)), 'wb') as f: 78 | pickle.dump([embedding, labels], f) 79 | 80 | scatter = ax[axis_count].scatter(embedding[:, 0], embedding[:, 1], s=7, c=labels, 81 | facecolors='none', cmap='Paired', alpha=0.1) 82 | scatter.set_facecolor("none") 83 | ax[axis_count].set_title('n_neighbors={}'.format(n_nbr), fontsize=12) 84 | # ax[axis_count].set_title('a={}, b={}'.format(a, b), fontsize=12) 85 | zoom_axis(embedding[:, 0], embedding[:, 1], ax[axis_count], zoom_cutoff=zoom_cutoff) 86 | if axis_count == 0: 87 | legend1 = ax[axis_count].legend(handles=scatter.legend_elements()[0], 88 | loc="upper right", title="condition", labels=conditions) 89 | ax[axis_count].set_xlabel('UMAP 1') 90 | ax[axis_count].set_ylabel('UMAP 2') 91 | 92 | axis_count += 1 93 | fig.savefig(os.path.join(input_dir, 'UMAP.png'), 94 | dpi=300, bbox_inches='tight') 95 | plt.close(fig) -------------------------------------------------------------------------------- /plot_scripts/recon_loss.py: -------------------------------------------------------------------------------- 1 | from HiddenStateExtractor.vq_vae import * 2 | import torch 3 | 4 | cs = [0, 1] 5 | cs_mask = [2, 3] 6 | input_shape = (128, 128) 7 | gpu = False 8 | path = '/mnt/comp_micro/Projects/CellVAE' 9 | 10 | ### Load Data ### 11 | fs = pickle.load(open('./HiddenStateExtractor/file_paths_bkp.pkl', 'rb')) 12 | 13 | dataset = torch.load('StaticPatchesAll.pt') 14 | dataset = rescale(dataset) 15 | B4_dataset = torch.load('../data_temp/B4_all_adjusted_static_patches.pt') 16 | B4_dataset = rescale(B4_dataset) 17 | 18 | model = VQ_VAE(alpha=0.0005, gpu=gpu) 19 | model.load_state_dict(torch.load('./HiddenStateExtractor/save_0005_bkp4.pt', map_location='cpu')) 20 | 21 | np.random.seed(123) 22 | r_losses = [] 23 | for i in np.random.choice(np.arange(len(dataset)), (5000,), replace=False): 24 | sample = dataset[i:(i+1)][0] 25 | output, loss = model.forward(sample) 26 | r_loss = loss['recon_loss'] 27 | r_losses.append(r_loss.data.cpu().numpy()) 28 | 29 | B4_r_losses = [] 30 | for i in np.random.choice(np.arange(len(B4_dataset)), (5000,), replace=False): 31 | sample = B4_dataset[i:(i+1)][0] 32 | output, loss = model.forward(sample) 33 | r_loss = loss['recon_loss'] 34 | B4_r_losses.append(r_loss.data.cpu().numpy()) 35 | 36 | #0.00756±0.01691 37 | #0.00795±0.00617 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r ./requirements/default.txt 2 | -------------------------------------------------------------------------------- /requirements/default.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.0.3 2 | scipy>=1.2.1 3 | opencv-python>=3.4.2.16 4 | matplotlib>=3.0.3 5 | scikit-image>=0.16.1 6 | scikit-learn>0.20.0 7 | tifffile>=0.15.1 8 | imageio>=2.6.0 9 | POT>=0.6.0 10 | h5py>=2.10.0 11 | tensorflow==2.1 12 | torch>=1.0.1 13 | torchvision==0.2.2 14 | segmentation_models==1.0.1 15 | pyyaml>=0.2.5 16 | umap-learn>=0.5.1 17 | -------------------------------------------------------------------------------- /requirements/pwrai_docker.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.0.3 2 | scipy>=1.2.1 3 | tifffile>=0.15.1 4 | imageio>=2.6.0 5 | keras==2.2.5 6 | segmentation_models==0.2.1 7 | umap-learn -------------------------------------------------------------------------------- /run_VAE.py: -------------------------------------------------------------------------------- 1 | from pipeline.patch_VAE import assemble_VAE, process_VAE, trajectory_matching 2 | from SingleCellPatch.extract_patches import get_im_sites 3 | from torch.multiprocessing import Pool, Queue, Process 4 | import torch.multiprocessing as mp 5 | import os, sys 6 | import argparse 7 | from configs.config_reader import YamlReader 8 | 9 | 10 | class Worker(Process): 11 | def __init__(self, inputs, gpuid=0, method='assemble'): 12 | super().__init__() 13 | self.gpuid = gpuid 14 | self.inputs = inputs 15 | self.method = method 16 | 17 | def run(self): 18 | if self.method == 'assemble': 19 | # assemble_VAE(*self.inputs) 20 | #TODO: make "patch_type" part of the config 21 | assemble_VAE(*self.inputs, patch_type='mat') 22 | elif self.method == 'process': 23 | process_VAE(*self.inputs, gpu=self.gpuid) 24 | elif self.method == 'trajectory_matching': 25 | trajectory_matching(*self.inputs) 26 | 27 | 28 | def main(method_, raw_dir_, supp_dir_, config_): 29 | method = method_ 30 | 31 | inputs = raw_dir_ 32 | outputs = supp_dir_ 33 | weights = config_.latent_encoding.weights 34 | # channels = config_.inference.channels 35 | # network = config_.inference.model 36 | # gpu_id = config_.latent_encoding.gpu_ids 37 | gpus = config_.latent_encoding.gpu_ids 38 | gpu_count = len(gpus) 39 | 40 | # assert len(channels) > 0, "At least one channel must be specified" 41 | 42 | # todo file path checks can be done earlier 43 | # assemble needs raw (write file_paths/static_patches/adjusted_patches), and supp (read site-supps) 44 | if method == 'assemble': 45 | if not inputs: 46 | raise AttributeError("raw directory must be specified when method = assemble") 47 | if not outputs: 48 | raise AttributeError("supplementary directory must be specified when method = assemble") 49 | 50 | # process needs raw (load _file_paths), and target (torch weights) 51 | elif method == 'process': 52 | if not inputs: 53 | raise AttributeError("raw directory must be specified when method = process") 54 | # if type(weights) is not list: 55 | # weights = [weights] 56 | if not weights: 57 | raise AttributeError("pytorch VQ-VAE weights path must be specified when method = process") 58 | 59 | # trajectory matching needs raw (load file_paths, write trajectories), supp (load cell_traj) 60 | elif method == 'trajectory_matching': 61 | if not inputs: 62 | raise AttributeError("raw directory must be specified when method = trajectory_matching") 63 | if not outputs: 64 | raise AttributeError("supplementary directory must be specified when method = trajectory_matching") 65 | 66 | if config_.latent_encoding.fov: 67 | sites = config_.latent_encoding.fov 68 | else: 69 | # get all "XX-SITE_#" identifiers in raw data directory 70 | sites = get_im_sites(inputs) 71 | 72 | wells = set(s[:2] for s in sites) 73 | mp.set_start_method('spawn', force=True) 74 | 75 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 76 | # os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpu_ids]) 77 | # print("CUDA_VISIBLE_DEVICES=" + os.environ["CUDA_VISIBLE_DEVICES"]) 78 | for i, well in enumerate(wells): 79 | well_sites = [s for s in sites if s[:2] == well] 80 | args = (inputs, outputs, well_sites, config_) 81 | gpu_idx = i % gpu_count 82 | gpu_id = gpus[gpu_idx] 83 | p = Worker(args, gpuid=gpu_id, method=method) 84 | p.start() 85 | p.join() 86 | 87 | # for weight in weights: 88 | # print('Encoding using model {}'.format(weight)) 89 | # well_sites = [s for s in sites if s[:2] == well] 90 | # args = (inputs, outputs, channels, weight, well_sites, network) 91 | # p = Worker(args, gpuid=gpu, method=method) 92 | # p.start() 93 | # p.join() 94 | 95 | 96 | def parse_args(): 97 | """ 98 | Parse command line arguments for CLI. 99 | 100 | :return: namespace containing the arguments passed. 101 | """ 102 | parser = argparse.ArgumentParser() 103 | 104 | parser.add_argument( 105 | '-m', '--method', 106 | type=str, 107 | required=True, 108 | choices=['assemble', 'process', 'trajectory_matching'], 109 | default='assemble', 110 | help="Method: one of 'assemble', 'process' or 'trajectory_matching'", 111 | ) 112 | parser.add_argument( 113 | '-c', '--config', 114 | type=str, 115 | required=True, 116 | help='path to yaml configuration file' 117 | ) 118 | return parser.parse_args() 119 | 120 | 121 | if __name__ == '__main__': 122 | arguments = parse_args() 123 | config = YamlReader() 124 | config.read_config(arguments.config) 125 | 126 | # batch run 127 | for (raw_dir, supp_dir) in list(zip(config.latent_encoding.raw_dirs, config.latent_encoding.supp_dirs)): 128 | main(arguments.method, raw_dir, supp_dir, config) 129 | -------------------------------------------------------------------------------- /run_patch.py: -------------------------------------------------------------------------------- 1 | # bchhun, {2020-02-21} 2 | 3 | from pipeline.patch_VAE import extract_patches, build_trajectories 4 | from multiprocessing import Pool, Queue, Process 5 | import os 6 | import numpy as np 7 | import argparse 8 | from configs.config_reader import YamlReader 9 | 10 | 11 | class Worker(Process): 12 | def __init__(self, inputs, cpu_id=0, method='extract_patches'): 13 | super().__init__() 14 | self.cpu_id = cpu_id 15 | self.inputs = inputs 16 | self.method = method 17 | 18 | def run(self): 19 | if self.method == 'extract_patches': 20 | extract_patches(*self.inputs) 21 | elif self.method == 'build_trajectories': 22 | build_trajectories(*self.inputs) 23 | 24 | 25 | def main(method_, raw_dir_, supp_dir_, config_): 26 | 27 | print("CLI arguments provided") 28 | raw = raw_dir_ 29 | supp = supp_dir_ 30 | method = method_ 31 | fov = config.patch.fov 32 | 33 | n_cpus = config.patch.num_cpus 34 | 35 | # extract patches needs raw (NN probs, stack), supp (cell_positions, cell_pixel_assignments) 36 | if method == 'extract_patches': 37 | if not raw: 38 | raise AttributeError("raw directory must be specified when method = extract_patches") 39 | if not supp: 40 | raise AttributeError("supplementary directory must be specified when method = extract_patches") 41 | 42 | # extract patches needs supp (cell_positions, cell_pixel_assignments) 43 | elif method == 'build_trajectories': 44 | if not supp: 45 | raise AttributeError("supplementary directory must be specified when method = extract_patches") 46 | 47 | if fov: 48 | sites = fov 49 | else: 50 | # get all "XX-SITE_#" identifiers in raw data directory 51 | img_names = [file for file in os.listdir(raw) if (file.endswith(".npy")) & ('_NN' not in file)] 52 | sites = [os.path.splitext(img_name)[0] for img_name in img_names] 53 | sites = list(set(sites)) 54 | # if probabilities and formatted stack exist 55 | segment_sites = [site for site in sites if os.path.exists(os.path.join(raw, "%s.npy" % site)) and \ 56 | os.path.exists(os.path.join(raw, "%s_NNProbabilities.npy" % site))] 57 | if len(segment_sites) == 0: 58 | raise AttributeError("no sites found in raw directory with preprocessed data and matching NNProbabilities") 59 | 60 | # process each site on a different GPU if using multi-gpu 61 | sep = np.linspace(0, len(segment_sites), n_cpus + 1).astype(int) 62 | 63 | # TARGET is never used in either extract_patches or build_trajectory 64 | processes = [] 65 | for i in range(n_cpus): 66 | _sites = segment_sites[sep[i]:sep[i + 1]] 67 | args = (raw, supp, _sites, config_) 68 | p = Worker(args, cpu_id=i, method=method) 69 | p.start() 70 | processes.append(p) 71 | for p in processes: 72 | p.join() 73 | 74 | 75 | def parse_args(): 76 | """ 77 | Parse command line arguments for CLI. 78 | 79 | :return: namespace containing the arguments passed. 80 | """ 81 | parser = argparse.ArgumentParser() 82 | 83 | parser.add_argument( 84 | '-m', '--method', 85 | type=str, 86 | required=False, 87 | choices=['extract_patches', 'build_trajectories'], 88 | default='extract_patches', 89 | help="Method: one of 'extract_patches', 'build_trajectories'", 90 | ) 91 | parser.add_argument( 92 | '-c', '--config', 93 | type=str, 94 | required=True, 95 | help='path to yaml configuration file' 96 | ) 97 | 98 | return parser.parse_args() 99 | 100 | 101 | if __name__ == '__main__': 102 | arguments = parse_args() 103 | config = YamlReader() 104 | config.read_config(arguments.config) 105 | 106 | # batch run 107 | for (raw_dir, supp_dir) in list(zip(config.patch.raw_dirs, config.patch.supp_dirs)): 108 | main(arguments.method, raw_dir, supp_dir, config) 109 | -------------------------------------------------------------------------------- /run_preproc.py: -------------------------------------------------------------------------------- 1 | 2 | # 1. check input: (n_frames * 2048 * 2048 * 2) channel 0 - phase, channel 1 - retardance 3 | # 2. adjust channel range 4 | # a. phase: 32767 plus/minus 1600~2000 5 | # b. retardance: 1400~1600 plus/minus 1500~1800 6 | # 3. save as '$SITE_NAME.npy' numpy array, dtype=uint16 7 | 8 | from pipeline.preprocess import write_raw_to_npy 9 | import os 10 | import fnmatch 11 | import re 12 | 13 | import argparse 14 | from configs.config_reader import YamlReader 15 | import logging 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | def main(input_, output_, config_): 20 | """ 21 | Using supplied config file parameters, prepare specified datasets for downstream analysis 22 | 23 | :param input_: str 24 | Path to a single experiment 25 | :param output_: str 26 | Path to output directory for prepared datasets 27 | :param config_: YamlReader 28 | YamlReader object containing parsed configuration values 29 | :return: 30 | """ 31 | 32 | chans = config_.preprocess.channels 33 | multi = config_.preprocess.multipage 34 | z_slice = config_.preprocess.z_slice if config_.preprocess.z_slice else None 35 | fovs = config_.preprocess.fov 36 | 37 | # === build list or dict of all sites we wish to process === 38 | 39 | # positions are identified by subfolder names 40 | if config_.preprocess.pos_dir: 41 | log.info("pos dir, identifying all subfolders") 42 | if fovs == 'all': 43 | sites = [site for site in os.listdir(input_) if os.path.isdir(os.path.join(input_, site))] 44 | elif type(fovs) is list: 45 | sites = [site for site in os.listdir(input_) if os.path.isdir(os.path.join(input_, site)) and site in fovs] 46 | else: 47 | raise NotImplementedError("FOV subfolder expected, or preprocess FOVs must be 'all' or list of positions") 48 | 49 | # positions are identified by indicies 50 | # assume files have name structure "t###_p###_z###" 51 | elif not config_.preprocess.pos_dir: 52 | log.info("no pos dir, identifiying all files") 53 | sites = {} 54 | all_files = [f for f in os.listdir(input_) 55 | if os.path.isfile(os.path.join(input_, f)) and '_p' in f and '.tif' in f] 56 | 57 | if fovs == 'all': 58 | log.info("fovs = all, looping ") 59 | # for every position index in the file, assign the image to a dict key 60 | while all_files: 61 | pos = [int(p_idx.strip('p')) for p_idx in all_files[0].split('_') if 'p' in p_idx][0] 62 | if not pos: 63 | all_files.pop(0) 64 | if pos in sites.keys(): 65 | sites[pos].append(os.path.join(input_, all_files.pop(0))) 66 | else: 67 | sites[pos] = [os.path.join(input_, all_files.pop(0))] 68 | 69 | elif type(fovs) is list: 70 | for fov in fovs: 71 | sites[fov] = [os.path.join(input_, f) for f in sorted(fnmatch.filter(all_files, f'*p{fov:03d}*'))] 72 | else: 73 | raise NotImplementedError("FOV index expected, or preprocess FOVs must be 'all' or list of positions") 74 | else: 75 | raise NotImplementedError("pos_dir must be boolean True/False") 76 | 77 | # write sites 78 | for site in sorted(sites): 79 | if not os.path.exists(output_): 80 | os.makedirs(output_) 81 | 82 | # site represents a position folder 83 | if type(site) is str: 84 | s_list = [os.path.join(input_, site, f) for f in sorted(os.listdir(os.path.join(input_, site)))] 85 | 86 | # site represents a position index 87 | elif type(site) is int: 88 | s_list = sites[site] 89 | else: 90 | log.warning(f"no files found for position = {site}") 91 | continue 92 | 93 | write_raw_to_npy(site, s_list, output_, chans, z_slice, multipage=multi) 94 | 95 | 96 | def parse_args(): 97 | """ 98 | Parse command line arguments for CLI. 99 | 100 | :return: namespace containing the arguments passed. 101 | """ 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | '-c', '--config', 105 | type=str, 106 | required=True, 107 | help='path to yaml configuration file' 108 | ) 109 | 110 | return parser.parse_args() 111 | 112 | 113 | if __name__ == '__main__': 114 | arguments = parse_args() 115 | config = YamlReader() 116 | config.read_config(arguments.config) 117 | 118 | for (src, target) in list(zip(config.preprocess.image_dirs, config.preprocess.target_dirs)): 119 | main(src, target, config) 120 | 121 | -------------------------------------------------------------------------------- /run_segmentation.py: -------------------------------------------------------------------------------- 1 | # bchhun, {2020-02-21} 2 | 3 | from pipeline.segmentation import segmentation, instance_segmentation 4 | from pipeline.segmentation_validation import segmentation_validation_michael 5 | from multiprocessing import Process 6 | import os 7 | import numpy as np 8 | import logging 9 | log = logging.getLogger(__name__) 10 | 11 | import argparse 12 | from configs.config_reader import YamlReader 13 | 14 | 15 | class Worker(Process): 16 | def __init__(self, inputs, gpuid=0, method='segmentation'): 17 | super().__init__() 18 | self.gpuid = gpuid 19 | self.inputs = inputs 20 | self.method = method 21 | 22 | def run(self): 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpuid) 25 | 26 | if self.method == 'segmentation': 27 | log.info(f"running segmentation worker on {self.gpuid}") 28 | segmentation(*self.inputs) 29 | elif self.method == 'instance_segmentation': 30 | log.info(f"running instance segmentation") 31 | instance_segmentation(*self.inputs) 32 | elif self.method == 'segmentation_validation': 33 | segmentation_validation_michael(*self.inputs) 34 | 35 | 36 | def main(method_, raw_dir_, supp_dir_, val_dir_, config_): 37 | method = method_ 38 | 39 | inputs = raw_dir_ 40 | outputs = supp_dir_ 41 | gpus = config_.segmentation.inference.gpu_ids 42 | gpu_count = len(gpus) 43 | 44 | assert len(config_.segmentation.inference.channels) > 0, "At least one channel must be specified" 45 | 46 | # segmentation validation requires raw, supp, and validation definitions 47 | if method == 'segmentation_validation': 48 | if not val_dir_: 49 | raise AttributeError("validation directory must be specified when method=segmentation_validation") 50 | if not outputs: 51 | raise AttributeError("supplemntary directory must be specifie dwhen method=segmentation_validation") 52 | 53 | # segmentation requires raw (NNProb), and weights to be defined 54 | elif method == 'segmentation': 55 | if config_.segmentation.inference.weights is None: 56 | raise AttributeError("Weights supp_dir must be specified when method=segmentation") 57 | 58 | # instance segmentation requires raw (stack, NNprob), supp (to write outputs) to be defined 59 | elif method == 'instance_segmentation': 60 | TARGET = '' 61 | else: 62 | raise AttributeError(f"method flag {method} not implemented") 63 | 64 | # all methods all require 65 | if config_.segmentation.inference.fov: 66 | sites = config_.segmentation.inference.fov 67 | else: 68 | # get all "XX-SITE_#" identifiers in raw data directory 69 | img_names = [file for file in os.listdir(inputs) if (file.endswith(".npy")) & ('_NN' not in file)] 70 | sites = [os.path.splitext(img_name)[0] for img_name in img_names] 71 | sites = list(set(sites)) 72 | 73 | segment_sites = [site for site in sites if os.path.exists(os.path.join(inputs, "%s.npy" % site))] 74 | sep = np.linspace(0, len(segment_sites), gpu_count + 1).astype(int) 75 | 76 | processes = [] 77 | for i, gpu in enumerate(gpus): 78 | _sites = segment_sites[sep[i]:sep[i + 1]] 79 | args = (inputs, outputs, val_dir_, _sites, config_) 80 | process = Worker(args, gpuid=gpu, method=method) 81 | process.start() 82 | processes.append(process) 83 | for p in processes: 84 | p.join() 85 | 86 | 87 | def parse_args(): 88 | """ 89 | Parse command line arguments for CLI. 90 | 91 | :return: namespace containing the arguments passed. 92 | """ 93 | parser = argparse.ArgumentParser() 94 | 95 | parser.add_argument( 96 | '-m', '--method', 97 | type=str, 98 | required=True, 99 | choices=['segmentation', 'instance_segmentation', 'segmentation_validation'], 100 | default='segmentation', 101 | help="Method: one of 'segmentation', 'instance_segmentation', or 'segmentation_validation'", 102 | ) 103 | 104 | parser.add_argument( 105 | '-c', '--config', 106 | type=str, 107 | required=True, 108 | help='path to yaml configuration file. Run_segmentation takes arguments from "inference" category' 109 | ) 110 | 111 | return parser.parse_args() 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | arguments = parse_args() 117 | config = YamlReader() 118 | config.read_config(arguments.config) 119 | 120 | # batch run 121 | for (raw_dir, supp_dir, val_dir) in list(zip(config.segmentation.inference.raw_dirs, 122 | config.segmentation.inference.supp_dirs, 123 | config.segmentation.inference.validation_dirs)): 124 | main(arguments.method, raw_dir, supp_dir, val_dir, config) 125 | --------------------------------------------------------------------------------