├── .gitignore ├── README.md ├── hparams.json ├── requirements.txt └── src ├── __init__.py ├── data.py ├── evaluation.py ├── infer.py ├── loss.py ├── model.py ├── prepare_data.py ├── server.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | log/* 7 | logs/* 8 | logs_old/* 9 | dump/* 10 | data/images/* 11 | data/labels/* 12 | data/masks/* 13 | data/tf/* 14 | data/slices/* 15 | data/* 16 | #src/train.py 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | bin/projects 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NCCNet 2 | 3 | Template matching by normalized cross correlation (NCC) is widely used for finding image correspondences. We improve the robustness of this algorithm by transforming image features with "siamese" convolutional networks trained to maximize the contrast between NCC values of true and false matches. Our main technical contribution is a weakly supervised learning algorithm for training siamese networks. Unlike fully supervised approaches to metric learning, our method can improve upon vanilla NCC without being given locations of true matches during training. The improvement is quantified using patches of brain images from serial section electron microscopy. Relative to a parameter-tuned bandpass filter, siamese convolutional networks significantly reduce false matches. The improved accuracy of our method could be essential for connectomics, because emerging petascale datasets may require billions of template matches during assembly. Our method is also expected to generalize to other computer vision applications that use template matching to find image correspondences. 4 | 5 | The code depends on Cavelab (https://github.com/Loqsh/cavelab). 6 | 7 | ## Usage 8 | Start docker environment 9 | ``` 10 | git clone https://github.com/seung-lab/NCCNet 11 | nvidia-docker run -it --net=host \ 12 | -v $(pwd)/NCCNet:/projects/NCCNet \ 13 | davidbun/cavelab:latest bash 14 | ``` 15 | 16 | To generate data using Cloud-volume start docker 17 | ``` 18 | nvidia-docker run -it --net=host \ 19 | -v $(pwd)/NCCNet:/projects/NCCNet \ 20 | -v /usr/people/$USER/.cloudvolume/secrets/:/root/.cloudvolume/secrets/ \ 21 | davidbun/cavelab:latest bash 22 | ``` 23 | 24 | Then CD to the folder and install additional dependencies 25 | ``` 26 | cd /NCCNet 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Data Collection 31 | For training the model you will require to have pairs of image-templates defined by TFRecords files in `/data/tf/train.tfrecords`. If you are part of seunglab you can find example file here `seungmount/research/davit/NCCNet/data/tf/bad_trainset_24000_612_324.tfrecords` 32 | 33 | To prepare your own data, open `hparams.json` file and modify section `"preprocessing"`: 34 | `cloud_src`, `cloud_mip`, and `features`. Most important is to define image feature size. `in_width` is the width the image is generated. `width` is for training cropping only. 35 | 36 | ``` 37 | "features": { 38 | "data": { 39 | "search_raw": {"in_width": 512, "width": 384, "depth":1}, 40 | "template_raw": {"in_width": 256, "width": 128, "depth":1} 41 | }, "type": "dict", "description": "Structure of input features"} 42 | ``` 43 | 44 | Run the following command to start data collection process. There are two additional parameters you would like to adjust. 45 | ``` 46 | mkdir data; mkdir data/tf; mkdir dump; 47 | python src/prepare_data.py 48 | ``` 49 | To follow the process of data collection, `dump/{image, ncc, small_template, template }.jpg` files will get updated 50 | 51 | ## Train 52 | To train the model defined in `hparams.json` run the following code. Parameters in the json file are self-explanatory. 53 | 54 | ``` 55 | mkdir logs 56 | python src/train.py 57 | ``` 58 | 59 | Logs and trained models will appear in `logs/` folder. Please change name in `hparams.json` for each experiment to log with different names, otherwise it will throw an error. `NameError: name 'raw_input' is not defined` 60 | 61 | 62 | ## Log 63 | You log the training process by running as a background process and go to `localhost:6006` 64 | 65 | ``` 66 | tensorboard --logdir=logs 67 | ``` 68 | 69 | ## Inference 70 | To run inference please inspect `infer.py` and make sure all arguments are correct 71 | 72 | ``` 73 | python infer.py A B 74 | ``` 75 | where A is the first slice and B is the second. You will need to put correct bounding box dimensions in the `infer.py` 76 | 77 | 78 | ## Cite 79 | 80 | ``` 81 | @article{buniatyan2017deep, 82 | title={Deep Learning Improves Template Matching by Normalized Cross Correlation}, 83 | author={Buniatyan, Davit and Macrina, Thomas and Ih, Dodam and Zung, Jonathan and Seung, H Sebastian}, 84 | journal={arXiv preprint arXiv:1705.08593}, 85 | year={2017} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /hparams.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "default": { 4 | "name": {"data":"training_name_10", "type": "str", "description": "Name of the experiment"}, 5 | "hypothesis": {"data": "random initialization", "type":"str", "description": "What should you figure out running this experiment"}, 6 | "identity_init" :{"data": false, "type": "bool", "description":"Initialize as Identity"}, 7 | "resize": {"data": 4, "type": "int", "description": "Resize images"}, 8 | "dilation_rate": {"data": 1, "type": "int", "description": "Global Dilation rate"}, 9 | "aligned": {"data": 0, "type": "int", "description": "Define the data type" }, 10 | "linear": {"data": false, "type": "bool", "description": "Decide if the convolution is linear or not" }, 11 | "output_layer": {"data": 1, "type": "int", "description": "Decide output channel" }, 12 | 13 | "radius": {"data": 5, "type": "int", "description": "Maximum radius for finding second" }, 14 | "mean_over_batch": {"data": true, "type": "bool", "description": "Take the mean over the batch otherwise min" }, 15 | "lambd": {"data": -0.5, "type": "float", "description": "Lambda for mixed loss" }, 16 | "eps": {"data": 0.0001, "type": "float", "description": "small number" }, 17 | "loss_type": {"data": "dist", "type": "str", "description": "Define the loss format either 'dist' or 'ratio' " }, 18 | "loss_form": {"data": "minus", "type": "str", "description": "Define the loss formulae to minimize over {'minus', 'inverse', 'log'}" }, 19 | "softmax": {"data": false, "type": "bool", "description": "Use Softmax"}, 20 | 21 | "train_file": {"data":["data/tf/train.tfrecords"], "type": "str", "description": "Training dataset"}, 22 | "test_file": {"data":["data/tf/bad_trainset_24000_612_324.tfrecords"], "type": "str", "description": "Testing dataset"}, 23 | 24 | "learning_rate": {"data": 0.00001, "type": "float", "description": "Learning rate"}, 25 | "momentum": {"data": 0.9, "type": "float", "description": "Learning momentum"}, 26 | "decay": {"data": 0.75, "type": "float", "description": "Learning momentum"}, 27 | "decay_steps": {"data": 1000, "type": "int", "description": "Learning momentum"}, 28 | "steps": {"data": 200000, "type": "int", "description": "Number of steps to complete the training"}, 29 | "batch_size": {"data": 4, "type": "int", "description": "Batch size during training"}, 30 | "epoch_size": {"data": 16, "type": "int", "description": "Epoch size during training"}, 31 | "eval_batch_size": {"data": 2, "type": "int", "description": "Batch size during evaluation"}, 32 | "optimizer": {"data": "Adam", "type": "str", "description": "Optimizer Name (Adam, Adagrad, etc)"}, 33 | "loglevel": {"data": 50, "type": "int", "description": "Tensorflow log level"}, 34 | "output_layer": {"data": 8, "type": "int", "description": "output layer of UNET"}, 35 | "log_iterations": {"data": 100, "type": "int", "description": "Tensorflow log level"}, 36 | "eval_iterations": {"data": 20000, "type": "int", "description": "Tensorflow log level"}, 37 | "resize_conv": {"data": true, "type":"int","description": "use resize convolutions otherwise deconvolutions"}, 38 | 39 | 40 | "kernels_shape": { 41 | "data": [[3,3,1,32], 42 | [3,3,32,64], 43 | [3,3,64,128], 44 | [3,3,128,256]], 45 | 46 | "type": "array of int", "description": "Kernel description"}, 47 | 48 | "testing_steps": {"data": 100, "type": "int", "description": "testing_steps"}, 49 | 50 | "features": { 51 | "data": { 52 | "search_raw": {"in_width": 512, "width": 384, "depth": 1}, 53 | "template_raw": {"in_width": 512, "width": 128, "depth": 1} 54 | }, "type": "dict", "description": "Structure of input features"}, 55 | 56 | "augmentation":{ 57 | "data": { 58 | "flipping": true, 59 | "random_brightness": false, 60 | "random_elastic_transform": false 61 | }, "type": "dict", "description": "augmetation"} 62 | }, 63 | 64 | "preprocessing": { 65 | "tfrecord_train_dest": {"data":"data/tf/train.tfrecords", "type": "str", "description": "Destination of training set"}, 66 | 67 | "cloud_src": {"data":"gs://neuroglancer/pinky100_v0/image_single_slices/", "type": "str", "description": "Cloud directory"}, 68 | "cloud_mip": {"data": 2, "type": "str", "description": "MIP level for neuroglancer"}, 69 | "threads": {"data":1, "type": "str", "description": "Number of threads for data collection"}, 70 | "width": {"data":256, "type": "str", "description": "width of the image"}, 71 | "scale": {"data":1, "type": "str", "description": "scaling factor"}, 72 | "features": { 73 | "data": { 74 | "search_raw": {"in_width": 512, "width": 384, "depth":1}, 75 | "template_raw": {"in_width": 256, "width": 128, "depth":1} 76 | }, "type": "dict", "description": "Structure of input features"}, 77 | "samples":{"data":10000, "description": "Number of training data samples"}, 78 | "r_delta":{"data": 0.2, "description": "R_delta filter for collecting data"} 79 | }, 80 | "evaluation":{ 81 | "batch_size": {"data": 1, "type": "int", "description": "Batch size during training"}, 82 | "train_file": {"data":["data/tf/imagenet_hard_mined.tfrecords"], "type": "str", "description": "Testing dataset"}, 83 | "features": { 84 | "data": { 85 | "search_raw": {"in_width": 512, "width": 384, "depth": 1}, 86 | "template_raw": {"in_width": 256, "width": 384, "depth": 1} 87 | }, "type": "dict", "description": "Structure of input features"}, 88 | 89 | "augmentation":{ 90 | "data": { 91 | "flipping": false, 92 | "random_brightness": false, 93 | "random_elastic_transform": false 94 | }, "type": "dict", "description": "augmentation"} 95 | 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | xmltodict 2 | cloud-volume==0.22 3 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seung-lab/NCCNet/ec1b61d43156a6e8f1fd2fbbc435ed169c738200/src/__init__.py -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cavelab as cl 3 | 4 | # Data preprocessing 5 | similar = True 6 | class Data(): 7 | def __init__(self, hparams, random=True): 8 | self.batch_size = hparams.batch_size 9 | self.data = cl.tfdata(hparams.train_file, 10 | random_order = random, 11 | batch_size = self.batch_size, 12 | features=hparams.features, 13 | flipping=hparams.augmentation["flipping"], 14 | random_brightness=hparams.augmentation["random_brightness"], 15 | random_elastic_transform=hparams.augmentation["random_elastic_transform"]) 16 | self.similar = True 17 | 18 | def get_batch(self, switching=True): 19 | 20 | image, template = self.data.get_batch() 21 | 22 | #image = input['search_raw'] 23 | #template = input['template_raw'] 24 | #print(input) 25 | #if not self.check_validity(image, template): 26 | # return self.get_batch() 27 | 28 | # Similar or Disimilar 29 | if not self.similar and switching: 30 | search_space, template = self.dissimilar(image, template) 31 | label = np.ones((self.batch_size),dtype=np.float) if self.similar else -1*np.ones((self.batch_size),dtype=np.float) 32 | 33 | if self.similar: 34 | label = 2*(np.random.rand(self.batch_size)>0)-1 #FIXME add this variable - Probability of misinterpreting the label 35 | self.similar = not self.similar 36 | 37 | image = np.expand_dims(image, -1) 38 | template = np.expand_dims(template, -1) 39 | return image, template, label 40 | 41 | def dissimilar(self, images, templates): 42 | length = templates.shape[0]-1 43 | temp = np.array(templates[0]) 44 | templates[0:length] = templates[1:length+1] 45 | templates[length] = temp 46 | return images, templates 47 | 48 | def check_validity(self, image, template): 49 | t = np.array(template.shape) 50 | if np.any(np.sum(image<0.01, axis=(1,2)) >= t[1]*t[2]) or image.shape[0]20: 95 | #print('err') 96 | #print(ncc[pos[0], pos[1]],ncc[locs[j,0], locs[j,1]]) 97 | cl.visual.save(imgs[j,:,:,0], "dump/eval/image") 98 | cl.visual.save(tmp[:,:,0], "dump/eval/template") 99 | cl.visual.save(ncc, "dump/eval/ncc") 100 | person = input('next: ') 101 | err += 1 102 | return err 103 | 104 | def process_ncc(image): 105 | #step = 384 106 | #W = step*4 107 | #H = step*2 108 | if image.shape[-1]==1: 109 | image = np.repeat(image, 3, axis=2) 110 | image = image/255.0 111 | image_new = infer.process(image) 112 | #image_new = cl.image_processing.read_without_borders_3d(np.expand_dims(image, -1), 113 | # [0, 0, 0], 114 | # [W, H, 3]) 115 | 116 | return image_new 117 | 118 | def process(imgs, tmps): 119 | #print(imgs.mean()) 120 | #print(tmps.mean()) 121 | #print(imgs[:10,:10]) 122 | #exit() 123 | imgs = imgs/(255.0) 124 | tmps = tmps/(255.0) 125 | 126 | imgs_new = model.process({features["inputs"]: imgs}, [features["outputs"]])[0] 127 | tmps_new = model.process({features["inputs"]: tmps}, [features["outputs"]])[0] 128 | imgs_new = imgs_new[:,:s_width, :s_width, :] 129 | tmps_new = tmps_new[:,:s_width, :s_width, :] 130 | 131 | return imgs_new, tmps_new 132 | 133 | def evaluate(model_to_use=False): 134 | err = np.zeros((size, 4)) 135 | count = 0 136 | for i in range(size): 137 | # Load data 138 | imgs, tmps, locs = get_batch() 139 | 140 | #err[i, 0] += get_wrong_matches(imgs, imgs, locs) 141 | err[i, 1] += get_wrong_matches(imgs, tmps, locs) 142 | 143 | imgs_t, tmps_t = process(imgs, tmps) 144 | 145 | #err[i, 2] += get_wrong_matches(imgs_t, imgs_t, locs) 146 | err[i, 3] += get_wrong_matches(imgs_t, tmps_t, locs) 147 | #exit() 148 | count += BATCH_SIZE 149 | 150 | print(i, err[:i, :].mean(axis=0)/BATCH_SIZE) 151 | 152 | err = err/BATCH_SIZE 153 | mean = err.mean(axis=0) # per batch 154 | std = err[250:, :].std(axis=0) # per batch 155 | 156 | #print('raw_self', mean[0], "+-"+str(std[0]/np.sqrt(1250))) 157 | print('raw', mean[1],"+-"+str(std[1]/np.sqrt(1250))) 158 | #print('model_self', mean[2],"+-"+str(std[2]/np.sqrt(1250))) 159 | print('NCCNet', mean[3],"+-"+str(std[3]/np.sqrt(1250))) 160 | 161 | if __name__ == '__main__': 162 | model = cl.Graph(directory=MODEL_DIR) 163 | evaluate(True) 164 | else: 165 | infer = cl.Infer(batch_size = 4, 166 | width = 512, 167 | n_threads = 8, 168 | model_directory = MODEL_DIR, 169 | cloud_volume = False, 170 | features=features, 171 | voxel_offset = (0, 0), 172 | crop_size=60, 173 | output_dim=output_dim)# FIXME Change according to the model 174 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | from cavelab import Infer, Cloud, h5data 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | # "data/tf/pinky40/train_v3.tfrecords" 7 | 8 | 9 | CLOUD_SRC = 'gs://neuroglancer/drosophila_v0/image_v14_single_slices' 10 | CLOUD_DST = 'gs://neuroglancer/drosophila_v0/image_v14_single_slices/nccnet/' 11 | 12 | MODEL_DIR = 'logs/NCCNet_flyem/' 13 | features = { "inputs":"input/image:0", "outputs": "Pred/image_transformed:0"} 14 | features = { "inputs":"image:0", "outputs": "output/image:0"} 15 | 16 | input_volume = Cloud(CLOUD_SRC, mip=2, cache=False) 17 | output_volume = Cloud(CLOUD_DST, mip=2, cache=False) 18 | #h5 = h5data('data/slices/') 19 | #output_volume = h5.create_dataset('prealigned_2', shape=input_volume.shape[0:2], dtype='uint8') 20 | 21 | infer = Infer(batch_size = 8, 22 | width = 384, 23 | n_threads = 8, 24 | model_directory = MODEL_DIR, 25 | cloud_volume = True, 26 | features=features, 27 | voxel_offset = (0, 0), 28 | crop_size=60) 29 | 30 | #input_volume.vol.flush_cache() 31 | 32 | locations = [[(0,0,i), (65536,43744,i)] for i in xrange(int(sys.argv[1]), int(sys.argv[2]))] 33 | infer.process_by_superbatch(input_volume, output_volume, locations) 34 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cavelab as cl 3 | import numpy as np 4 | 5 | def loss(p, similar, radius=10, eps = 0.001, name='loss'): 6 | 7 | #Get maximum p and mask the point 8 | p_max = tf.reduce_max(p, axis=[1,2], keep_dims=True) 9 | mask_p = tf.cast(p>p_max-tf.constant(eps), tf.float32) 10 | 11 | #Design the shape of the mask 12 | p_shape = tf.shape(p) 13 | mask = tf.ones([radius*2, radius*2, p_shape[3]], tf.float32) 14 | 15 | mask_p = tf.nn.dilation2d(mask_p, mask, [1,1,1,1], [1,1,1,1], 'SAME') 16 | mask_p = tf.to_float(mask_p)<=tf.constant(1, dtype='float32') 17 | mask_p = tf.cast(mask_p, dtype='float32') 18 | 19 | # Take care about second distance 20 | p_2 = tf.multiply(mask_p,p) 21 | p_max_2 = tf.reduce_max(p_2, axis=[1,2], keep_dims=True) 22 | 23 | mask_p = tf.multiply(mask_p, p) 24 | l = -(p_max-p_max_2) 25 | 26 | #Add more sanity checks 27 | #to_update = tf.logical_and(tf.reduce_all(tf.is_finite(l)), tf.greater(tf.reduce_max(l), -0.1)) 28 | 29 | full_loss = tf.identity(l, name='full_loss') 30 | 31 | l = tf.where(tf.reduce_max(similar)>0, l, tf.abs(p_max), name=None) 32 | l = tf.reduce_mean(l) 33 | 34 | p_max = tf.reduce_mean(p_max) 35 | p_max_2 = tf.reduce_mean(p_max_2) 36 | 37 | # Summarize into tensorboard 38 | with tf.name_scope(name): 39 | tf.summary.scalar('loss', l) 40 | tf.summary.scalar('distance', p_max-p_max_2) 41 | tf.summary.scalar('max', p_max) 42 | tf.summary.scalar('second_max', p_max_2) 43 | 44 | return l 45 | 46 | 47 | def binary_entropy(correlogram, similar, hparams, name="loss", alpha=10.0, threshold=0.05): 48 | correlogram_power = tf.exp(alpha*correlogram) 49 | smoothmax = tf.multiply(correlogram, correlogram_power)/tf.reduce_sum(correlogram_power) 50 | smoothmax = tf.cast(smoothmax0.0, 52 | smoothmax, 53 | tf.zeros(tf.shape(smoothmax), dtype=tf.int32), name=None) 54 | correlogram, kernel, b = cl.tf.layers.conv_one_by_one(correlogram, 2) 55 | print(correlogram.get_shape()) 56 | print(smoothmax.get_shape()) 57 | loss, labels = cl.tf.loss.soft_cross_entropy(correlogram, smoothmax) 58 | tf.summary.scalar('loss',loss) 59 | return loss 60 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import cavelab as cl 2 | import loss 3 | import tensorflow as tf 4 | 5 | def build(hparams): 6 | 7 | g = cl.tf.Graph() 8 | 9 | # Define inputs 10 | g.image = tf.placeholder(tf.float32, shape=[hparams.batch_size, hparams.features['search_raw']['width'], hparams.features['search_raw']['width'],hparams.features['search_raw']['depth']], name='image') 11 | g.template = tf.placeholder(tf.float32, shape=[hparams.batch_size, hparams.features['template_raw']['width'], hparams.features['template_raw']['width'],hparams.features['search_raw']['depth']], name='template') 12 | g.similar = tf.placeholder(tf.float32, shape=[hparams.batch_size], name='similarity') 13 | g.crop_coef = tf.placeholder(tf.int32, shape=[], name='crop_coef') 14 | 15 | # Add to metrics 16 | cl.tf.metrics.image_summary(g.image[:,:,:,0], 'input/image') 17 | cl.tf.metrics.image_summary(g.template[:,:,:,0], 'input/template') 18 | 19 | # Model 20 | g.xs, g.ys = cl.models.SiameseFusionNet(g.image, g.template, hparams.resize_conv, hparams.kernels_shape) 21 | g.ps = [] 22 | ls = 0 23 | levels = len(g.xs) 24 | for i in range(levels): 25 | shape = hparams.kernels_shape[i] 26 | shape = [shape[0], shape[1], shape[3], hparams.output_layer] 27 | 28 | g.xs[i], g.ys[i] = cl.tf.layers.conv_block_dual(g.xs[i], g.ys[i], shape, activation = tf.tanh) 29 | g.xs[i], g.ys[i] = tf.identity(g.xs[i], name='output/image_'+str(i)), tf.identity(g.ys[i], name='output/template_'+str(i)) 30 | 31 | crop = (levels-i)*10*g.crop_coef 32 | t_crop = (levels-i)*5*g.crop_coef 33 | 34 | cl.tf.metrics.scalar(crop, name='crop/image') 35 | cl.tf.metrics.scalar(t_crop, name='crop/template') 36 | 37 | ishp = g.xs[i].get_shape() 38 | tshp = g.ys[i].get_shape() 39 | 40 | swidth = int(hparams.features['search_raw']['width']/(2**i)) 41 | twidth = int(hparams.features['template_raw']['width']/(2**i)) 42 | 43 | g.cnn_image_big = g.xs[i][:, crop:swidth-crop, crop:swidth-crop, :] 44 | g.cnn_templ_big = g.ys[i][:, t_crop:twidth-t_crop, t_crop:twidth-t_crop, :] 45 | 46 | # Save to Tensorboard 47 | cl.tf.metrics.image_summary(tf.squeeze(g.xs[i][:, :, :,0]), 'pred/image_mip='+str(i)) 48 | cl.tf.metrics.image_summary(tf.squeeze(g.ys[i][:, :, :,0]), 'pred/template_mip='+str(i)) 49 | 50 | if hparams.output_layer>1: 51 | 52 | image_feature = tf.transpose(tf.squeeze(g.xs[i][0, :, :, :]), perm=[2,0,1]) 53 | temp_feature = tf.transpose(tf.squeeze(g.ys[i][0, :, :, :]), perm=[2,0,1]) 54 | 55 | cl.tf.metrics.image_summary(image_feature, 'features/image_mip='+str(i)) 56 | cl.tf.metrics.image_summary(temp_feature, 'features/template_mip='+str(i)) 57 | 58 | # Loss 59 | g.p = cl.tf.layers.batch_normxcorr(g.cnn_image_big, g.cnn_templ_big) 60 | g.p = tf.reduce_mean(g.p, axis=3, keep_dims=True) 61 | 62 | #Resize 63 | if i>0 and False: 64 | shape = tf.shape(g.ps[0])#.get_shape() 65 | g.p = tf.image.resize_images(g.p, size=[264, 264], method=1, align_corners=True) 66 | g.ps.append(g.p) 67 | 68 | #g.p = tf.add_n(g.ps)/3 69 | cl.tf.metrics.image_summary(g.p[:,:,:,0], 'pred/normxcorr_large_mip='+str(i), resize=False) 70 | l = loss.loss(g.p, g.similar, radius=hparams.radius, name='loss_'+str(i)) 71 | ls += l 72 | 73 | # Output 74 | g.l = ls/3 75 | tf.summary.scalar('loss/loss', g.l) 76 | g.train_step = tf.train.AdamOptimizer(hparams.learning_rate).minimize(g.l) 77 | 78 | return g 79 | 80 | #ncc = cl.tf.base_layers.batch_normxcorr(g.cnn_image, g.cnn_templ) 81 | #ncc = tf.reduce_mean(ncc, axis=3, keep_dims=False) 82 | #cl.tf.metrics.image_summary(ncc, 'pred/normxcorr_large', resize=False) 83 | -------------------------------------------------------------------------------- /src/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts EM data to TFRecords file format with Example protos.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import os 23 | import sys 24 | 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | import cavelab as cl 29 | from os import listdir 30 | from os.path import isfile, join 31 | import random 32 | import scipy 33 | import xmltodict 34 | 35 | 36 | FLAGS = None 37 | 38 | #Load hyperparams 39 | hparams = cl.hparams(name="preprocessing") 40 | 41 | cloud = cl.Cloud(hparams.cloud_src, mip=hparams.cloud_mip, cache=False, bounded = False, fill_missing=True) 42 | shape = cloud.shape 43 | shape[2] = 800 44 | downsample = hparams.scale 45 | 46 | base = "/usr/people/davitb/.kaggle/competitions/imagenet-object-detection-from-video-challenge/ILSVRC/" 47 | 48 | suffix = "VID/train/" 49 | 50 | def get_bbox(path_to_xml): 51 | with open(path_to_xml) as fd: 52 | doc = xmltodict.parse(fd.read()) 53 | #import pdb; pdb.set_trace() 54 | bbox = doc['annotation']['object'] 55 | #print(bbox) 56 | if isinstance(bbox, list): 57 | bbox = bbox[0] 58 | return bbox['bndbox'] 59 | 60 | def get_random_path(suffix, path, n): 61 | if n==0: 62 | return path 63 | 64 | filelist = sorted([f for f in listdir(suffix) if not isfile(join(suffix, f))]) 65 | id = random.randint(0,len(filelist)-1) 66 | suffix = os.path.join(suffix, filelist[id]) 67 | path = os.path.join(path, filelist[id]) 68 | return get_random_path(suffix, path, n-1) 69 | 70 | def get_image_data(base, suffix, img_size=512, tmp_size=256): 71 | img_path = os.path.join(base, "Data") 72 | img_path = os.path.join(img_path, suffix) 73 | box_path = os.path.join(base, "Annotations") 74 | box_path = os.path.join(box_path, suffix) 75 | 76 | sample_path = get_random_path(img_path,"", 2) 77 | img_path = os.path.join(img_path, sample_path) 78 | box_path = os.path.join(box_path, sample_path) 79 | 80 | 81 | filelist = sorted([f for f in listdir(img_path) if isfile(join(img_path, f))]) 82 | xmllist = sorted([f for f in listdir(box_path) if isfile(join(box_path, f))]) 83 | 84 | id = random.randint(0,len(filelist)-6) 85 | #print(len(xmllist), len(filelist)) 86 | next = random.randint(1,3) 87 | image_xml = os.path.join(box_path, xmllist[id]) 88 | template_xml = os.path.join(box_path, xmllist[id+next]) 89 | image_src = os.path.join(img_path, filelist[id]) 90 | template_src = os.path.join(img_path, filelist[id+next]) 91 | 92 | image_bbox = get_bbox(image_xml) 93 | template_bbox = get_bbox(template_xml) 94 | 95 | img_cntr = [int((int(image_bbox['ymax'])+int(image_bbox['ymin']))/2), 96 | int((int(image_bbox['xmax'])+int(image_bbox['xmin']))/2)] 97 | 98 | tmp_cntr = [int((int(template_bbox['ymax'])+int(template_bbox['ymin']))/2), 99 | int((int(template_bbox['xmax'])+int(template_bbox['xmin']))/2)] 100 | 101 | image = scipy.ndimage.imread(image_src) 102 | template = scipy.ndimage.imread(template_src) 103 | 104 | image = cl.image_processing.read_without_borders_3d(np.expand_dims(image, -1), 105 | [int(img_cntr[0]-img_size/2), 106 | int(img_cntr[1]-img_size/2), 0], 107 | [img_size, img_size, 3]) 108 | 109 | template = cl.image_processing.read_without_borders_3d(np.expand_dims(template, -1), 110 | [int(tmp_cntr[0]-tmp_size/2), 111 | int(tmp_cntr[1]-tmp_size/2), 0], 112 | [tmp_size, tmp_size, 3]) 113 | #cl.visual.save(image, 'dump/image') 114 | #cl.visual.save(template, 'dump/template') 115 | 116 | return image, template 117 | 118 | 119 | def get_sample(s_size): 120 | x = np.floor(0.75*shape[0]*np.random.random(1)+shape[0]*0.1).astype(int) 121 | y = np.floor(0.75*shape[1]*np.random.random(1)+shape[1]*0.1).astype(int) 122 | z = np.floor(0.75*shape[2]*np.random.random(1)+shape[2]*0.1).astype(int)+1200 123 | 124 | scale_ratio = (1/float(downsample), 1/float(downsample)) 125 | image = cloud.vol[x:x+downsample*s_size, y:y+downsample*s_size, z:z+1] 126 | template = cloud.vol[x:x+downsample*s_size, y:y+downsample*s_size, z+1:z+2] 127 | image, template = image[:,:,:,0], template[:,:,:,0] 128 | 129 | #image = cl.image_processing.resize(image[:,:,0], ratio=scale_ratio, order=1) 130 | #template = cl.image_processing.resize(template[:,:,0], ratio=scale_ratibo, order=1) 131 | return image[:,:,0], template[:,:,0] 132 | 133 | def get_real_sample(): 134 | return 135 | 136 | def _int64_feature(value): 137 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 138 | 139 | 140 | def _bytes_feature(value): 141 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 142 | 143 | def doncc(image, template): 144 | ncc = cl.image_processing.cv_normxcorr(image, template) 145 | pos = np.array(np.unravel_index(ncc.argmax(), ncc.shape)) 146 | tmp = np.array(ncc) 147 | tmp[pos[0]-5:pos[0]+10, pos[1]-5:pos[1]+10] = 0 148 | pos2 = np.array(np.unravel_index(tmp.argmax(), ncc.shape)) 149 | r1 = ncc[pos[0], pos[1]] 150 | r2 = ncc[pos2[0], pos2[1]] 151 | r = r1-r2 152 | return r, pos, ncc 153 | 154 | def convert_to(hparams, num_examples): 155 | """Converts a dataset to tfrecords.""" 156 | 157 | s_rows = hparams.features['search_raw']['in_width'] 158 | t_rows = hparams.features['template_raw']['width'] 159 | 160 | filename = hparams.tfrecord_train_dest #os.path.join(hparams.data_dir, name + '.tfrecords') 161 | 162 | print('Writing', filename) 163 | writer = tf.python_io.TFRecordWriter(filename) 164 | index = 0 165 | 166 | while(index < num_examples): 167 | #if index%20 == 0: 168 | # print(str(100*index/float(num_examples))+"%") 169 | #Get images 170 | s, t = get_sample(s_rows) 171 | 172 | start = int(t.shape[0]/2-t_rows/2) 173 | end = start + t_rows 174 | 175 | temp = t[start:end, start:end] 176 | result, _, ncc = doncc(s, temp) 177 | #print(result) 178 | if(result < hparams.r_delta) and (result>0.01) : 179 | cl.visual.save(s/255, 'dump/image') 180 | cl.visual.save(t/255, 'dump/template') 181 | cl.visual.save(temp/255, 'dump/small_template') 182 | cl.visual.save(ncc, 'dump/ncc') 183 | print('done', index, result) 184 | #exit() 185 | search_raw = np.asarray(s, dtype=np.uint8).tostring() 186 | temp_raw = np.asarray(t, dtype=np.uint8).tostring() 187 | 188 | ex = tf.train.Example(features=tf.train.Features(feature={ 189 | 'search_raw': _bytes_feature(search_raw), 190 | 'template_raw': _bytes_feature(temp_raw),})) 191 | 192 | writer.write(ex.SerializeToString()) 193 | index += 1 194 | 195 | writer.close() 196 | 197 | 198 | def rbg_convert_to(hparams, num_examples): 199 | """Converts a dataset to tfrecords.""" 200 | 201 | s_rows = hparams.features['search_raw']['in_width'] 202 | t_rows = hparams.features['template_raw']['in_width'] 203 | print(s_rows, t_rows) 204 | filename = hparams.tfrecord_train_dest #os.path.join(hparams.data_dir, name + '.tfrecords') 205 | 206 | print('Writing', filename) 207 | writer = tf.python_io.TFRecordWriter(filename) 208 | index = 0 209 | while(index < num_examples): 210 | try: 211 | s, t = get_image_data(base, suffix, s_rows, t_rows) 212 | 213 | #print(s[:,:,0].shape) 214 | #print( temp[:,:,0].shape) 215 | #exit() 216 | 217 | result, _, ncc = doncc(s[:,:,0].astype(np.uint8), t[:,:,0].astype(np.uint8)) 218 | 219 | #print(result) 220 | if(result < 0.2) and (result>0.001): 221 | #cl.visual.save(s, 'dump/image') 222 | #cl.visual.save(t, 'dump/template') 223 | #cl.visual.save(ncc, 'dump/ncc') 224 | 225 | print('done', index) 226 | #exit() 227 | search_raw = np.asarray(s, dtype=np.uint8).tostring() 228 | temp_raw = np.asarray(t, dtype=np.uint8).tostring() 229 | 230 | ex = tf.train.Example(features=tf.train.Features(feature={ 231 | 'search_raw': _bytes_feature(search_raw), 232 | 'template_raw': _bytes_feature(temp_raw),})) 233 | 234 | writer.write(ex.SerializeToString()) 235 | index += 1 236 | except: 237 | print("error") 238 | 239 | writer.close() 240 | 241 | def main(unused_argv): 242 | 243 | # Convert to Examples and write the result to TFRecords. 244 | #rbg_convert_to(hparams, 100000) 245 | convert_to(hparams, hparams.samples) 246 | #convert_to(data, hparams, 1000, 'test_1K') 247 | 248 | 249 | if __name__ == '__main__': 250 | tf.app.run(main=main, argv=[sys.argv[0]]) 251 | -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from scipy.io import loadmat, savemat 3 | from scipy.ndimage import imread 4 | from flask import request 5 | import json 6 | app = Flask(__name__) 7 | output_rect_path = '/tmp/output_rect.mat' 8 | import cavelab as cl 9 | from evaluation import doncc, process_ncc 10 | import numpy as np 11 | 12 | n_threads = 40 13 | 14 | from pathos.multiprocessing import ProcessPool, ThreadPool 15 | pool = ThreadPool(n_threads) 16 | 17 | @app.route('/process_NCC') 18 | def process_NCC(): 19 | frames_path = request.args.get('frames_path') 20 | frames = loadmat(frames_path)['frames'] 21 | rect_path = request.args.get('rect_path') 22 | init_rect = loadmat(rect_path)['rect'][0] 23 | 24 | out_rects = [] 25 | print(init_rect) 26 | W = int(init_rect[3]) 27 | H = int(init_rect[2]) 28 | 29 | image_data = imread(frames[0][0][0]) 30 | if len(image_data.shape) == 2: 31 | image_data = np.expand_dims(image_data, -1) 32 | template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 33 | [init_rect[1], init_rect[0], 0], 34 | [W, H, 3]).astype(np.uint8) 35 | 36 | print(template.shape, template.dtype) 37 | #cl.visual.save(template, "dump/eval/image") 38 | for f in frames: 39 | image_path = f[0][0] 40 | image_data = imread(image_path) 41 | 42 | if len(image_data.shape) == 2: 43 | image_data = np.expand_dims(image_data, -1) 44 | 45 | ncc, _, pos, _ = doncc(image_data, template) 46 | 47 | template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 48 | [pos[0], pos[1], 0], 49 | [W, H, 3]).astype(np.uint8) 50 | 51 | image_pred = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 52 | [int(pos[0]),int(pos[1]),0], 53 | [W, H, 3]) 54 | 55 | 56 | cl.visual.save(template, "dump/eval/image") 57 | print(pos) 58 | result_bbox = [pos[1], pos[0], H, W] 59 | out_rects.append(result_bbox) 60 | 61 | savemat(output_rect_path, {"out_rects": out_rects}) 62 | return output_rect_path 63 | 64 | 65 | @app.route('/process_NCCNet') 66 | def process_NCCNet(): 67 | frames_path = request.args.get('frames_path') 68 | frames = loadmat(frames_path)['frames'] 69 | rect_path = request.args.get('rect_path') 70 | init_rect = loadmat(rect_path)['rect'][0] 71 | 72 | out_rects = [] 73 | print(init_rect) 74 | W = int(init_rect[3]) 75 | H = int(init_rect[2]) 76 | 77 | image_data = imread(frames[0][0][0]) 78 | if len(image_data.shape) == 2: 79 | image_data = np.expand_dims(image_data, -1) 80 | #image_data = process_ncc(image_data) 81 | 82 | template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 83 | [init_rect[1], init_rect[0], 0], 84 | [W, H, image_data.shape[2]]).astype(np.uint8) 85 | 86 | print(template.shape, template.dtype) 87 | images = pool.map(lambda f: imread(f[0][0]), frames) 88 | 89 | for image_data in images: 90 | if len(image_data.shape) == 2: 91 | image_data = np.expand_dims(image_data, -1) 92 | 93 | #image_data = process_ncc(image_data) 94 | ncc, _, pos, _ = doncc(image_data, template) 95 | 96 | # template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 97 | # [pos[0], pos[1], 0], 98 | # [W, H, image_data.shape[2]]).astype(np.uint8) 99 | 100 | # cl.visual.save(template, "dump/eval/image") 101 | print(pos) 102 | result_bbox = [pos[1], pos[0], H, W] 103 | out_rects.append(result_bbox) 104 | 105 | savemat(output_rect_path, {"out_rects": out_rects}) 106 | return output_rect_path 107 | 108 | 109 | 110 | 111 | def process_test(): 112 | frames_path = "/tmp/frames.mat" #request.args.get('frames_path') 113 | frames = loadmat(frames_path)['frames'] 114 | rect_path = "/tmp/rect.mat" #request.args.get('rect_path') 115 | init_rect = loadmat(rect_path)['rect'][0] 116 | 117 | out_rects = [] 118 | print(init_rect) 119 | W = int(init_rect[3]) 120 | H = int(init_rect[2]) 121 | 122 | image_data = imread(frames[0][0][0]) 123 | if len(image_data.shape) == 2: 124 | image_data = np.expand_dims(image_data, -1) 125 | image_data = process_ncc(image_data) 126 | 127 | template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 128 | [init_rect[1], init_rect[0], 0], 129 | [W, H, image_data.shape[2]]).astype(np.uint8) 130 | 131 | print(template.shape, template.dtype) 132 | #cl.visual.save(template, "dump/eval/image") 133 | images = pool.map(lambda f: imread(f[0][0]), frames) 134 | 135 | for image_data in images: 136 | # image_path = f[0][0] 137 | # image_data = image#imread(image_path) 138 | 139 | if len(image_data.shape) == 2: 140 | image_data = np.expand_dims(image_data, -1) 141 | cl.visual.save(image_data[:,:,0], "dump/eval/image") 142 | image_data = process_ncc(image_data) 143 | print(image_data.shape, template.shape) 144 | 145 | ncc, _, pos, _ = doncc(image_data, template) 146 | 147 | # template = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 148 | # [pos[0], pos[1], 0], 149 | # [W, H, image_data.shape[2]]).astype(np.uint8) 150 | 151 | # image_pred = cl.image_processing.read_without_borders_3d(np.expand_dims(image_data, -1), 152 | # [int(pos[0]),int(pos[1]),0], 153 | # [W, H, 3]) 154 | 155 | 156 | cl.visual.save(ncc[:,:], "dump/eval/template") 157 | print(pos) 158 | result_bbox = init_rect #[pos[1], pos[0], H, W] 159 | out_rects.append(result_bbox) 160 | 161 | savemat(output_rect_path, {"out_rects": out_rects}) 162 | return output_rect_path 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | # if not os.path.exists('db.sqlite'): 168 | # db.create_all() 169 | process_test() 170 | #app.run(host='0.0.0.0', port=5000, debug=False) 171 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import cavelab as cl 2 | from model import build 3 | from data import Data 4 | from time import time 5 | import numpy as np 6 | #from evaluation import get_batch 7 | 8 | 9 | # April 28 10 | # Work on evaluation 11 | # - Setup validation dataset 12 | # - Explore object tracking 13 | # Work on benchmarking 14 | # - Connect Matlab to Python 15 | # - Run OTB and do tricks? 16 | 17 | #Loop 18 | # Get more data 10K -> 80K 19 | # Train Multilayer 20 | # Train Multiscale (given time) 21 | 22 | def train(hparams): 23 | #Data input 24 | d = Data(hparams) 25 | 26 | model = build(hparams) 27 | sess = cl.tf.global_session().get_sess() 28 | cl.tf.global_session().add_log_writers('/projects/NCCNet2/NCCNet/logs/'+hparams.name+'/', 29 | hparams=hparams, 30 | clean_first=True) 31 | test_data = [] 32 | crop_coef = 0 33 | hardness = 0.25 34 | save = True 35 | 36 | try: 37 | for i in range(hparams.steps): 38 | 39 | a = time() 40 | image, template, label = d.get_batch() 41 | #print(image) 42 | 43 | image = image*255.0 44 | template = template*255.0 45 | 46 | model_run = [model.train_step, 47 | model.l, 48 | model.cnn_image_big, 49 | model.cnn_templ_big, 50 | cl.tf.global_session().merged] 51 | 52 | feed_dict = { 53 | model.image: image, 54 | model.template: template, 55 | model.similar: label, 56 | model.crop_coef: crop_coef 57 | } 58 | 59 | step = sess.run(model_run, feed_dict=feed_dict, run_metadata=cl.tf.global_session().run_metadata) 60 | c = time() 61 | 62 | #Curriculum learning 63 | if abs(step[1])>hardness and label[0]>0 and crop_coef<5: 64 | #crop_coef += 1 65 | hardness += 0.05 66 | 67 | if i%hparams.log_iterations == 0: 68 | a1 = time() 69 | cl.tf.global_session().log_save(cl.tf.global_session().train_writer, step[-1], i) 70 | a2 = time() 71 | 72 | if i%hparams.eval_iterations == 0: 73 | b1 = time() 74 | evaluate(model, test_data, hparams.testing_steps, i) 75 | #try: 76 | 77 | #except: 78 | # pass 79 | b2 = time() 80 | 81 | if((step[1]==float("Inf") or step[1]==0) and save): 82 | save = False 83 | for j in range(8): 84 | cl.visual.save(image[j], "dump/img/source/"+str(j)) 85 | for k in range(4): 86 | cl.visual.save(step[2][j,:,:,k], "dump/img/f"+str(k)+"/"+str(j)) 87 | cl.visual.save(step[3][j,:,:,k], "dump/tmp/f"+str(k)+"/"+str(j)) 88 | cl.visual.save(template[j], "dump/tmp/template/"+str(j)) 89 | raise Exception('error') 90 | b = time() 91 | print('iteration',i, format(b-a, '.2f'), format(step[1], '.2f'), np.mean(label) ) 92 | except KeyboardInterrupt(): 93 | print('exiting') 94 | finally: 95 | cl.tf.global_session().model_save() 96 | print('saved') 97 | cl.tf.global_session().close_sess() 98 | 99 | 100 | # write this 101 | def evaluate(model, test_data, testing_steps, i): 102 | return 103 | 104 | if __name__ == "__main__": 105 | hparams = cl.hparams(name="default") 106 | train(hparams) 107 | --------------------------------------------------------------------------------