├── README.md
├── Train.py
├── __init__.py
├── evaluations
├── Clefts.py
├── NeuronIds.py
├── SynapticPartners.py
├── __init__.py
├── border_mask.py
├── rand.py
├── synaptic_partners.py
└── voi.py
├── img
├── *Filtered Mask.png
├── *Visualize Boundary.png
├── 6p.png
├── acc.png
├── err.png
├── loss.png
├── n.py
├── res window.png
└── rot.png
├── io
├── CremiFile.py
└── __init__.py
├── models
├── Resnet.py
└── Resnet_3.py
└── type
├── Annotations.py
└── Volume.py
/README.md:
--------------------------------------------------------------------------------
1 | # CREMIchallenge2017 - Neuron Image Segmentation Task
2 | This is a tentative experiments on solving automated neuron segmentation task using deep learing methods, residual networks.
3 |
4 | See the original challenge post at: https://cremi.org, and leaderboard at: https://cremi.org/leaderboard/
5 |
6 | Use Train.py file o train the models.
7 |
8 |
9 | ## Experimental Restuls
10 | Best at 100 epoch:
11 |
12 | **Acc: 98.88%;**
13 |
14 | **Loss: 0.0298**
15 |
16 |
17 |
18 |
19 |
20 | ### Approaches
21 | - For this task, I trained a 2 way classifier to classify the central pixel in 127*127 sample as boudary and non-boundary. The 2-way sofmax layer was applied before the output of the network.
22 | - Reproduced and used **residual network method**. (original: https://arxiv.org/abs/1512.03385, implementation on github: https://github.com/gcr/torch-residual-networks). This has been giving me a great boost in classificaiton results.
23 |
24 | - (see plot below) It was found in preliminary experiments that using a 5-7-5 window for the three conv layers in the bottleneck block of residual net (training on 127*127 sample size, green line) outformed the originally proposed 1-3-1 structure (gray line) by a large margin, so experiments reported above were all trained with the 5-7-5. The position of batch normalization and dropout layer in the block was also changed.
25 |
26 |
27 |
28 | - **Selectively choose training samples from raw (see figure below)**: the yellow area **X3 dilated boundary** pixels were avoided to be chosen, only green and purple (true boudary, background) pixels will be selected into training batches,.
29 | - **Random rotation techniques**: various augmentation approches were explored, including rand rotations of +/-60, rand +/- 30, on 33.33%, 50% of samples in each batc. rand +/- 60 deg on 50% of samples (see figure below) was found to perform the best so far.
30 |
31 |
32 |
33 |
34 | ### Future work
35 |
36 | - The neighbor area of the boundaries was avoided in this experiment, however the boundary pixels from other organels (intracellular organels) should also be avoided. These pixels could be easily treated as target neuron boundaries which are actually not. The approach to address this challenge can be to pre-train a network to recognize these intracellular boundaries and filter out these pixels when creating training batches for the segmentation task.
37 |
38 | - the raw is originally a 3D image of size 125 * 1250 * 1250. I started by treating each layer in deapth 125 as an independent sample and trained my network with images in 2D sections. However, in later stages of experiments (which i was not able to do due to the time limit of my project), the third deimension should be considered to address the correlation between the neuron pixels at depth.
39 |
40 | ## Dependencies
41 |
42 | * python
43 | * pytorch
44 | * numpy
45 | * matplotlib
46 |
--------------------------------------------------------------------------------
/Train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import torch.legacy.nn as L
6 | from torch.autograd import Variable
7 |
8 | from torch.utils.data import Dataset, DataLoader, TensorDataset
9 | import torchvision
10 | from torchvision import transforms, datasets
11 | import torchvision.models as models
12 | import numpy as np
13 | from tempfile import TemporaryFile
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | import matplotlib.pyplot as plt
17 | import scipy.ndimage as ndimage
18 | from scipy.ndimage.interpolation import rotate
19 | import time
20 | import os
21 | import random
22 |
23 | from Annotations import *
24 | from Volume import *
25 | from CremiFile import *
26 | from voi import voi
27 | from rand import adapted_rand
28 | plt.style.use('ggplot')
29 |
30 | from models.Resnet_3 import *
31 | from models.Resnet import *
32 |
33 |
34 |
35 | print('')
36 | print('DATASET LOADING ...')
37 | print('')
38 | emdset = CremiFile('sample_B_20160501.hdf', 'r')
39 |
40 | #Check the content of the datafile
41 | print "Has raw: " + str(emdset.has_raw())
42 | print "Has neuron ids: " + str(emdset.has_neuron_ids())
43 | print "Has clefts: " + str(emdset.has_clefts())
44 | print "Has annotations: " + str(emdset.has_annotations())
45 |
46 |
47 | #Read volume and annotation
48 | raw = emdset.read_raw()
49 | neuron_ids = emdset.read_neuron_ids()
50 | clefts = emdset.read_clefts()
51 | annotations = emdset.read_annotations()
52 |
53 | print("")
54 | print "Read raw: " + str(raw) + \
55 | ", resolution " + str(raw.resolution) + \
56 | ", offset " + str(raw.offset) + \
57 | ", data size " + str(raw.data.shape) + \
58 | ("" if raw.comment == None else ", comment \"" + raw.comment + "\"")
59 |
60 | print "Read neuron_ids: " + str(neuron_ids) + \
61 | ", resolution " + str(neuron_ids.resolution) + \
62 | ", offset " + str(neuron_ids.offset) + \
63 | ", data size " + str(neuron_ids.data.shape) + \
64 | ("" if neuron_ids.comment == None else ", comment \"" + neuron_ids.comment + "\"")
65 |
66 | print "Read clefts: " + str(clefts) + \
67 | ", resolution " + str(clefts.resolution) + \
68 | ", offset " + str(clefts.offset) + \
69 | ", data size " + str(clefts.data.shape) + \
70 | ("" if clefts.comment == None else ", comment \"" + clefts.comment + "\"")
71 |
72 |
73 |
74 |
75 | def mask_filtered(raw, neuron_ids):
76 | """
77 | Image boudnary dilation
78 | Compute mask on each depth for un-selectable dilated(6X) boundary pixels (labeled as value 200.),
79 | the selectable background (0.) and actual boundary (100.) pixels
80 |
81 | return(numpy array): mask of shape 125,1250,1250
82 | """
83 | print ''
84 | print ''
85 | print 'building mask-5X from raw dataset...'
86 | since = time.time()
87 |
88 | d, h, w = raw.data.shape
89 | mask = np.empty([d, h, w]).astype('float32')
90 | for i in range(d):
91 | for j in range(h):
92 | for k in range(w):
93 | pixel = neuron_ids.data[i, j, k]
94 | if check_boundary(pixel, i, j, k, neuron_ids):
95 | mask[i, j, k] = 100
96 | else:
97 | mask[i, j, k] = 0
98 | if (i + 1) % 1 == 0:
99 | print str(0.8 * (i + 1)) + '% done'
100 |
101 | mask_dilated = ndimage.binary_dilation(mask, iterations=7).astype(mask.dtype)
102 | mask_filtered = 200 * mask_dilated - mask
103 |
104 | filter_time = time.time()
105 | time_elapsed = filter_time - since
106 | print('Mask complete in {:.0f}m {:.0f}s'.format(
107 | time_elapsed // 60, time_elapsed % 60))
108 |
109 |
110 | print 'save to maskfile7X.npy'
111 | np.save('maskfile7X.npy', mask_filtered)
112 | print 'saved'
113 |
114 |
115 | def check_boundary(pixel, x, y, z, neuron_ids):
116 | """
117 | Check if a pixel at position (x,y,z) is labeled
118 | as boundary/non-boundary in neuron_ids.
119 |
120 | return(boolean): boundary
121 | """
122 | max_z = neuron_ids.data.shape[2] - 1
123 | max_y = neuron_ids.data.shape[1] - 1
124 | a = neuron_ids.data[x, y, z - 1] if z > 0 else pixel
125 | b = neuron_ids.data[x, y, z + 1] if z < max_z else pixel
126 | c = neuron_ids.data[x, y - 1, z] if y > 0 else pixel
127 | d = neuron_ids.data[x, y + 1, z] if y < max_y else pixel
128 | e = neuron_ids.data[x, y - 1, z - 1] if (y > 0 and z > 0) else pixel
129 | f = neuron_ids.data[x, y - 1, z + 1] if (y > 0 and z < max_z) else pixel
130 | g = neuron_ids.data[x, y + 1, z - 1] if (y < max_y and z > 0) else pixel
131 | h = neuron_ids.data[x, y + 1, z + 1] if (y < max_y and z < max_z) else pixel
132 |
133 | neighbors = [a, b, c, d, e, f, g, h]
134 | boundary = False
135 | for neighbor in neighbors:
136 | if pixel != neighbor:
137 | boundary = True
138 |
139 | return boundary
140 |
141 |
142 | # Seed a random number generator
143 | #seed = 24102016
144 | #rng = np.random.RandomState(seed)
145 | def random_rotation(inputs):
146 | """Randomly rotates a subset of images in a batch.
147 | reference: https://github.com/CSTR-Edinburgh/mlpractical/blob/mlp2016-7/master/notebooks/05_Non-linearities_and_regularisation.ipynb
148 |
149 | * chooses 30-50% of the images in the batch at random
150 | * for each image in the 30-50% chosen, rotates the image by a random angle in [-60, 60]
151 | * returns a new array of size (129, 129) in which the rows corresponding to the 25% chosen images are the vectors corresponding to the new randomly rotated images, while the remaining rows correspond to the original images.
152 | Args:
153 | inputs: Input image batch, an array of shape (129, 129).
154 |
155 | Returns:
156 | An array of shape (129, 129) corresponding to a copy
157 | of the original `inputs` array that was randomly selected
158 | to be rotated by a random angle. The original `inputs`
159 | array should not be modified.
160 | """
161 |
162 | new_ims = np.zeros(inputs.shape).astype('float32')
163 | indices = random.randint(-1,1)
164 | angles = random.uniform(-30., 30.)
165 | if indices == 0:
166 | rotate(inputs, angles, output = new_ims, order=1, reshape=False)
167 |
168 | return new_ims
169 |
170 |
171 |
172 |
173 | #mask_filtered(raw, neuron_ids) # used only when the first time of training
174 | mask = np.load('maskfile5X.npy')
175 | print ''
176 | print 'mask loaded'
177 |
178 | class NeuronSegmenDataset(Dataset):
179 | """Raw pixel and its label.
180 | Dataset splitted into 80,000 training and 20,000 validation set
181 | """
182 |
183 | def __init__(self, raw, neuron_ids, mask, phase, transform=None):
184 | """
185 | Args:
186 | raw(Volume): raw
187 | neuron_ids(Volume): neuron segmentation labels
188 | mask(numpy ndarray): filtered mask
189 | phase(String): 'train' or 'val'
190 | transform(callable, optional): optional data augmentation to be applied
191 | """
192 |
193 | self.phase = phase
194 | self.raw = raw
195 | self.neuron_ids = neuron_ids
196 | self.mask = mask
197 | self.transform = transform
198 |
199 | def __len__(self):
200 | """ length of the dataset """
201 | if self.phase == 'train':
202 | x = 80000
203 | else:
204 | x = 20000
205 |
206 | return x
207 |
208 | def __getitem__(self, idx):
209 | """
210 | Return 33*33 patches for each raw pixel at the center
211 | positive if boundary pixel, negative if non-boundary pixel
212 | """
213 | depth = self.raw.data.shape[0]
214 | size = self.raw.data.shape[1]
215 |
216 |
217 | while True:
218 | d = random.randint(0, depth - 1)
219 | h = random.randint(64, size - 65)
220 | w = random.randint(64, size - 65)
221 | ids_pixel = self.neuron_ids.data[d, h, w]
222 | pixel = self.mask[d, h, w]
223 |
224 | if idx % 2 == 0: #control half samples to be boundary pixels
225 | if pixel == 100.:
226 | raw_batch = self.raw.data[d][h - 64:h + 65, w - 64:w + 65].astype(
227 | 'float32') # crop a 129*129 patch
228 |
229 |
230 | if self.transform:
231 | raw_batch = self.transform(raw_batch)
232 |
233 | raw_batch = raw_batch.reshape([1, 129, 129])
234 | raw_batch = torch.from_numpy(raw_batch)
235 | sample = (raw_batch, 0)
236 |
237 | break
238 | elif pixel == 0.: # the other half as non-boundary pixel
239 | raw_batch = self.raw.data[d][h - 64:h + 65, w - 64:w + 65].astype(
240 | 'float32') # crop 33*33 patch
241 | raw_batch = raw_batch.reshape([1, 129, 129])
242 |
243 | if self.transform:
244 | raw_batch = self.transform(raw_batch)
245 |
246 | raw_batch = torch.from_numpy(raw_batch)
247 | sample = (raw_batch, 1)
248 |
249 | break
250 | return sample
251 |
252 |
253 | batch_size = 100
254 | emdset_seg = {x: NeuronSegmenDataset(raw, neuron_ids, mask, x, transform=random_rotation)
255 | for x in ['train', 'val']}
256 | emdset_loaders = {x: DataLoader(emdset_seg[x], batch_size=batch_size, shuffle=True)
257 | for x in ['train', 'val']}
258 | dset_sizes = {x: len(emdset_seg[x]) for x in ['train', 'val']}
259 | dset_classes = ['boundary', 'non-b']
260 | use_gpu = torch.cuda.is_available()
261 |
262 | print "Load num of batches: train " + str(len(emdset_loaders['train'])) + \
263 | ' validation ' + str(len(emdset_loaders['val']))
264 |
265 | print ('done')
266 | print ('')
267 |
268 |
269 | class ConvNet(nn.Module):
270 | def __init__(self, D_out, kernel= 3, window =2, padding=1):
271 | super(ConvNet, self).__init__()
272 | self.conv1 = nn.Conv2d(1, 32, kernel, padding=padding)
273 | self.conv2 = nn.Conv2d(32, 32, kernel, padding=padding)
274 | self.conv3 = nn.Conv2d(32, 64, kernel, padding=padding)
275 | self.conv4 = nn.Conv2d(64, 64, kernel, padding=padding)
276 | self.conv5 = nn.Conv2d(64, 128, kernel, padding=padding)
277 | self.conv6 = nn.Conv2d(128, 128, kernel, padding=padding)
278 | self.conv7 = nn.Conv2d(128, 256, kernel, padding=padding)
279 | self.conv8 = nn.Conv2d(256, 256, kernel, padding=padding)
280 | self.conv9 = nn.Conv2d(256, 512, kernel, padding=padding)
281 | self.conv10 = nn.Conv2d(512, 512, kernel, padding=padding)
282 | self.pool = nn.MaxPool2d(window)
283 | self.linear1 = nn.Linear(2*2*256, 256)
284 | self.linear2 = nn.Linear(512, 256)
285 | self.linear3 = nn.Linear(256, 128)
286 | self.linear4 = nn.Linear(128, D_out)
287 |
288 | def forward(self, x):
289 | x = F.relu(self.conv1(x))
290 | print "conv 1: " + str(x.data.size())
291 | x = F.relu(self.conv2(x))
292 | print "conv 2: " + str(x.data.size())
293 | x = self.pool(x)
294 | print "pool 1: " + str(x.data.size())
295 |
296 | x = F.relu(self.conv3(x))
297 | print "conv 3: " + str(x.data.size())
298 | x = F.relu(self.conv4(x))
299 | print "conv 4: " + str(x.data.size())
300 | x = self.pool(x)
301 | print "pool 2: " + str(x.data.size())
302 |
303 | x = F.relu(self.conv5(x))
304 | print "conv 5: " + str(x.data.size())
305 | x = F.relu(self.conv6(x))
306 | print "conv 6: " + str(x.data.size())
307 | x = self.pool(x)
308 | print "pool 3: " + str(x.data.size())
309 |
310 | x = F.relu(self.conv7(x))
311 | print "conv 7: " + str(x.data.size())
312 | x = F.relu(self.conv8(x))
313 | print "conv 8: " + str(x.data.size())
314 | x = self.pool(x)
315 | print "pool 4: " + str(x.data.size())
316 |
317 | x = x.view(-1, 2*2*256)
318 | x = F.relu(self.linear1(x))
319 |
320 | x = F.relu(self.linear3(x))
321 | x = F.relu(self.linear4(x))
322 |
323 | return x
324 |
325 |
326 |
327 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=10):
328 | """Decay learning rate by a factor of 0.1 every 10 epochs"""
329 | lr = init_lr * (0.1**(epoch // lr_decay_epoch))
330 | for param_group in optimizer.param_groups:
331 | param_group['lr'] = lr
332 |
333 | return optimizer
334 |
335 |
336 | def piecewise_scheduler(optimizer, epoch):
337 | if epoch % 50 ==0 :
338 | for param_group in optimizer.param_groups:
339 | lr = param_group['lr'] / 2
340 | param_group['lr'] = lr
341 |
342 | return optimizer
343 |
344 |
345 |
346 | def train_model (model, criterion, optimizer, lr_scheduler=None, num_epochs=100):
347 | since = time.time()
348 | train_voi_split = np.zeros(num_epochs)
349 | train_voi_merge = np.zeros(num_epochs)
350 | train_rand = np.zeros(num_epochs)
351 |
352 | # iterate over epoch
353 | for epoch in range(num_epochs):
354 | print ('Epoch{}/{}'.format(epoch+1, num_epochs))
355 | print ('-' * 10)
356 |
357 |
358 | # train and validation set
359 | for phase in ['train', 'val']:
360 | if phase == 'train':
361 | if lr_scheduler:
362 | optimizer = lr_scheduler(optimizer, epoch + 1)
363 | model.train(True)
364 | else:
365 | model.train(True)
366 |
367 | running_loss = 0.
368 | running_accuracy = 0.
369 | total = 0
370 |
371 | # iterate over each batch
372 | for i, data in enumerate(emdset_loaders[phase]):
373 | inputs, labels = data
374 | if use_gpu:
375 | model = model.cuda()
376 | inputs, labels = Variable(inputs.cuda()), \
377 | Variable(labels.cuda())
378 | else:
379 | inputs, labels = Variable(inputs), Variable(labels)
380 |
381 |
382 | optimizer.zero_grad() # clean gradients in buffer
383 |
384 | outputs = model(inputs)
385 | _, predicted = torch.max(outputs.data, 1)
386 | loss = criterion(outputs, labels)
387 |
388 |
389 | if phase == 'train':
390 | loss.backward()
391 | optimizer.step()
392 |
393 |
394 | running_loss += loss.data[0]
395 | running_accuracy += (predicted == labels.data).sum()
396 |
397 | # visualize random patches
398 | visualize_pred = True
399 | tt = visualize_pred and epoch == num_epochs-1 and phase == 'val' \
400 | and i+1 == len(emdset_loaders['val'])
401 | if tt:
402 | print('visualizing...')
403 | images_so_far = 0
404 | fig = plt.figure()
405 | for j in [6, 15, 38, 41, 86, 99]:
406 | images_so_far += 1
407 | ax = fig.add_subplot(3, 2, images_so_far)
408 | ax.axis('off')
409 | ax.set_title('Pred: {},\n Labeled: {}'.format(dset_classes[int(predicted.cpu().numpy()[j])],
410 | dset_classes[labels.data[j]]))
411 | ax.imshow(inputs.cpu().data[j].view(129,129).numpy())
412 | fig.savefig('6p.png')
413 | print 'done and saved to 6p.png'
414 |
415 |
416 | # normalize by number of batches
417 | running_loss /= (i + 1)
418 | running_accuracy = 100 * running_accuracy / dset_sizes[phase]
419 |
420 | # print statistics
421 | if epoch % 1 == 0:
422 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(
423 | phase, running_loss, running_accuracy
424 | ))
425 | # print "\tvoi split : " + str(train_voi_split)
426 | # print "\tvoi merge : " + str(train_voi_merge)
427 | # print "\tadapted RAND: " + str(train_rand)
428 |
429 | # Visualize the model. raw, labeled, predicted (optional)
430 | visualize = False
431 | if visualize:
432 | print('')
433 | print('Begin to visualize model..')
434 | visualize_model(model)
435 |
436 | time_elapsed = time.time() - train_time
437 | print('Visualizing complete in {:.0f}m {:.0f}s'.format(
438 | time_elapsed // 60, time_elapsed % 60))
439 |
440 |
441 |
442 | def visualize_model(model, i = 80, s = 300):
443 | """ Args:
444 | model: model
445 | i: depth of the raw image to visualize
446 | s: crop the 1250*1250 image to the size of s*s
447 | """
448 | fig = plt.figure(figsize=(15,6))
449 |
450 | ax_ori = fig.add_subplot(1,3,1)
451 | ax_lab = fig.add_subplot(1,3,2)
452 | ax_pred = fig.add_subplot(1,3,3)
453 |
454 | ax_ori.imshow(raw.data[i][0:s, 0:s])
455 | ax_ori.set_title('raw')
456 | ax_lab.imshow(neuron_ids.data[i][0:s, 0:s])
457 | ax_lab.set_title('labeled')
458 |
459 |
460 | preds = np.empty([s*s])
461 | for j in range(s*s):
462 | pixel = raw.data[i][j/s, j%s]
463 | input = np.random.uniform(-10000, 0, (1, 1, 33, 33)).astype('float32') ## boundary patch: positive
464 | input[0, 0, 16, 16] = pixel
465 | input = torch.from_numpy(input)
466 |
467 | model.train(False)
468 | if use_gpu:
469 | model = model.cuda()
470 | input = Variable(input.cuda())
471 | else:
472 | input = Variable(input)
473 |
474 | outputs = model(input)
475 | _, pred = torch.max(outputs.data, 1)
476 | pred = pred.cpu().numpy()
477 | if pred[0] == 0:
478 | preds[j] = 20000
479 | else:
480 | preds[j] = 100
481 |
482 | if j == 30000:
483 | print '1/3 done'
484 | if j == 60000:
485 | print '2/3 done'
486 |
487 | print preds.reshape(s, s)
488 | ax_pred.imshow(preds.reshape((s,s)))
489 | ax_pred.set_title('predicted')
490 |
491 | ax_lab.axis('off')
492 | ax_ori.axis('off')
493 | ax_pred.axis('off')
494 |
495 | plt.show()
496 | fig.savefig('vi.png')
497 | print('saved as vi.png')
498 |
499 |
500 |
501 |
502 |
503 |
504 | # Train
505 | #-----------------------------------------------------------------------
506 | num_classes = 2
507 | num_epochs = 100
508 | #model = ConvNet(num_classes )
509 | #model = DeepResNet18(num_classes)
510 | #model = DeepResNet34(num_classes )
511 | model = DeepResNet50(num_classes)
512 | #model = DeepResNet101(num_classes )
513 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
514 | #optimizer = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4,6,8,10], gamma=0.5)
515 | criterion = nn.CrossEntropyLoss()
516 | print('')
517 | print('START TRAINING ...')
518 | print(time.time())
519 | print('ResNet50. 33% 30deg lr50')
520 | train = train_model(model, criterion, optimizer, lr_scheduler=piecewise_scheduler, num_epochs=num_epochs)
521 |
522 |
523 |
524 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from Annotations import *
2 | from Volume import *
3 | from NeuronIds import *
4 | from border_mask import *
5 | from CremiFile import *
6 |
7 |
--------------------------------------------------------------------------------
/evaluations/Clefts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import ndimage
3 |
4 | class Clefts:
5 |
6 | def __init__(self, test, truth):
7 |
8 | test_clefts = test
9 | truth_clefts = truth
10 |
11 | self.test_clefts_mask = np.equal(test_clefts.data, 0xffffffffffffffff)
12 | self.truth_clefts_mask = np.equal(truth_clefts.data, 0xffffffffffffffff)
13 |
14 | self.test_clefts_edt = ndimage.distance_transform_edt(self.test_clefts_mask, sampling=test_clefts.resolution)
15 | self.truth_clefts_edt = ndimage.distance_transform_edt(self.truth_clefts_mask, sampling=truth_clefts.resolution)
16 |
17 | def count_false_positives(self, threshold = 200):
18 |
19 | mask1 = np.invert(self.test_clefts_mask)
20 | mask2 = self.truth_clefts_edt > threshold
21 | false_positives = self.truth_clefts_edt[np.logical_and(mask1, mask2)]
22 | return false_positives.size
23 |
24 | def count_false_negatives(self, threshold = 200):
25 |
26 | mask1 = np.invert(self.truth_clefts_mask)
27 | mask2 = self.test_clefts_edt > threshold
28 | false_negatives = self.test_clefts_edt[np.logical_and(mask1, mask2)]
29 | return false_negatives.size
30 |
31 | def acc_false_positives(self):
32 |
33 | mask = np.invert(self.test_clefts_mask)
34 | false_positives = self.truth_clefts_edt[mask]
35 | stats = {
36 | 'mean': np.mean(false_positives),
37 | 'std': np.std(false_positives),
38 | 'max': np.amax(false_positives),
39 | 'count': false_positives.size,
40 | 'median': np.median(false_positives)}
41 | return stats
42 |
43 | def acc_false_negatives(self):
44 |
45 | mask = np.invert(self.truth_clefts_mask)
46 | false_negatives = self.test_clefts_edt[mask]
47 | stats = {
48 | 'mean': np.mean(false_negatives),
49 | 'std': np.std(false_negatives),
50 | 'max': np.amax(false_negatives),
51 | 'count': false_negatives.size,
52 | 'median': np.median(false_negatives)}
53 | return stats
54 |
55 |
56 |
--------------------------------------------------------------------------------
/evaluations/NeuronIds.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from border_mask import create_border_mask
3 | from voi import voi
4 | from rand import adapted_rand
5 |
6 | class NeuronIds:
7 |
8 | def __init__(self, groundtruth, border_threshold = None):
9 | """Create a new evaluation object for neuron ids against the provided ground truth.
10 |
11 | Parameters
12 | ----------
13 |
14 | groundtruth: Volume
15 | The ground truth volume containing neuron ids.
16 |
17 | border_threshold: None or float, in world units
18 | Pixels within `border_threshold` to a label border in the
19 | same section will be assigned to background and ignored during
20 | the evaluation.
21 | """
22 |
23 | assert groundtruth.resolution[1] == groundtruth.resolution[2], \
24 | "x and y resolutions of ground truth are not the same (%f != %f)" % \
25 | (groundtruth.resolution[1], groundtruth.resolution[2])
26 |
27 | self.groundtruth = groundtruth
28 | self.border_threshold = border_threshold
29 |
30 | if self.border_threshold:
31 |
32 | print "Computing border mask..."
33 |
34 | self.gt = np.zeros(groundtruth.data.shape, dtype=np.uint64)
35 | create_border_mask(
36 | groundtruth.data,
37 | self.gt,
38 | float(border_threshold)/groundtruth.resolution[1],
39 | np.uint64(-1))
40 | else:
41 | self.gt = np.array(self.groundtruth.data).copy()
42 |
43 | # current voi and rand implementations don't work with np.uint64(-1) as
44 | # background label, so we make it 0 here and bump all other labels
45 | self.gt += 1
46 |
47 | def voi(self, segmentation):
48 |
49 | assert list(segmentation.data.shape) == list(self.groundtruth.data.shape)
50 | assert list(segmentation.resolution) == list(self.groundtruth.resolution)
51 |
52 | print "Computing VOI..."
53 |
54 | return voi(np.array(segmentation.data), self.gt, ignore_groundtruth = [0])
55 |
56 | def adapted_rand(self, segmentation):
57 |
58 | assert list(segmentation.data.shape) == list(self.groundtruth.data.shape)
59 | assert list(segmentation.resolution) == list(self.groundtruth.resolution)
60 |
61 | print "Computing RAND..."
62 |
63 | return adapted_rand(np.array(segmentation.data), self.gt)
64 |
--------------------------------------------------------------------------------
/evaluations/SynapticPartners.py:
--------------------------------------------------------------------------------
1 | from synaptic_partners import synaptic_partners_fscore
2 |
3 | class SynapticPartners:
4 |
5 | def __init__(self, matching_threshold = 400):
6 |
7 | self.matching_threshold = matching_threshold
8 |
9 | def fscore(self, rec_annotations, gt_annotations, gt_segmentation, all_stats = False):
10 |
11 | return synaptic_partners_fscore(rec_annotations, gt_annotations, gt_segmentation, self.matching_threshold, all_stats)
12 |
--------------------------------------------------------------------------------
/evaluations/__init__.py:
--------------------------------------------------------------------------------
1 | from Clefts import *
2 | from NeuronIds import *
3 | from SynapticPartners import *
4 | from border_mask import *
5 |
--------------------------------------------------------------------------------
/evaluations/border_mask.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import scipy
4 |
5 | def create_border_mask(input_data, target, max_dist, background_label,axis=0):
6 | """
7 | Overlay a border mask with background_label onto input data.
8 | A pixel is part of a border if one of its 4-neighbors has different label.
9 |
10 | Parameters
11 | ----------
12 | input_data : h5py.Dataset or numpy.ndarray - Input data containing neuron ids
13 | target : h5py.Datset or numpy.ndarray - Target which input data overlayed with border mask is written into.
14 | max_dist : int or float - Maximum distance from border for pixels to be included into the mask.
15 | background_label : int - Border mask will be overlayed using this label.
16 | axis : int - Axis of iteration (perpendicular to 2d images for which mask will be generated)
17 | """
18 | sl = [slice(None) for d in xrange(len(target.shape))]
19 |
20 | for z in xrange(target.shape[axis]):
21 | sl[ axis ] = z
22 | border = create_border_mask_2d(input_data[tuple(sl)], max_dist)
23 | target_slice = input_data[tuple(sl)] if isinstance(input_data,h5py.Dataset) else np.copy(input_data[tuple(sl)])
24 | target_slice[border] = background_label
25 | target[tuple(sl)] = target_slice
26 |
27 | def create_and_write_masked_neuron_ids(in_file, out_file, max_dist, background_label, overwrite=False):
28 | """
29 | Overlay a border mask with background_label onto input data loaded from in_file and write into out_file.
30 | A pixel is part of a border if one of its 4-neighbors has different label.
31 |
32 | Parameters
33 | ----------
34 | in_file : CremiFile - Input file containing neuron ids
35 | out_file : CremiFile - Output file which input data overlayed with border mask is written into.
36 | max_dist : int or float - Maximum distance from border for pixels to be included into the mask.
37 | background_label : int - Border mask will be overlayed using this label.
38 | overwrite : bool - Overwrite existing data in out_file (True) or do nothing if data is present in out_file (False).
39 | """
40 | if ( not in_file.has_neuron_ids() ) or ( (not overwrite) and out_file.has_neuron_ids() ):
41 | return
42 |
43 | neuron_ids, resolution, offset, comment = in_file.read_neuron_ids()
44 | comment = ('' if comment is None else comment + ' ') + 'Border masked with max_dist=%f' % max_dist
45 |
46 | path = "/volumes/labels/neuron_ids"
47 | group_path = "/".join( path.split("/")[:-1] )
48 | ds_name = path.split("/")[-1]
49 | if ( out_file.has_neuron_ids() ):
50 | del out_file.h5file[path]
51 | if (group_path not in out_file.h5file):
52 | out_file.h5file.create_group(group_path)
53 |
54 | group = out_file.h5file[group_path]
55 | target = group.create_dataset(ds_name, shape=neuron_ids.shape, dtype=neuron_ids.dtype)
56 | target.attrs["resolution"] = resolution
57 | target.attrs["comment"] = comment
58 | if offset != (0.0, 0.0, 0.0):
59 | target.attrs["offset"] = offset
60 |
61 | create_border_mask(neuron_ids, target, max_dist, background_label)
62 |
63 | def create_border_mask_2d(image, max_dist):
64 | """
65 | Create binary border mask for image.
66 | A pixel is part of a border if one of its 4-neighbors has different label.
67 |
68 | Parameters
69 | ----------
70 | image : numpy.ndarray - Image containing integer labels.
71 | max_dist : int or float - Maximum distance from border for pixels to be included into the mask.
72 |
73 | Returns
74 | -------
75 | mask : numpy.ndarray - Binary mask of border pixels. Same shape as image.
76 | """
77 | max_dist = max(max_dist, 0)
78 |
79 | padded = np.pad(image, 1, mode='edge')
80 |
81 | border_pixels = np.logical_and(
82 | np.logical_and( image == padded[:-2, 1:-1], image == padded[2:, 1:-1] ),
83 | np.logical_and( image == padded[1:-1, :-2], image == padded[1:-1, 2:] )
84 | )
85 |
86 | distances = scipy.ndimage.distance_transform_edt(
87 | border_pixels,
88 | return_distances=True,
89 | return_indices=False
90 | )
91 |
92 | return distances <= max_dist
93 |
94 |
--------------------------------------------------------------------------------
/evaluations/rand.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import numpy as np
4 | import scipy.sparse as sparse
5 |
6 | # Evaluation code courtesy of Juan Nunez-Iglesias, taken from
7 | # https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py
8 |
9 | def adapted_rand(seg, gt, all_stats=False):
10 | """Compute Adapted Rand error as defined by the SNEMI3D contest [1]
11 |
12 | Formula is given as 1 - the maximal F-score of the Rand index
13 | (excluding the zero component of the original labels). Adapted
14 | from the SNEMI3D MATLAB script, hence the strange style.
15 |
16 | Parameters
17 | ----------
18 | seg : np.ndarray
19 | the segmentation to score, where each value is the label at that point
20 | gt : np.ndarray, same shape as seg
21 | the groundtruth to score against, where each value is a label
22 | all_stats : boolean, optional
23 | whether to also return precision and recall as a 3-tuple with rand_error
24 |
25 | Returns
26 | -------
27 | are : float
28 | The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$,
29 | where $p$ and $r$ are the precision and recall described below.
30 | prec : float, optional
31 | The adapted Rand precision. (Only returned when `all_stats` is ``True``.)
32 | rec : float, optional
33 | The adapted Rand recall. (Only returned when `all_stats` is ``True``.)
34 |
35 | References
36 | ----------
37 | [1]: http://brainiac2.mit.edu/SNEMI3D/evaluation
38 | """
39 | # segA is truth, segB is query
40 | segA = np.ravel(gt)
41 | segB = np.ravel(seg)
42 | n = segA.size
43 |
44 | n_labels_A = np.amax(segA) + 1
45 | n_labels_B = np.amax(segB) + 1
46 |
47 | ones_data = np.ones(n)
48 |
49 | p_ij = sparse.csr_matrix((ones_data, (segA[:], segB[:])), shape=(n_labels_A, n_labels_B))
50 |
51 | a = p_ij[1:n_labels_A,:]
52 | b = p_ij[1:n_labels_A,1:n_labels_B]
53 | c = p_ij[1:n_labels_A,0].todense()
54 | d = b.multiply(b)
55 |
56 | a_i = np.array(a.sum(1))
57 | b_i = np.array(b.sum(0))
58 |
59 | sumA = np.sum(a_i * a_i)
60 | sumB = np.sum(b_i * b_i) + (np.sum(c) / n)
61 | sumAB = np.sum(d) + (np.sum(c) / n)
62 |
63 | precision = sumAB / sumB
64 | recall = sumAB / sumA
65 |
66 | fScore = 2.0 * precision * recall / (precision + recall)
67 | are = 1.0 - fScore
68 |
69 | if all_stats:
70 | return (are, precision, recall)
71 | else:
72 | return are
73 |
--------------------------------------------------------------------------------
/evaluations/synaptic_partners.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from munkres import Munkres
3 | import numpy as np
4 |
5 | def synaptic_partners_fscore(rec_annotations, gt_annotations, gt_segmentation, matching_threshold = 400, all_stats = False):
6 | """Compute the f-score of the found synaptic partners.
7 |
8 | Parameters
9 | ----------
10 |
11 | rec_annotations: Annotations, containing found synaptic partners
12 |
13 | gt_annotations: Annotations, containing ground truth synaptic partners
14 |
15 | gt_segmentation: Volume, ground truth neuron segmentation
16 |
17 | matching_threshold: float, world units
18 | Euclidean distance threshold to consider two annotations a potential
19 | match. Annotations that are `matching_threshold` or more untis apart
20 | from each other are not considered as potential matches.
21 |
22 | all_stats: boolean, optional
23 | Whether to also return precision, recall, FP, FN, and matches as a 6-tuple with f-score
24 |
25 | Returns
26 | -------
27 |
28 | fscore: float
29 | The f-score of the found synaptic partners.
30 | precision: float, optional
31 | recall: float, optional
32 | fp: int, optional
33 | fn: int, optional
34 | filtered_matches: list of tuples, optional
35 | The indices of the matches with matching costs.
36 | """
37 |
38 | # get cost matrix
39 | costs = cost_matrix(rec_annotations, gt_annotations, gt_segmentation, matching_threshold)
40 |
41 | # match using Hungarian method
42 | print "Finding cost-minimal matches..."
43 | munkres = Munkres()
44 | matches = munkres.compute(costs.copy()) # have to copy, because munkres changes the cost matrix...
45 |
46 | filtered_matches = [ (i,j, costs[i][j]) for (i,j) in matches if costs[i][j] <= matching_threshold ]
47 | print str(len(filtered_matches)) + " matches found"
48 |
49 | # unmatched in rec = FP
50 | fp = len(rec_annotations.pre_post_partners) - len(filtered_matches)
51 |
52 | # unmatched in gt = FN
53 | fn = len(gt_annotations.pre_post_partners) - len(filtered_matches)
54 |
55 | # all ground truth elements - FN = TP
56 | tp = len(gt_annotations.pre_post_partners) - fn
57 |
58 | precision = float(tp)/(tp + fp)
59 | recall = float(tp)/(tp + fn)
60 | fscore = 2.0*precision*recall/(precision + recall)
61 |
62 | if all_stats:
63 | return (fscore, precision, recall, fp, fn, filtered_matches)
64 | else:
65 | return fscore
66 |
67 | def cost_matrix(rec, gt, gt_segmentation, matching_threshold):
68 |
69 | print "Computing matching costs..."
70 |
71 | rec_locations = pre_post_locations(rec, gt_segmentation)
72 | gt_locations = pre_post_locations(gt, gt_segmentation)
73 |
74 | rec_labels = pre_post_labels(rec_locations, gt_segmentation)
75 | gt_labels = pre_post_labels(gt_locations, gt_segmentation)
76 |
77 | size = max(len(rec_locations), len(gt_locations))
78 | costs = np.zeros((size, size), dtype=np.float)
79 | costs[:] = 2*matching_threshold
80 | num_potential_matches = 0
81 | for i in range(len(rec_locations)):
82 | for j in range(len(gt_locations)):
83 | c = cost(rec_locations[i], gt_locations[j], rec_labels[i], gt_labels[j], matching_threshold)
84 | costs[i,j] = c
85 | if c <= matching_threshold:
86 | num_potential_matches += 1
87 |
88 | print str(num_potential_matches) + " potential matches found"
89 |
90 | return costs
91 |
92 | def pre_post_locations(annotations, gt_segmentation):
93 | """Get the locations of the annotations relative to the ground truth offset."""
94 |
95 | locations = annotations.locations()
96 | shift = sub(annotations.offset, gt_segmentation.offset)
97 |
98 | return [
99 | (add(annotations.get_annotation(pre_id)[1], shift), add(annotations.get_annotation(post_id)[1], shift)) for (pre_id, post_id) in annotations.pre_post_partners
100 | ]
101 |
102 | def pre_post_labels(locations, segmentation):
103 |
104 | return [ (segmentation[pre], segmentation[post]) for (pre, post) in locations ]
105 |
106 |
107 | def cost(pre_post_location1, pre_post_location2, labels1, labels2, matching_threshold):
108 |
109 | max_cost = 2*matching_threshold
110 |
111 | # pairs do not link the same segments
112 | if labels1 != labels2:
113 | return max_cost
114 |
115 | pre_dist = distance(pre_post_location1[0], pre_post_location2[0])
116 | post_dist = distance(pre_post_location1[1], pre_post_location2[1])
117 |
118 | if pre_dist > matching_threshold or post_dist > matching_threshold:
119 | return max_cost
120 |
121 | return 0.5*(pre_dist + post_dist)
122 |
123 | def distance(a, b):
124 | return np.linalg.norm(np.array(list(a))-np.array(list(b)))
125 |
126 | def add(a, b):
127 | return tuple([a[d] + b[d] for d in range(len(b))])
128 |
129 | def sub(a, b):
130 | return tuple([a[d] - b[d] for d in range(len(b))])
131 |
--------------------------------------------------------------------------------
/evaluations/voi.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # Evaluation code courtesy of Juan Nunez-Iglesias, taken from
4 | # https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py
5 |
6 | import numpy as np
7 | import scipy.sparse as sparse
8 |
9 | def voi(reconstruction, groundtruth, ignore_reconstruction=[], ignore_groundtruth=[0]):
10 | """Return the conditional entropies of the variation of information metric. [1]
11 |
12 | Let X be a reconstruction, and Y a ground truth labelling. The variation of
13 | information between the two is the sum of two conditional entropies:
14 |
15 | VI(X, Y) = H(X|Y) + H(Y|X).
16 |
17 | The first one, H(X|Y), is a measure of oversegmentation, the second one,
18 | H(Y|X), a measure of undersegmentation. These measures are referred to as
19 | the variation of information split or merge error, respectively.
20 |
21 | Parameters
22 | ----------
23 | seg : np.ndarray, int type, arbitrary shape
24 | A candidate segmentation.
25 | gt : np.ndarray, int type, same shape as `seg`
26 | The ground truth segmentation.
27 | ignore_seg, ignore_gt : list of int, optional
28 | Any points having a label in this list are ignored in the evaluation.
29 | By default, only the label 0 in the ground truth will be ignored.
30 |
31 | Returns
32 | -------
33 | (split, merge) : float
34 | The variation of information split and merge error, i.e., H(X|Y) and H(Y|X)
35 |
36 | References
37 | ----------
38 | [1] Meila, M. (2007). Comparing clusterings - an information based
39 | distance. Journal of Multivariate Analysis 98, 873-895.
40 | """
41 | (hyxg, hxgy) = split_vi(reconstruction, groundtruth, ignore_reconstruction, ignore_groundtruth)
42 | return (hxgy, hyxg)
43 |
44 | def split_vi(x, y=None, ignore_x=[0], ignore_y=[0]):
45 | """Return the symmetric conditional entropies associated with the VI.
46 |
47 | The variation of information is defined as VI(X,Y) = H(X|Y) + H(Y|X).
48 | If Y is the ground-truth segmentation, then H(Y|X) can be interpreted
49 | as the amount of under-segmentation of Y and H(X|Y) is then the amount
50 | of over-segmentation. In other words, a perfect over-segmentation
51 | will have H(Y|X)=0 and a perfect under-segmentation will have H(X|Y)=0.
52 |
53 | If y is None, x is assumed to be a contingency table.
54 |
55 | Parameters
56 | ----------
57 | x : np.ndarray
58 | Label field (int type) or contingency table (float). `x` is
59 | interpreted as a contingency table (summing to 1.0) if and only if `y`
60 | is not provided.
61 | y : np.ndarray of int, same shape as x, optional
62 | A label field to compare to `x`.
63 | ignore_x, ignore_y : list of int, optional
64 | Any points having a label in this list are ignored in the evaluation.
65 | Ignore 0-labeled points by default.
66 |
67 | Returns
68 | -------
69 | sv : np.ndarray of float, shape (2,)
70 | The conditional entropies of Y|X and X|Y.
71 |
72 | See Also
73 | --------
74 | vi
75 | """
76 | _, _, _ , hxgy, hygx, _, _ = vi_tables(x, y, ignore_x, ignore_y)
77 | # false merges, false splits
78 | return np.array([hygx.sum(), hxgy.sum()])
79 |
80 | def vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]):
81 | """Return probability tables used for calculating VI.
82 |
83 | If y is None, x is assumed to be a contingency table.
84 |
85 | Parameters
86 | ----------
87 | x, y : np.ndarray
88 | Either x and y are provided as equal-shaped np.ndarray label fields
89 | (int type), or y is not provided and x is a contingency table
90 | (sparse.csc_matrix) that may or may not sum to 1.
91 | ignore_x, ignore_y : list of int, optional
92 | Rows and columns (respectively) to ignore in the contingency table.
93 | These are labels that are not counted when evaluating VI.
94 |
95 | Returns
96 | -------
97 | pxy : sparse.csc_matrix of float
98 | The normalized contingency table.
99 | px, py, hxgy, hygx, lpygx, lpxgy : np.ndarray of float
100 | The proportions of each label in `x` and `y` (`px`, `py`), the
101 | per-segment conditional entropies of `x` given `y` and vice-versa, the
102 | per-segment conditional probability p log p.
103 | """
104 | if y is not None:
105 | pxy = contingency_table(x, y, ignore_x, ignore_y)
106 | else:
107 | cont = x
108 | total = float(cont.sum())
109 | # normalize, since it is an identity op if already done
110 | pxy = cont / total
111 |
112 | # Calculate probabilities
113 | px = np.array(pxy.sum(axis=1)).ravel()
114 | py = np.array(pxy.sum(axis=0)).ravel()
115 | # Remove zero rows/cols
116 | nzx = px.nonzero()[0]
117 | nzy = py.nonzero()[0]
118 | nzpx = px[nzx]
119 | nzpy = py[nzy]
120 | nzpxy = pxy[nzx, :][:, nzy]
121 |
122 | # Calculate log conditional probabilities and entropies
123 | lpygx = np.zeros(np.shape(px))
124 | lpygx[nzx] = xlogx(divide_rows(nzpxy, nzpx)).sum(axis=1)
125 | # \sum_x{p_{y|x} \log{p_{y|x}}}
126 | hygx = -(px*lpygx) # \sum_x{p_x H(Y|X=x)} = H(Y|X)
127 |
128 | lpxgy = np.zeros(np.shape(py))
129 | lpxgy[nzy] = xlogx(divide_columns(nzpxy, nzpy)).sum(axis=0)
130 | hxgy = -(py*lpxgy)
131 |
132 | return [pxy] + list(map(np.asarray, [px, py, hxgy, hygx, lpygx, lpxgy]))
133 |
134 | def contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True):
135 | """Return the contingency table for all regions in matched segmentations.
136 |
137 | Parameters
138 | ----------
139 | seg : np.ndarray, int type, arbitrary shape
140 | A candidate segmentation.
141 | gt : np.ndarray, int type, same shape as `seg`
142 | The ground truth segmentation.
143 | ignore_seg : list of int, optional
144 | Values to ignore in `seg`. Voxels in `seg` having a value in this list
145 | will not contribute to the contingency table. (default: [0])
146 | ignore_gt : list of int, optional
147 | Values to ignore in `gt`. Voxels in `gt` having a value in this list
148 | will not contribute to the contingency table. (default: [0])
149 | norm : bool, optional
150 | Whether to normalize the table so that it sums to 1.
151 |
152 | Returns
153 | -------
154 | cont : scipy.sparse.csc_matrix
155 | A contingency table. `cont[i, j]` will equal the number of voxels
156 | labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels
157 | if `norm=True`.)
158 | """
159 | segr = seg.ravel()
160 | gtr = gt.ravel()
161 | ignored = np.zeros(segr.shape, np.bool)
162 | data = np.ones(len(gtr))
163 | for i in ignore_seg:
164 | ignored[segr == i] = True
165 | for j in ignore_gt:
166 | ignored[gtr == j] = True
167 | data[ignored] = 0
168 | cont = sparse.coo_matrix((data, (segr, gtr))).tocsc()
169 | if norm:
170 | cont /= float(cont.sum())
171 | return cont
172 |
173 | def divide_columns(matrix, row, in_place=False):
174 | """Divide each column of `matrix` by the corresponding element in `row`.
175 |
176 | The result is as follows: out[i, j] = matrix[i, j] / row[j]
177 |
178 | Parameters
179 | ----------
180 | matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)
181 | The input matrix.
182 | column : a 1D np.ndarray, shape (N,)
183 | The row dividing `matrix`.
184 | in_place : bool (optional, default False)
185 | Do the computation in-place.
186 |
187 | Returns
188 | -------
189 | out : same type as `matrix`
190 | The result of the row-wise division.
191 | """
192 | if in_place:
193 | out = matrix
194 | else:
195 | out = matrix.copy()
196 | if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:
197 | if type(out) == sparse.csc_matrix:
198 | convert_to_csc = True
199 | out = out.tocsr()
200 | else:
201 | convert_to_csc = False
202 | row_repeated = np.take(row, out.indices)
203 | nz = out.data.nonzero()
204 | out.data[nz] /= row_repeated[nz]
205 | if convert_to_csc:
206 | out = out.tocsc()
207 | else:
208 | out /= row[np.newaxis, :]
209 | return out
210 |
211 | def divide_rows(matrix, column, in_place=False):
212 | """Divide each row of `matrix` by the corresponding element in `column`.
213 |
214 | The result is as follows: out[i, j] = matrix[i, j] / column[i]
215 |
216 | Parameters
217 | ----------
218 | matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)
219 | The input matrix.
220 | column : a 1D np.ndarray, shape (M,)
221 | The column dividing `matrix`.
222 | in_place : bool (optional, default False)
223 | Do the computation in-place.
224 |
225 | Returns
226 | -------
227 | out : same type as `matrix`
228 | The result of the row-wise division.
229 | """
230 | if in_place:
231 | out = matrix
232 | else:
233 | out = matrix.copy()
234 | if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:
235 | if type(out) == sparse.csr_matrix:
236 | convert_to_csr = True
237 | out = out.tocsc()
238 | else:
239 | convert_to_csr = False
240 | column_repeated = np.take(column, out.indices)
241 | nz = out.data.nonzero()
242 | out.data[nz] /= column_repeated[nz]
243 | if convert_to_csr:
244 | out = out.tocsr()
245 | else:
246 | out /= column[:, np.newaxis]
247 | return out
248 |
249 | def xlogx(x, out=None, in_place=False):
250 | """Compute x * log_2(x).
251 |
252 | We define 0 * log_2(0) = 0
253 |
254 | Parameters
255 | ----------
256 | x : np.ndarray or scipy.sparse.csc_matrix or csr_matrix
257 | The input array.
258 | out : same type as x (optional)
259 | If provided, use this array/matrix for the result.
260 | in_place : bool (optional, default False)
261 | Operate directly on x.
262 |
263 | Returns
264 | -------
265 | y : same type as x
266 | Result of x * log_2(x).
267 | """
268 | if in_place:
269 | y = x
270 | elif out is None:
271 | y = x.copy()
272 | else:
273 | y = out
274 | if type(y) in [sparse.csc_matrix, sparse.csr_matrix]:
275 | z = y.data
276 | else:
277 | z = y
278 | nz = z.nonzero()
279 | z[nz] *= np.log2(z[nz])
280 | return y
281 |
--------------------------------------------------------------------------------
/img/*Filtered Mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/*Filtered Mask.png
--------------------------------------------------------------------------------
/img/*Visualize Boundary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/*Visualize Boundary.png
--------------------------------------------------------------------------------
/img/6p.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/6p.png
--------------------------------------------------------------------------------
/img/acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/acc.png
--------------------------------------------------------------------------------
/img/err.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/err.png
--------------------------------------------------------------------------------
/img/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/loss.png
--------------------------------------------------------------------------------
/img/n.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/img/res window.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/res window.png
--------------------------------------------------------------------------------
/img/rot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/celisun/Deep_learning_automated__neuron_segmentation/68bc9adb2e19ad2ca53ac72236176d274c07e3c7/img/rot.png
--------------------------------------------------------------------------------
/io/CremiFile.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from Annotations import *
4 | from Volume import *
5 |
6 | class CremiFile(object):
7 |
8 | def __init__(self, filename, mode):
9 |
10 | self.h5file = h5py.File(filename, mode)
11 |
12 | if mode == "w" or mode == "a":
13 | self.h5file["/"].attrs["file_format"] = "0.2"
14 |
15 | def __create_group(self, group):
16 |
17 | path = "/"
18 | for d in group.split("/"):
19 | path += d + "/"
20 | try:
21 | self.h5file.create_group(path)
22 | except ValueError:
23 | pass
24 |
25 | def __create_dataset(self, path, data, dtype, compression = None):
26 | """Wrapper around h5py's create_dataset. Creates the group, if not
27 | existing. Deletes a previous dataset, if existing and not compatible.
28 | Otherwise, replaces the dataset.
29 | """
30 |
31 | group = "/".join(path.split("/")[:-1])
32 | ds_name = path.split("/")[-1]
33 |
34 | self.__create_group(group)
35 |
36 | if ds_name in self.h5file[group]:
37 |
38 | ds = self.h5file[path]
39 | if ds.dtype == dtype and ds.shape == np.array(data).shape:
40 | print "overwriting existing dataset"
41 | self.h5file[path][:] = data[:]
42 | return
43 |
44 | del self.h5file[path]
45 |
46 | self.h5file.create_dataset(path, data=data, dtype=dtype, compression=compression)
47 |
48 | def write_volume(self, volume, ds_name, dtype):
49 |
50 | self.__create_dataset(ds_name, data=volume.data, dtype=dtype, compression="gzip")
51 | self.h5file[ds_name].attrs["resolution"] = volume.resolution
52 | if volume.comment is not None:
53 | self.h5file[ds_name].attrs["comment"] = str(volume.comment)
54 | if tuple(volume.offset) != (0.0, 0.0, 0.0):
55 | self.h5file[ds_name].attrs["offset"] = volume.offset
56 |
57 | def read_volume(self, ds_name):
58 |
59 | volume = Volume(self.h5file[ds_name])
60 |
61 | volume.resolution = self.h5file[ds_name].attrs["resolution"]
62 | if "offset" in self.h5file[ds_name].attrs:
63 | volume.offset = self.h5file[ds_name].attrs["offset"]
64 | if "comment" in self.h5file[ds_name].attrs:
65 | volume.comment = self.h5file[ds_name].attrs["comment"]
66 |
67 | return volume
68 |
69 | def __has_volume(self, ds_name):
70 |
71 | return ds_name in self.h5file
72 |
73 | def write_raw(self, raw):
74 | """Write a raw volume.
75 | """
76 |
77 | self.write_volume(raw, "/volumes/raw", np.uint8)
78 |
79 | def write_neuron_ids(self, neuron_ids):
80 | """Write a volume of segmented neurons.
81 | """
82 |
83 | self.write_volume(neuron_ids, "/volumes/labels/neuron_ids", np.uint64)
84 |
85 | def write_clefts(self, clefts):
86 | """Write a volume of segmented synaptic clefts.
87 | """
88 |
89 | self.write_volume(clefts, "/volumes/labels/clefts", np.uint64)
90 |
91 | def write_annotations(self, annotations):
92 | """Write pre- and post-synaptic site annotations.
93 | """
94 |
95 | if len(annotations.ids()) == 0:
96 | return
97 |
98 | self.__create_group("/annotations")
99 | if tuple(annotations.offset) != (0.0, 0.0, 0.0):
100 | self.h5file["/annotations"].attrs["offset"] = annotations.offset
101 |
102 | self.__create_dataset("/annotations/ids", data=annotations.ids(), dtype=np.uint64)
103 | self.__create_dataset("/annotations/types", data=annotations.types(), dtype=h5py.special_dtype(vlen=unicode), compression="gzip")
104 | self.__create_dataset("/annotations/locations", data=annotations.locations(), dtype=np.double)
105 |
106 | if len(annotations.comments) > 0:
107 | self.__create_dataset("/annotations/comments/target_ids", data=annotations.comments.keys(), dtype=np.uint64)
108 | self.__create_dataset("/annotations/comments/comments", data=annotations.comments.values(), dtype=h5py.special_dtype(vlen=unicode))
109 |
110 | if len(annotations.pre_post_partners) > 0:
111 | self.__create_dataset("/annotations/presynaptic_site/partners", data=annotations.pre_post_partners, dtype=np.uint64)
112 |
113 | def has_raw(self):
114 | """Check if this file contains a raw volume.
115 | """
116 | return self.__has_volume("/volumes/raw")
117 |
118 | def has_neuron_ids(self):
119 | """Check if this file contains neuron ids.
120 | """
121 | return self.__has_volume("/volumes/labels/neuron_ids")
122 |
123 | def has_neuron_ids_confidence(self):
124 | """Check if this file contains confidence information about neuron ids.
125 | """
126 | return self.__has_volume("/volumes/labels/neuron_ids_confidence")
127 |
128 | def has_clefts(self):
129 | """Check if this file contains synaptic clefts.
130 | """
131 | return self.__has_volume("/volumes/labels/clefts")
132 |
133 | def has_annotations(self):
134 | """Check if this file contains synaptic partner annotations.
135 | """
136 | return "/annotations" in self.h5file
137 |
138 | def has_segment_annotations(self):
139 | """Check if this file contains segment annotations.
140 | """
141 | return "/annotations" in self.h5file
142 |
143 | def read_raw(self):
144 | """Read the raw volume.
145 | Returns a Volume.
146 | """
147 |
148 | return self.read_volume("/volumes/raw")
149 |
150 | def read_neuron_ids(self):
151 | """Read the volume of segmented neurons.
152 | Returns a Volume.
153 | """
154 |
155 | return self.read_volume("/volumes/labels/neuron_ids")
156 |
157 | def read_neuron_ids_confidence(self):
158 | """Read confidence information about neuron ids.
159 | Returns Confidences.
160 | """
161 |
162 | confidences = Confidences(num_levels=2)
163 | if not self.has_neuron_ids_confidence():
164 | return confidences
165 |
166 | data = self.h5file["/volumes/labels/neuron_ids_confidence"]
167 | i = 0
168 | while i < len(data):
169 | level = data[i]
170 | i += 1
171 | num_ids = data[i]
172 | i += 1
173 | confidences.add_all(level, data[i:i+num_ids])
174 | i += num_ids
175 |
176 | return confidences
177 |
178 | def read_clefts(self):
179 | """Read the volume of segmented synaptic clefts.
180 | Returns a Volume.
181 | """
182 |
183 | return self.read_volume("/volumes/labels/clefts")
184 |
185 | def read_annotations(self):
186 | """Read pre- and post-synaptic site annotations.
187 | """
188 |
189 | annotations = Annotations()
190 |
191 | if not "/annotations" in self.h5file:
192 | return annotations
193 |
194 | offset = (0.0, 0.0, 0.0)
195 | if "offset" in self.h5file["/annotations"].attrs:
196 | offset = self.h5file["/annotations"].attrs["offset"]
197 | annotations.offset = offset
198 |
199 | ids = self.h5file["/annotations/ids"]
200 | types = self.h5file["/annotations/types"]
201 | locations = self.h5file["/annotations/locations"]
202 | for i in range(len(ids)):
203 | annotations.add_annotation(ids[i], types[i], locations[i])
204 |
205 | if "comments" in self.h5file["/annotations"]:
206 | ids = self.h5file["/annotations/comments/target_ids"]
207 | comments = self.h5file["/annotations/comments/comments"]
208 | for (id, comment) in zip(ids, comments):
209 | annotations.add_comment(id, comment)
210 |
211 | if "presynaptic_site/partners" in self.h5file["/annotations"]:
212 | pre_post = self.h5file["/annotations/presynaptic_site/partners"]
213 | for (pre, post) in pre_post:
214 | annotations.set_pre_post_partners(pre, post)
215 |
216 | return annotations
217 |
218 | def close(self):
219 |
220 | self.h5file.close()
221 |
--------------------------------------------------------------------------------
/io/__init__.py:
--------------------------------------------------------------------------------
1 | from CremiFile import *
2 |
--------------------------------------------------------------------------------
/models/Resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.legacy.nn as L
4 |
5 |
6 | def residualLayer2(conv2d1, norm2d, input, nChannels, nOutChannels=False, stride=1, conv2d2=False):
7 | """ Deep Residual Network
8 | https://github.com/gcr/torch-residual-networks
9 |
10 | giving stack of 2 layers as a block providing shortcuts."""
11 |
12 |
13 | if not nOutChannels:
14 | nOutChannels = nChannels
15 | if not conv2d2:
16 | conv2d2 = conv2d1
17 |
18 | # part 1: conv
19 | net = conv2d1(input)
20 | net = norm2d(net) # learnable parameters
21 | net = F.relu(net)
22 | net = conv2d2(net)
23 |
24 |
25 | # part 2: identity / skip connection
26 | skip = input
27 | if stride > 1: # optional downsampling
28 | skip = L.SpatialAveragePooling(1, 1, stride, stride).forward(skip.cpu().data)
29 | skip = Variable(skip.cuda())
30 | if nOutChannels > nChannels: # optional padding
31 | skip = L.Padding(1, (nOutChannels - nChannels), 3).forward(skip.cpu().data)
32 | skip = Variable(skip.cuda())
33 | elif nOutChannels < nChannels: # optional narrow
34 | skip = L.Narrow(2, 1, nOutChannels).forward(skip.cpu().data)
35 | skip = Variable(skip.cuda())
36 |
37 |
38 | # H(x) + x
39 | net = norm2d(net)
40 | #print "skip: " + str(skip.data.size())
41 | #print "net: " + str(net.data.size())
42 | net = torch.add(skip, net)
43 | # net = F.relu(net) # relu here ? see: http://www.gitxiv.com/comments/7rffyqcPLirEEsmpX
44 | #net = norm2d(net) # ==========================BN after add or before ???
45 |
46 | return net
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | class DeepResNet18(nn.Module):
55 | def __init__(self, D_out, kernel=3, padding=1):
56 | super(DeepResNet18, self).__init__()
57 | self.conv1 = nn.Conv2d(1, 32, kernel, padding=padding)
58 | self.conv2 = nn.Conv2d(32, 32, kernel, padding=padding)
59 | self.conv3 = nn.Conv2d(32, 64, kernel, stride =2, padding=padding)
60 | self.conv4 = nn.Conv2d(64, 64, kernel, padding=padding)
61 | self.conv5 = nn.Conv2d(64, 128, kernel, stride =2, padding=padding)
62 | self.conv6 = nn.Conv2d(128, 128, kernel, padding=padding)
63 | self.conv7 = nn.Conv2d(128, 256, kernel, stride =2, padding=padding)
64 | self.conv8 = nn.Conv2d(256, 256, kernel, padding=padding)
65 | self.norm1 = nn.BatchNorm2d(32)
66 | self.norm2 = nn.BatchNorm2d(64)
67 | self.norm3 = nn.BatchNorm2d(128)
68 | self.norm4 = nn.BatchNorm2d(256)
69 | self.linear = nn.Linear(256, 2)
70 |
71 | def forward(self, x):
72 | # ----> 1, 33, 33
73 | x = F.relu(self.norm1(self.conv1(x)))
74 |
75 | # ----> 32, 33, 33 First Group 2X
76 | for i in range(2): x = residualLayer2(self.conv2, self.norm1, x, 32)
77 |
78 | # ----> 64, 17, 17 Second Group 2X
79 | x = residualLayer2(self.conv3, self.norm2, x, 32, 64, stride=2, conv2d2=self.conv4)
80 | for i in range(2-1): x = residualLayer2(self.conv4, self.norm2, x, 64)
81 |
82 | # ----> 128, 9, 9 Third Group 2X
83 | x = residualLayer2(self.conv5, self.norm3, x, 64, 128, stride=2, conv2d2=self.conv6)
84 | for i in range(2-1): x = residualLayer2(self.conv6, self.norm3, x, 128)
85 |
86 | # ----> 256, 5, 5 Fourth Group 2X
87 | x = residualLayer2(self.conv7, self.norm4, x, 128, 256, stride=2, conv2d2=self.conv8)
88 | for i in range(2-1): x = residualLayer2(self.conv8, self.norm4, x, 256)
89 |
90 | # ----> 256, 5, 5 Pooling, Linear, Softmax
91 | x = nn.AvgPool2d(5,5)(x)
92 | x = x.view(-1, 256)
93 | x = self.linear(x)
94 |
95 |
96 | return x
97 |
98 |
99 | class DeepResNet34(nn.Module):
100 | def __init__(self, D_out, kernel=5, padding=2):
101 | super(DeepResNet34, self).__init__()
102 | self.conv1 = nn.Conv2d(1, 32, kernel, padding=padding)
103 | self.conv2 = nn.Conv2d(32, 32, kernel, padding=padding)
104 | self.conv3 = nn.Conv2d(32, 64, kernel, stride =2, padding=padding)
105 | self.conv4 = nn.Conv2d(64, 64, kernel, padding=padding)
106 | self.conv5 = nn.Conv2d(64, 128, kernel, stride =2, padding=padding)
107 | self.conv6 = nn.Conv2d(128, 128, kernel, padding=padding)
108 | self.conv7 = nn.Conv2d(128, 256, kernel, stride =2, padding=padding)
109 | self.conv8 = nn.Conv2d(256, 256, kernel, padding=padding)
110 | self.norm1 = nn.BatchNorm2d(32)
111 | self.norm2 = nn.BatchNorm2d(64)
112 | self.norm3 = nn.BatchNorm2d(128)
113 | self.norm4 = nn.BatchNorm2d(256)
114 | self.linear = nn.Linear(256, 2)
115 | self.pool = nn.MaxPool2d(3, stride=2)
116 |
117 | def forward(self, x):
118 | # ------> 65 * 65
119 | x = F.relu(self.norm1(self.conv1(x)))
120 | x = self.pool(x) # ================= max pooling ??
121 | # ------> 32 * 32
122 | for i in range(3): x = residualLayer2(self.conv2, self.norm1, x, 32)
123 | # ------> 32 * 32
124 | x = residualLayer2(self.conv3, self.norm2, x, 32, 64, stride=2, conv2d2=self.conv4)
125 | for i in range(4-1): x = residualLayer2(self.conv4, self.norm2, x, 64)
126 |
127 | x = residualLayer2(self.conv5, self.norm3, x, 64, 128, stride=2, conv2d2=self.conv6)
128 | for i in range(6-1): x = residualLayer2(self.conv6, self.norm3, x, 128)
129 |
130 | x = residualLayer2(self.conv7, self.norm4, x, 128, 256, stride=2, conv2d2=self.conv8)
131 | for i in range(3-1): x = residualLayer2(self.conv8, self.norm4, x, 256)
132 |
133 | x = nn.AvgPool2d(8,8)(x)
134 | x = x.view(-1, 256)
135 | x = self.linear(x)
136 |
137 |
138 | return x
139 |
--------------------------------------------------------------------------------
/models/Resnet_3.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.legacy.nn as L
4 |
5 |
6 |
7 | class DeepResNet101(nn.Module):
8 |
9 | """using bottle-neck building block """
10 |
11 |
12 | def __init__(self, D_out, kernel=7, padding=3):
13 | super(DeepResNet101, self).__init__()
14 | self.conv1 = nn.Conv2d(1, 32, kernel, padding=padding)
15 | self.conv2_ = nn.Conv2d(32, 32, 1, stride=2)
16 | self.conv2 = nn.Conv2d(128, 32, 5, padding=2)
17 | self.conv3 = nn.Conv2d(32, 32, kernel, padding=padding)
18 | self.conv4 = nn.Conv2d(32, 128, 5, padding=2)
19 | self.conv5_ = nn.Conv2d(128, 64, 1, stride=2)
20 | self.conv5 = nn.Conv2d(256, 64, 5, padding=2)
21 | self.conv6 = nn.Conv2d(64, 64, kernel, padding=padding)
22 | self.conv7 = nn.Conv2d(64, 256, 5, padding=2)
23 | self.conv8_ = nn.Conv2d(256, 128, 1, stride=2)
24 | self.conv8 = nn.Conv2d(512, 128, 5, padding=2)
25 | self.conv9 = nn.Conv2d(128, 128, kernel, padding=padding)
26 | self.conv10 = nn.Conv2d(128, 512, 5, padding=2)
27 | self.conv11_ = nn.Conv2d(512, 256, 1, stride=2)
28 | self.conv11 = nn.Conv2d(1024, 256, 5, padding=2)
29 | self.conv12 = nn.Conv2d(256, 256, kernel, padding=padding)
30 | self.conv13 = nn.Conv2d(256, 1024, 5, padding=2)
31 |
32 | self.norm1 = nn.BatchNorm2d(32)
33 | self.norm2 = nn.BatchNorm2d(64)
34 | self.norm3 = nn.BatchNorm2d(128)
35 | self.norm4 = nn.BatchNorm2d(256)
36 | self.norm5 = nn.BatchNorm2d(512)
37 | self.norm6 = nn.BatchNorm2d(1024)
38 | self.linear1 = nn.Linear(1024, 512)
39 | self.linear2 = nn.Linear(512, 128)
40 | self.linear3 = nn.Linear(128, 2)
41 | self.pool = nn.MaxPool2d(3, stride=2)
42 |
43 | def forward(self, x):
44 | # ----> 1, 129, 129
45 | x = F.relu(self.norm1(self.conv1(x)))
46 | x = self.pool(x) # max pooling ? 3x3 s=2
47 |
48 | # ----> 32, 64, 64 First Group
49 | x = residualLayer3(x, self.conv2_, self.conv3, self.conv4, self.norm1, self.norm3, 32, 32, 128, stride=2)
50 | for i in range(3-1): x = residualLayer3(x, self.conv2, self.conv3, self.conv4, self.norm1, self.norm3, 128, 32, 128)
51 |
52 | x = residualLayer3(x, self.conv5_, self.conv6, self.conv7, self.norm2, self.norm4, 128, 64, 256, stride=2)
53 | for i in range(8-1): x = residualLayer3(x, self.conv5, self.conv6, self.conv7, self.norm2, self.norm4, 256, 64, 256)
54 |
55 | x = residualLayer3(x, self.conv8_, self.conv9, self.conv10, self.norm3, self.norm5, 256, 128, 512, stride=2)
56 | for i in range(36-1): x = residualLayer3(x, self.conv8, self.conv9, self.conv10, self.norm3, self.norm5, 512, 128, 512)
57 |
58 | x = residualLayer3(x, self.conv11_, self.conv12, self.conv13, self.norm4, self.norm6, 512, 256, 1024, stride=2)
59 | for i in range(3-1): x = residualLayer3(x, self.conv11, self.conv12, self.conv13, self.norm4, self.norm6, 1024, 256, 1024)
60 |
61 | # ----> 1024, 4, 4 Pooling, Linear, Softmax
62 | x = nn.AvgPool2d(4,4)(x)
63 | x = x.view(-1, 1024)
64 | x = self.linear1(x)
65 | x = F.dropout(x) # ==============================
66 | x = self.linear2(x)
67 | x = self.linear3(x)
68 |
69 | return x
70 |
71 |
72 |
73 |
74 |
75 | class DeepResNet50(nn.Module):
76 |
77 | """using bottle-neck building block """
78 |
79 |
80 |
81 | def __init__(self, D_out, kernel=7, padding=3): #=============== conv window size 5/22 9:24pm
82 | super(DeepResNet50, self).__init__()
83 | self.conv1 = nn.Conv2d(1, 32, kernel, padding=padding)
84 | self.conv2_ = nn.Conv2d(32, 32, 1, stride=2)
85 | self.conv2 = nn.Conv2d(128, 32, 5, padding=2)
86 | self.conv3 = nn.Conv2d(32, 32, kernel, padding=padding)
87 | self.conv4 = nn.Conv2d(32, 128, 5, padding=2)
88 | self.conv5_ = nn.Conv2d(128, 64, 1, stride=2)
89 | self.conv5 = nn.Conv2d(256, 64, 5, padding=2)
90 | self.conv6 = nn.Conv2d(64, 64, kernel, padding=padding)
91 | self.conv7 = nn.Conv2d(64, 256, 5, padding=2)
92 | self.conv8_ = nn.Conv2d(256, 128, 1, stride=2)
93 | self.conv8 = nn.Conv2d(512, 128, 5, padding=2)
94 | self.conv9 = nn.Conv2d(128, 128, kernel, padding=padding)
95 | self.conv10 = nn.Conv2d(128, 512, 5, padding=2)
96 | self.conv11_ = nn.Conv2d(512, 256, 1, stride=2)
97 | self.conv11 = nn.Conv2d(1024, 256, 5, padding=2)
98 | self.conv12 = nn.Conv2d(256, 256, kernel, padding=padding)
99 | self.conv13 = nn.Conv2d(256, 1024, 5, padding=2)
100 |
101 | self.norm1 = nn.BatchNorm2d(32)
102 | self.norm2 = nn.BatchNorm2d(64)
103 | self.norm3 = nn.BatchNorm2d(128)
104 | self.norm4 = nn.BatchNorm2d(256)
105 | self.norm5 = nn.BatchNorm2d(512)
106 | self.norm6 = nn.BatchNorm2d(1024)
107 | self.linear1 = nn.Linear(1024, 512)
108 | self.linear2 = nn.Linear(512, 128)
109 | self.linear3 = nn.Linear(128, 2)
110 | self.pool = nn.MaxPool2d(3, stride=2)
111 |
112 |
113 | def forward(self, x):
114 | # ----> 1, 129, 129
115 | x = F.relu(self.norm1(self.conv1(x)))
116 | x = self.pool(x) # ================= max pooling ? better without here
117 |
118 |
119 | # ----> 32, 64, 64 First Group
120 | x = residualLayer3(x, self.conv2_, self.conv3, self.conv4, self.norm1, self.norm3, 32, 32, 128, stride=2)
121 | for i in range(3-1): x = residualLayer3(x, self.conv2, self.conv3, self.conv4, self.norm1, self.norm3, 128, 32, 128)
122 |
123 | # ----> 128, 32, 32 Second Group
124 | x = residualLayer3(x, self.conv5_, self.conv6, self.conv7, self.norm2, self.norm4, 128, 64, 256, stride=2)
125 | for i in range(4-1): x = residualLayer3(x, self.conv5, self.conv6, self.conv7, self.norm2, self.norm4, 256, 64, 256)
126 |
127 | # ----> 256, 16, 16 Third Group
128 | x = residualLayer3(x, self.conv8_, self.conv9, self.conv10, self.norm3, self.norm5, 256, 128, 512, stride=2)
129 | for i in range(6-1): x = residualLayer3(x, self.conv8, self.conv9, self.conv10, self.norm3, self.norm5, 512, 128, 512)
130 |
131 | # ----> 512, 8,8 Fourth Group
132 | x = residualLayer3(x, self.conv11_, self.conv12, self.conv13, self.norm4, self.norm6, 512, 256, 1024, stride=2)
133 | for i in range(3-1): x = residualLayer3(x, self.conv11, self.conv12, self.conv13, self.norm4, self.norm6, 1024, 256, 1024)
134 |
135 | # ----> 1024, 4, 4 Pooling, Linear, Softmax
136 | x = nn.AvgPool2d(4,4)(x)
137 | x = x.view(-1, 1024)
138 | x = self.linear1(x)
139 | x = F.dropout(x) # ==============================
140 | x = self.linear2(x)
141 | x = self.linear3(x)
142 |
143 | return x
144 |
145 |
146 | # stack of 3 layers providing shortcuts
147 | def residualLayer3(input, conv2d1, conv2d2, conv2d3, norm2d1, norm2d2, inChannels, hiddenChannels, outChannels, stride=1):
148 | net = conv2d1(input) # 1x1
149 | net = norm2d1(net)
150 | net = F.relu(net)
151 | net = F.dropout(net) # ========================== dropout within blocks ???? 8.21 9pm
152 |
153 | net = conv2d2(net) # kernel 3x3 or 5x5
154 | net = norm2d1(net)
155 | net = F.relu(net)
156 | net = F.dropout(net) # ========================== dropout ???? 8.21 9pm
157 |
158 | net = conv2d3(net) # 1x1
159 |
160 | skip = input
161 | #print "input: " + str(skip.data.size())
162 | if stride > 1:
163 | skip = L.SpatialAveragePooling(1, 1, stride, stride).forward(skip.cpu().data)
164 | skip = Variable(skip.cuda())
165 | if outChannels > inChannels:
166 | skip = L.Padding(1, (outChannels - inChannels), 3).forward(skip.cpu().data)
167 | skip = Variable(skip.cuda())
168 | elif outChannels < inChannels:
169 | skip = L.Narrow(2, 1, outChannels).forward(skip.cpu().data)
170 | skip = Variable(skip.cuda())
171 |
172 | #net = norm2d2(net)
173 | #print "skip: " + str(skip.data.size())
174 | #print "net: " + str(net.data.size())
175 | net = norm2d2(torch.add(skip, net)) # ==========================BN after add or before ???
176 | net = F.dropout(net) # ========================== dropout ????
177 | return net
178 |
179 |
180 |
--------------------------------------------------------------------------------
/type/Annotations.py:
--------------------------------------------------------------------------------
1 | class Annotations:
2 |
3 | def __init__(self, offset = (0.0, 0.0, 0.0)):
4 |
5 | self.__types = {}
6 | self.__locations = {}
7 | self.comments = {}
8 | self.pre_post_partners = []
9 | self.offset = offset
10 |
11 | def __check(self, id):
12 | if not id in self.__types.keys():
13 | raise "there is no annotation with id " + str(id)
14 |
15 | def add_annotation(self, id, type, location):
16 | """Add a new annotation.
17 |
18 | Parameters
19 | ----------
20 |
21 | id: int
22 | The ID of the new annotation.
23 |
24 | type: string
25 | A string denoting the type of the annotation. Use
26 | "presynaptic_site" or "postsynaptic_site" for pre- and
27 | post-synaptic annotations, respectively.
28 |
29 | location: tuple, float
30 | The location of the annotation, relative to the offset.
31 | """
32 |
33 | self.__types[id] = type.encode('utf8')
34 | self.__locations[id] = location
35 |
36 | def add_comment(self, id, comment):
37 | """Add a comment to an annotation.
38 | """
39 |
40 | self.__check(id)
41 | self.comments[id] = comment.encode('utf8')
42 |
43 | def set_pre_post_partners(self, pre_id, post_id):
44 | """Mark two annotations as pre- and post-synaptic partners.
45 | """
46 |
47 | self.__check(pre_id)
48 | self.__check(post_id)
49 | self.pre_post_partners.append((pre_id, post_id))
50 |
51 | def ids(self):
52 | """Get the ids of all annotations.
53 | """
54 |
55 | return self.__types.keys()
56 |
57 | def types(self):
58 | """Get the types of all annotations.
59 | """
60 |
61 | return self.__types.values()
62 |
63 | def locations(self):
64 | """Get the locations of all annotations. Locations are in world units,
65 | relative to the offset.
66 | """
67 |
68 | return self.__locations.values()
69 |
70 | def get_annotation(self, id):
71 | """Get the type and location of an annotation by its id.
72 | """
73 |
74 | self.__check(id)
75 | return (self.__types[id], self.__locations[id])
76 |
--------------------------------------------------------------------------------
/type/Volume.py:
--------------------------------------------------------------------------------
1 | class Volume():
2 |
3 | def __init__(self, data, resolution = (1.0, 1.0, 1.0), offset = (0.0, 0.0, 0.0), comment = ""):
4 | self.data = data
5 | self.resolution = resolution
6 | self.offset = offset
7 | self.comment = comment
8 |
9 | def __getitem__(self, location):
10 | """Get the closest value of this volume to the given location. The
11 | location is in world units, relative to the volumes offset.
12 |
13 | This method takes into account the resolution of the volume. An
14 | IndexError exception is raised if the location is not contained in this
15 | volume.
16 |
17 | To access the raw pixel values, use the `data` attribute.
18 | """
19 |
20 | i = tuple([ round(location[d]/self.resolution[d]) for d in range(len(location)) ])
21 |
22 | if min(i) >= 0:
23 | try:
24 | return self.data[i]
25 | except IndexError as e:
26 | raise IndexError("location " + str(location) + " does not lie inside volume: " + str(e))
27 |
28 | raise IndexError("location " + str(location) + " does not lie inside volume")
29 |
30 | def __setitem__(self, location, value):
31 | """Set the closest value of this volume to the given location. The
32 | location is in world units, relative to the volumes offset.
33 |
34 | This method takes into account the resolution of the volume. An
35 | IndexError exception is raised if the location is not contained in this
36 | volume.
37 |
38 | To access the raw pixel values, use the `data` attribute.
39 | """
40 |
41 | i = tuple([ round(location[d]/self.resolution[d]) for d in range(len(location)) ])
42 |
43 | if min(i) >= 0:
44 | try:
45 | self.data[i] = value
46 | return
47 | except IndexError as e:
48 | raise IndexError("location " + str(location) + " does not lie inside volume: " + str(e))
49 |
50 | raise IndexError("location " + str(location) + " does not lie inside volume")
51 |
52 |
53 |
--------------------------------------------------------------------------------