├── utils ├── __init__.py ├── datasets.py ├── plotters.py ├── transformer.py ├── training.py ├── models.py ├── readers.py ├── TrainNetwork.py └── InferMovie.py ├── LICENSE ├── README.md └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | # Defines class items handling a dataset. 2 | class TrackingDataset: 3 | def __init__(self, train_file, valid_file, folder_prefix): 4 | self.train_list = open(train_file, 'r').read().splitlines() 5 | self.valid_list = open(valid_file, 'r').read().splitlines() 6 | self.train_images = [folder_prefix + '/Ref/' + train_item + '.png' for train_item in self.train_list] 7 | self.train_labels = [folder_prefix + '/Ell/' + train_item + '.txt' for train_item in self.train_list] 8 | self.train_seg = [folder_prefix + '/Seg/' + train_item + '.png' for train_item in self.train_list] 9 | self.train_size = len(self.train_list) 10 | self.valid_images = [folder_prefix + '/Ref/' + valid_item + '.png' for valid_item in self.valid_list] 11 | self.valid_labels = [folder_prefix + '/Ell/' + valid_item + '.txt' for valid_item in self.valid_list] 12 | self.valid_seg = [folder_prefix + '/Seg/' + valid_item + '.png' for valid_item in self.valid_list] 13 | self.valid_size = len(self.valid_list) 14 | 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/plotters.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from .readers import means, scale 3 | import numpy as np 4 | 5 | # Plots the ellipse on the plot 6 | def plot_ellipse(plot, label, color = (255, 255, 0)): 7 | label = np.add(np.multiply(label, scale), means) 8 | labelAngle = -np.arctan2(label[4],label[5])*180/np.pi 9 | if label[2] > 0 and label[3] > 0: 10 | cv2.ellipse(plot,(np.float32(label[0]),np.float32(label[1])),(np.float32(label[2]/2.0),np.float32(label[3]/2.0)),labelAngle,0.0,360.0,color) 11 | # Direction 12 | # (x,y) to (x+cos(angle)*maj/2,y+sin(angle)*maj/2) 13 | cv2.line(plot,(np.float32(label[0]),np.float32(label[1])),(np.float32(label[0]+label[4]*label[3]/2.0),np.float32(label[1]+label[5]*label[3]/2.0)),color) 14 | return plot 15 | 16 | # Plots an xy hash on the plot 17 | def plot_xy(plot, label, color = (255, 255, 0), rescale = True): 18 | if rescale: 19 | label = np.add(np.multiply(label, scale), means) 20 | cv2.line(plot,(np.float32(label[0]-2), np.float32(label[1])),(np.float32(label[0]+2), np.float32(label[1])), color) 21 | cv2.line(plot,(np.float32(label[0]), np.float32(label[1]-2)),(np.float32(label[0]), np.float32(label[1]+2)), color) 22 | return plot 23 | 24 | # Merges the mask of a segmented image 25 | def plot_seg(plot, label, color = 1): 26 | plot[:,:,color] = -cv2.resize(label,(480,480))+1 27 | return plot 28 | 29 | def plot_image_seg(image, label): 30 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 31 | plot = plot_seg(plot, label/255.0) 32 | cv2.imshow('PlotSeg',plot) 33 | cv2.waitKey() 34 | cv2.destroyAllWindows() 35 | 36 | def plot_image_seg_compare(image, label, label2): 37 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 38 | plot = plot_seg(plot, label/255.0) 39 | plot = plot_seg(plot, label2/255.0, 2) 40 | cv2.imshow('PlotSeg',plot) 41 | cv2.waitKey() 42 | cv2.destroyAllWindows() 43 | 44 | # Plots the image with the given label 45 | def plot_image_labels(image, label): 46 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 47 | plot = plot_ellipse(plot, label) 48 | cv2.imshow('PlotEllipse',plot) 49 | cv2.waitKey() 50 | cv2.destroyAllWindows() 51 | 52 | # Plots the image with the given label and label2 53 | def plot_image_labels_compare(image, label, label2): 54 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 55 | plot = plot_ellipse(plot, label) # Cyan default 56 | plot = plot_ellipse(plot, label2, (255, 0, 255)) # Magenta 57 | cv2.imshow('PlotEllipse',plot) 58 | cv2.waitKey() 59 | cv2.destroyAllWindows() 60 | 61 | def plot_image_xy(image, label): 62 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 63 | plot = plot_xy(plot, label) # Cyan default 64 | cv2.imshow('PlotXY',plot) 65 | cv2.waitKey() 66 | cv2.destroyAllWindows() 67 | 68 | def plot_image_xy_compare(image, label, label2): 69 | plot = cv2.resize(cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB),(480,480)) 70 | plot = plot_xy(plot, label) # Cyan default 71 | plot = plot_xy(plot, label2, (255, 0, 255)) # Magenta 72 | cv2.imshow('PlotXY',plot) 73 | cv2.waitKey() 74 | cv2.destroyAllWindows() 75 | 76 | def save_image_labels(image, label, filename): 77 | plot = cv2.resize(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB),(480,480)) 78 | plot = plot_ellipse(plot, label) # Cyan default 79 | cv2.imwrite(filename, plot) 80 | 81 | def save_image_labels_compare(image, label, label2, filename): 82 | plot = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB) 83 | plot = plot_ellipse(plot, label) # Cyan default 84 | plot = plot_ellipse(plot, label2, (255, 0, 255)) # Magenta 85 | cv2.imwrite(filename, plot) 86 | 87 | def save_image_xy(image, label, filename): 88 | plot = cv2.resize(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB),(480,480)) 89 | plot = plot_xy(plot, label) # Cyan default 90 | cv2.imwrite(filename, plot) 91 | 92 | def save_image_xy_compare(image, label, label2, filename): 93 | plot = cv2.resize(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB),(480,480)) 94 | plot = plot_xy(plot, label) # Cyan default 95 | plot = plot_xy(plot, label2, (255, 0, 255)) # Magenta 96 | cv2.imwrite(filename, plot) 97 | 98 | def save_image_seg(image, label): 99 | plot = cv2.cvtColor(image/255.0 ,cv2.COLOR_GRAY2RGB, filename) 100 | plot = plot_seg(plot, label/255.0) 101 | cv2.imwrite(filename, plot*255.0) 102 | 103 | def save_image_seg_compare(image, label, label2, filename): 104 | plot = cv2.cvtColor(image/255.0,cv2.COLOR_GRAY2RGB) 105 | plot = plot_seg(plot, label/255.0) 106 | plot = plot_seg(plot, label2/255.0, 2) 107 | cv2.imwrite(filename, plot*255.0) 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Environment Installation 3 | 4 | All testing in the paper was done in python 2.7.12 with tensorflow v 1.0.1. Compatibility has been tested for python 3.6 and tensorflow 1.4. 5 | 6 | Additional support libraries were in the following versions 7 | 8 | 1. Opencv 2.4.13 9 | 2. Numpy 1.13.1 10 | 3. Scipy 0.17.0 11 | 4. imageio 2.1.2 12 | 13 | While opencv, numpy, scipy, and imageio have fairly stable releases, tensorflow is still heavily under development. Details on installing older versions of tensorflow are available at the [tensorflow website](https://www.tensorflow.org/versions/). 14 | 15 | The following commands were used to install the environment. You can specify the versions 16 | 17 | ``` 18 | pip install tensorflow-gpu==1.4 19 | pip install opencv-python 20 | pip install numpy 21 | pip install scipy 22 | pip install imageio 23 | ``` 24 | 25 | If you do not have the ffmpeg video reader installed, imageio can download it for you: 26 | 27 | Inside python... 28 | 29 | ``` 30 | import imageio 31 | imageio.plugins.ffmpeg.download() 32 | ``` 33 | 34 | # Neural Network 35 | 36 | ## Training 37 | 38 | 1. Download/Create a labeled dataset. 39 | 40 | 2. Create a training/validation split of the dataset. 41 | ``` 42 | find Ell -name '*.txt' | sed -e 's/Ell\///' -e 's/\.txt//' | shuf | split -l 16000 43 | ``` 44 | This will create 2 files: xaa and xab. xaa is the train set (of size 16000 frames), xab is the validation set (remainder, up to 16000). If you are using the training set from the paper, the file patterns infer the splitting. 45 | 46 | 3. Run the training. 47 | Exposed parameters described in program docs 48 | ``` 49 | python main.py --net_type segellreg --batch_size 50 Train --model construct_segellreg_v8 --log_dir Training_Segellreg --num_steps 100000 --learn_function gen_train_op_adam --train_list Train_Split.txt --valid_list Valid_Split.txt --dataset_folder --start_learn_rate 1e-5 50 | ``` 51 | 52 | ## Inference 53 | 54 | 1. Idenfity the video or list of videos that you wish to infer. 55 | 56 | 2. Run the inference code. 57 | ``` 58 | python main.py --net_type segellreg InferMany --model construct_segellreg_v8 --network_to_restore --input_movie_list --ellfit_output 59 | ``` 60 | 61 | --- 62 | 63 | # Program Docs 64 | 65 | ## Design Intention 66 | 67 | This software was designed for 480x480 monochromatic images and is untested for different image size. Due to the pooling and upsampling layers, this exact structure will only work with images in multiples of 96 pixels without adjusting the network layers. 68 | Functionally, the input images must be square and be the same shape as all the other images tested. 69 | 70 | ## Usage Parameters 71 | 72 | This software has a large variety of parameters to edit at runtime. 73 | A brief description of these parameters have been encoded into the main file through Python's argparse library. 74 | To access this information, run the following commands: 75 | 76 | ``` 77 | python main.py --help 78 | python main.py Train --help 79 | python main.py Infer --help 80 | python main.py InferMany --help 81 | ``` 82 | 83 | ## Ellipse-fit scaling values 84 | 85 | For different environments, it is necessary to change the "means" and "scales" variables inside "utils/readers.py" to place the dataset roughly into the range of [0,1]. 86 | "means" accounts for an additive mean shift of the data. 87 | "scale" accounts for a multiplicative scaling of the data. 88 | The scaling equation (during training) : rescaled\_ellfit = (img_ellfit - means) / scales 89 | The reverse (during inference) is: img\_ellfit = (predicted\_ellfit * scales) + means 90 | 91 | While these changes is not important for the segmentation approach, they do substantially influence the performance of other approaches. 92 | 93 | ## Network Types 94 | 95 | This code supports 3 main types of network structures: Segmentation-based Ellipse Regression (segellreg), Direct Ellipse Regression (ellreg), and Binned XY Prediction (binned). Additionally, the segmentation network without angle predictor is included (seg) 96 | 97 | In this release, there is one network model definition for each solution type: 98 | 99 | 1. construct_segellreg_v8 (segellreg) 100 | 2. construct_ellreg_v3_resnet (ellreg) 101 | 3. construct_xybin_v1 (binned) 102 | 4. construct_segsoft_v5 (seg) 103 | 104 | 105 | ## Inference notes 106 | 107 | ### Object Not Present Handling 108 | 109 | The segmentation-based ellipse fit approach uses negative values if no mask is present. A quick check of containing negative major/minor axis lengths will identify frames in which the desired tracked object is not present. 110 | 111 | The other approaches do not use default values and as such may produce odd behavior when the tracked object is not present. For the binned approach, the network typically contains a uniform distribution of probable locations. For the regression approach, nonsense values are produced (such as values outside the expected range). 112 | 113 | ### Video Frame Synchronization 114 | 115 | This software uses the imageio-based ffmpeg frame reader. If you are not familiar with encoding of videos, there are a couple known issues that are related to frame timestamps and flags. 116 | 117 | To avoid these issues, it is recommended that you remove timestamp information from your videos with the following ffmpeg command (for 30fps video): 118 | ``` 119 | ffmpeg -r 30 -i -c:v mpeg4 -q 0 -vsync drop 120 | ``` 121 | Command description: 122 | -r 30 --> Assume video is 30fps 123 | -i `` --> input movie 124 | -c:v mpeg4 --> use ffmpeg's mpeg4 video codec 125 | -q 0 --> get as close to input quality as possible 126 | -vsync drop --> remove timestamps from frames (and fill with framerate listed) 127 | 128 | ### Loading Models 129 | 130 | Tensorflow has a bit of a strange naming convention for model files. This code is designed to save models similar to the following: model.chkpt-[step\_number]. 131 | 132 | Each model is associated with 3 files: [model\_name].data-00000-of-00001, [model\_name].index, and [model\_name].meta. 133 | Tensorflow has additional model format information available at the [tensorflow website](https://www.tensorflow.org/extend/tool_developers/). 134 | 135 | ### Available Outputs 136 | 137 | All outputs are not available for all network types. "Feature" layers don't exist for the segmentation approach and "Segmentation" videos can't be produced for a regression approach. If the output has not been implemented, the code will simply ignore the request for this output. 138 | The primary output used is the ellipse-fit npy output (--ellfit_output) for downstream analyses. 139 | 140 | All outputs include: 141 | 142 | 1. ellfit_movie_output 143 | 2. affine_movie_output 144 | 3. crop_movie_output 145 | 4. ellfit_output 146 | 5. ellfit_features_output 147 | 6. seg_movie_output 148 | 149 | If no selected outputs are available for the network structure, the program will not run inferences on the video (and close without error). 150 | 151 | #### Ellipse-fit File 152 | 153 | The ellipse-fit output file contains 6 values per frame. The values are in the following order: 154 | 155 | 1. center_x (in pixels) 156 | 2. center_y (in pixels) 157 | 3. minor_axis_length (in pixels, half the width of the bounding rectangle) 158 | 4. major_axis_length (in pixels, half the height of the bounding rectangle) 159 | 5. sine of predicted angle (zero pointing down with positive angles going counter-clockwise) 160 | 6. cosine of predicted angle (zero pointing down with positive angles going counter-clockwise) 161 | 162 | For reading the binary numpy (npy) file, refer to the [supporting code](https://github.com/KumarLabJax/MouseTrackingExtras/tree/master/NPYReader). 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys, getopt, os, re, argparse 2 | import utils.TrainNetwork as trainNet 3 | import utils.InferMovie as inferNet 4 | from utils import datasets 5 | from utils import models 6 | from utils import training 7 | import inspect 8 | import resource 9 | 10 | def main(argv): 11 | # Parse some selections 12 | possible_models = [x for x in inspect.getmembers(models)] 13 | possible_models = {x[0]:x[1] for x in possible_models if inspect.isfunction(getattr(models, x[0]))} 14 | possible_learn = [x for x in inspect.getmembers(training)] 15 | possible_learn = {x[0]:x[1] for x in possible_learn if inspect.isfunction(getattr(training, x[0]))} 16 | # Filter the learning functions for the ones that actually contain "train" in their name (omits summary builders, other helper functions) 17 | learn_functs = [re.search('.*train.*',x).group() for x in possible_learn.keys() if re.search('.*train.*',x)] 18 | possible_learn = { key:value for key,value in possible_learn.items() if key in learn_functs } 19 | 20 | # Start up the definitions of argument parsers 21 | parser = argparse.ArgumentParser(description='Run Tensorflow Graphs') 22 | # Separate out specific training/eval/inference parameters 23 | subparsers = parser.add_subparsers(title='mode', description='Type of processing to do with the network', help='Additional Help', dest='mode') 24 | # Add general arguments 25 | parser.add_argument('--net_type', help='Type of network model (default ellipse_regression)', choices=['ellreg','segellreg','binned','seg'], default='segellreg') 26 | parser.add_argument('--batch_size', help='Batch size of the network (default 5)', type=int, default=5) 27 | parser.add_argument('--input_size', help='Frame input size of the network (default 480)', type=int, default=480) 28 | 29 | # Training parameters 30 | parser_train = subparsers.add_parser('Train', help='Training Parameters') 31 | parser_train.add_argument('--model', help='Network model to use', choices=possible_models.keys(), required=True) 32 | parser_train.add_argument('--network_to_restore', help='Network checkpoint to restore') 33 | parser_train.add_argument('--n_reader_threads', help='Number of CPU threads for fetching data (default 3)', default=3, type=int) 34 | parser_train.add_argument('--log_dir', help='Log folder', default='.') 35 | parser_train.add_argument('--train_list', help='File containing list of identifiers for training', required=True) 36 | parser_train.add_argument('--valid_list', help='File containing list of identifiers for validating', required=True) 37 | parser_train.add_argument('--dataset_folder', help='Root folder of the training set') 38 | parser_train.add_argument('--num_steps', help='Steps to take during training (default 500k)', type=int, default=500000) 39 | parser_train.add_argument('--start_learn_rate', help='Initial learning rate for training (default 5e-7)', type=float, default=5e-7) 40 | parser_train.add_argument('--epocs_per_lr_decay', help='Epocs per learn rate decay (default 5)', type=int, default=5) 41 | parser_train.add_argument('--decay_learn_rate', help='Decay learn rate (default constant)', dest='const_learn_rate', action='store_false', default=True) 42 | parser_train.add_argument('--learn_function', help='Learn function', choices=possible_learn.keys(), required=True) 43 | parser_train.add_argument('--aug_rot_max', help='Max small rotation augmentation (degrees, train set)', type=float, default=2.5) 44 | parser_train.add_argument('--aug_trans_max', help='Max small translation augmentation (px, train set)', type=float, default=5.0) 45 | parser_train.add_argument('--bin_per_px', help='Multiplier for number of bins per pixel (default 10)', type=int, default=10) 46 | 47 | # Inference parameters 48 | parser_infer = subparsers.add_parser('Infer', help='Inference Parameters') 49 | parser_infer.add_argument('--model', help='Network model to use', choices=possible_models.keys(), required=True) 50 | parser_infer.add_argument('--bin_per_px', help='Multiplier for number of bins per pixel (Binned Network ONLY default 10)', type=int, default=10) 51 | parser_infer.add_argument('--network_to_restore', help='Network checkpoint to restore', required=True) 52 | parser_infer.add_argument('--input_movie', help='Input movie to evaluate') 53 | parser_infer.add_argument('--ellfit_movie_output', help='Output a movie with the plotted ellipse-prediction', action='store_true', default=False) 54 | parser_infer.add_argument('--affine_movie_output', help='Output cropped + centered + rotated movie', action='store_true', default=False) 55 | parser_infer.add_argument('--crop_movie_output', help='Output center-cropped movie (uses same affine_crop_dim)', action='store_true', default=False) 56 | parser_infer.add_argument('--affine_crop_dim', help='Cropped dimension for affine-transformed movie (default 112)', type=int, default=112) 57 | parser_infer.add_argument('--ellfit_output', help='Output ellipse-fit data file (npy)', action='store_true', default=False) 58 | parser_infer.add_argument('--ellfit_features_output', help='Output ellipse-fit feature data file (npy)', action='store_true', default=False) 59 | parser_infer.add_argument('--seg_movie_output', help='Output the segmentation mask as a movie', action='store_true', default=False) 60 | 61 | # Multiple Inference parameters 62 | parser_infermany = subparsers.add_parser('InferMany', help='Inference Parameters') 63 | parser_infermany.add_argument('--model', help='Network model to use', choices=possible_models.keys(), required=True) 64 | parser_infermany.add_argument('--bin_per_px', help='Multiplier for number of bins per pixel (Binned Network ONLY default 10)', type=int, default=10) 65 | parser_infermany.add_argument('--network_to_restore', help='Network checkpoint to restore', required=True) 66 | parser_infermany.add_argument('--input_movie_list', help='Text file containing line-by-line list of movies to process') 67 | parser_infermany.add_argument('--ellfit_movie_output', help='Output a movie with the plotted ellipse-prediction', action='store_true', default=False) 68 | parser_infermany.add_argument('--affine_movie_output', help='Output cropped + centered + rotated movie', action='store_true', default=False) 69 | parser_infermany.add_argument('--crop_movie_output', help='Output center-cropped movie (uses same affine_crop_dim)', action='store_true', default=False) 70 | parser_infermany.add_argument('--affine_crop_dim', help='Cropped dimension for affine-transformed movie (default 112)', type=int, default=112) 71 | parser_infermany.add_argument('--ellfit_output', help='Output ellipse-fit data file (npy)', action='store_true', default=False) 72 | parser_infermany.add_argument('--ellfit_features_output', help='Output ellipse-fit feature data file (npy)', action='store_true', default=False) 73 | parser_infermany.add_argument('--seg_movie_output', help='Output the segmentation mask as a movie', action='store_true', default=False) 74 | 75 | # Grab all the parsed arguments 76 | args = parser.parse_args() 77 | arg_dict = args.__dict__ 78 | arg_dict['model_construct_function'] = possible_models[args.model] 79 | # Other keyed selections... 80 | if 'learn_function' in arg_dict.keys() and arg_dict['learn_function'] is not None: 81 | arg_dict['learn_function'] = possible_learn[args.learn_function] 82 | 83 | # Prep the dataset 84 | if 'dataset_folder' in arg_dict.keys() and arg_dict['dataset_folder'] is not None: 85 | arg_dict['dataset'] = datasets.TrackingDataset(arg_dict['train_list'], arg_dict['valid_list'], arg_dict['dataset_folder']) 86 | elif 'train_list' in arg_dict.keys(): 87 | arg_dict['dataset'] = datasets.TrackingDataset(arg_dict['train_list'], arg_dict['valid_list'], '.') 88 | 89 | # Call the correct sub-parser and send the argument dictionary for futher separation 90 | # Note that the keys are heavily dependent upon naming conventions... 91 | if args.mode == 'Train': 92 | trainNet.trainNetwork(arg_dict) 93 | elif args.mode == 'Infer' or args.mode == 'InferMany': 94 | inferNet.inferMovie(arg_dict) 95 | else: 96 | print('Could not understand commands:') 97 | print(args.__dict__) 98 | 99 | 100 | if __name__ == '__main__': 101 | main(sys.argv[1:]) 102 | 103 | -------------------------------------------------------------------------------- /utils/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | 18 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs): 19 | """Spatial Transformer Layer 20 | 21 | Implements a spatial transformer layer as described in [1]_. 22 | Based on [2]_ and edited by David Dao for Tensorflow. 23 | 24 | Parameters 25 | ---------- 26 | U : float 27 | The output of a convolutional net should have the 28 | shape [num_batch, height, width, num_channels]. 29 | theta: float 30 | The output of the 31 | localisation network should be [num_batch, 6]. 32 | out_size: tuple of two ints 33 | The size of the output of the network (height, width) 34 | 35 | References 36 | ---------- 37 | .. [1] Spatial Transformer Networks 38 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 39 | Submitted on 5 Jun 2015 40 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 41 | 42 | Notes 43 | ----- 44 | To initialize the network to the identity transform init 45 | ``theta`` to : 46 | identity = np.array([[1., 0., 0.], 47 | [0., 1., 0.]]) 48 | identity = identity.flatten() 49 | theta = tf.Variable(initial_value=identity) 50 | 51 | """ 52 | 53 | def _repeat(x, n_repeats): 54 | with tf.variable_scope('_repeat'): 55 | rep = tf.transpose( 56 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 57 | rep = tf.cast(rep, 'int32') 58 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 59 | return tf.reshape(x, [-1]) 60 | 61 | def _interpolate(im, x, y, out_size): 62 | with tf.variable_scope('_interpolate'): 63 | # constants 64 | num_batch = tf.shape(im)[0] 65 | height = tf.shape(im)[1] 66 | width = tf.shape(im)[2] 67 | channels = tf.shape(im)[3] 68 | 69 | x = tf.cast(x, 'float32') 70 | y = tf.cast(y, 'float32') 71 | height_f = tf.cast(height, 'float32') 72 | width_f = tf.cast(width, 'float32') 73 | out_height = out_size[0] 74 | out_width = out_size[1] 75 | zero = tf.zeros([], dtype='int32') 76 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 77 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 78 | 79 | # scale indices from [-1, 1] to [0, width/height] 80 | x = (x + 1.0)*(width_f) / 2.0 81 | y = (y + 1.0)*(height_f) / 2.0 82 | 83 | # do sampling 84 | x0 = tf.cast(tf.floor(x), 'int32') 85 | x1 = x0 + 1 86 | y0 = tf.cast(tf.floor(y), 'int32') 87 | y1 = y0 + 1 88 | 89 | x0 = tf.clip_by_value(x0, zero, max_x) 90 | x1 = tf.clip_by_value(x1, zero, max_x) 91 | y0 = tf.clip_by_value(y0, zero, max_y) 92 | y1 = tf.clip_by_value(y1, zero, max_y) 93 | dim2 = width 94 | dim1 = width*height 95 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) 96 | base_y0 = base + y0*dim2 97 | base_y1 = base + y1*dim2 98 | idx_a = base_y0 + x0 99 | idx_b = base_y1 + x0 100 | idx_c = base_y0 + x1 101 | idx_d = base_y1 + x1 102 | 103 | # use indices to lookup pixels in the flat image and restore 104 | # channels dim 105 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 106 | im_flat = tf.cast(im_flat, 'float32') 107 | Ia = tf.gather(im_flat, idx_a) 108 | Ib = tf.gather(im_flat, idx_b) 109 | Ic = tf.gather(im_flat, idx_c) 110 | Id = tf.gather(im_flat, idx_d) 111 | 112 | # and finally calculate interpolated values 113 | x0_f = tf.cast(x0, 'float32') 114 | x1_f = tf.cast(x1, 'float32') 115 | y0_f = tf.cast(y0, 'float32') 116 | y1_f = tf.cast(y1, 'float32') 117 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 118 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 119 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 120 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 121 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) 122 | return output 123 | 124 | def _meshgrid(height, width): 125 | with tf.variable_scope('_meshgrid'): 126 | # This should be equivalent to: 127 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 128 | # np.linspace(-1, 1, height)) 129 | # ones = np.ones(np.prod(x_t.shape)) 130 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 131 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), 132 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) 133 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), 134 | tf.ones(shape=tf.stack([1, width]))) 135 | 136 | x_t_flat = tf.reshape(x_t, (1, -1)) 137 | y_t_flat = tf.reshape(y_t, (1, -1)) 138 | 139 | ones = tf.ones_like(x_t_flat) 140 | grid = tf.concat([x_t_flat, y_t_flat, ones], 0) 141 | return grid 142 | 143 | def _transform(theta, input_dim, out_size): 144 | with tf.variable_scope('_transform'): 145 | num_batch = tf.shape(input_dim)[0] 146 | height = tf.shape(input_dim)[1] 147 | width = tf.shape(input_dim)[2] 148 | num_channels = tf.shape(input_dim)[3] 149 | theta = tf.reshape(theta, (-1, 2, 3)) 150 | theta = tf.cast(theta, 'float32') 151 | 152 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 153 | height_f = tf.cast(height, 'float32') 154 | width_f = tf.cast(width, 'float32') 155 | out_height = out_size[0] 156 | out_width = out_size[1] 157 | grid = _meshgrid(out_height, out_width) 158 | grid = tf.expand_dims(grid, 0) 159 | grid = tf.reshape(grid, [-1]) 160 | grid = tf.tile(grid, tf.stack([num_batch])) 161 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) 162 | 163 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 164 | T_g = tf.matmul(theta, grid) 165 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 166 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 167 | x_s_flat = tf.reshape(x_s, [-1]) 168 | y_s_flat = tf.reshape(y_s, [-1]) 169 | 170 | input_transformed = _interpolate( 171 | input_dim, x_s_flat, y_s_flat, 172 | out_size) 173 | 174 | output = tf.reshape( 175 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) 176 | return output 177 | 178 | with tf.variable_scope(name): 179 | output = _transform(theta, U, out_size) 180 | return output 181 | 182 | 183 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): 184 | """Batch Spatial Transformer Layer 185 | 186 | Parameters 187 | ---------- 188 | 189 | U : float 190 | tensor of inputs [num_batch,height,width,num_channels] 191 | thetas : float 192 | a set of transformations for each input [num_batch,num_transforms,6] 193 | out_size : int 194 | the size of the output [out_height,out_width] 195 | 196 | Returns: float 197 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels] 198 | """ 199 | with tf.variable_scope(name): 200 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) 201 | indices = [[i]*num_transforms for i in xrange(num_batch)] 202 | input_repeated = tf.gather(U, tf.reshape(indices, [-1])) 203 | return transformer(input_repeated, thetas, out_size) 204 | 205 | -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | from .readers import means, scale 2 | import tensorflow as tf 3 | import tensorflow.contrib.slim as slim 4 | from .readers import ellreg_to_xyhot 5 | from .readers import atan2 6 | import scipy.ndimage.morphology as morph 7 | import numpy as np 8 | 9 | def gen_loss_ellreg(network_eval_batch, label_placeholder): 10 | loss = slim.losses.mean_squared_error(network_eval_batch, label_placeholder) 11 | # If angle should be ignored... 12 | #loss = slim.losses.mean_squared_error(tf.slice(network_eval_batch,[0,0],[-1,4]), tf.slice(label_placeholder,[0,0],[-1,4])) 13 | errors = tf.multiply(tf.reduce_mean(tf.abs(tf.subtract(network_eval_batch, label_placeholder)), reduction_indices=0), scale) 14 | gt_angle = 0.5*atan2(tf.add(tf.multiply(tf.slice(label_placeholder,[0,4],[-1,1]), scale[4]), means[4]), tf.add(tf.multiply(tf.slice(label_placeholder,[0,5],[-1,1]), scale[5]), means[5])) 15 | test_angle = 0.5*atan2(tf.add(tf.multiply(tf.slice(network_eval_batch,[0,4],[-1,1]), scale[4]), means[4]), tf.add(tf.multiply(tf.slice(network_eval_batch,[0,5],[-1,1]), scale[5]), means[5])) 16 | angles = tf.reduce_mean(tf.abs(tf.subtract(test_angle,gt_angle)))*180/np.pi 17 | return loss, errors, angles 18 | 19 | def gen_loss_seg(network_eval_batch, label_placeholder): 20 | # Apply morphological filtering to the label 21 | filter1 = tf.expand_dims(tf.constant(morph.iterate_structure(morph.generate_binary_structure(2,1),5),dtype=tf.float32),-1) 22 | seg_morph = tf.nn.dilation2d(tf.nn.erosion2d(label_placeholder,filter1,[1,1,1,1],[1,1,1,1],"SAME"),filter1,[1,1,1,1],[1,1,1,1],"SAME") 23 | filter2 = tf.expand_dims(tf.constant(morph.iterate_structure(morph.generate_binary_structure(2,1),4),dtype=tf.float32),-1) 24 | seg_morph = tf.nn.erosion2d(tf.nn.dilation2d(seg_morph,filter2,[1,1,1,1],[1,1,1,1],"SAME"),filter2,[1,1,1,1],[1,1,1,1],"SAME") 25 | #seg_morph = label_placeholder 26 | 27 | # Create the 2 bins 28 | mouse_label = tf.to_float(tf.greater(seg_morph, 0.0)) 29 | background_label = tf.to_float(tf.equal(seg_morph, 0.0)) 30 | combined_label = tf.concat([mouse_label, background_label] ,axis=3) 31 | flat_combined_label = tf.reshape(combined_label, [-1, 2]) 32 | flat_network_eval = tf.reshape(network_eval_batch, [-1, 2]) 33 | loss = tf.losses.softmax_cross_entropy(flat_combined_label, flat_network_eval) 34 | # Could do something fancy with counting TP/FP/TN/FN based on a softmax/argmax between the 2 35 | errors = None 36 | return loss, errors 37 | 38 | def gen_loss_seg_nomorph(network_eval_batch, label_placeholder): 39 | # Create the 2 bins 40 | mouse_label = tf.to_float(tf.greater(label_placeholder, 0.0)) 41 | background_label = tf.to_float(tf.equal(label_placeholder, 0.0)) 42 | combined_label = tf.concat([mouse_label, background_label] ,axis=3) 43 | flat_combined_label = tf.reshape(combined_label, [-1, 2]) 44 | flat_network_eval = tf.reshape(network_eval_batch, [-1, 2]) 45 | loss = tf.losses.softmax_cross_entropy(flat_combined_label, flat_network_eval) 46 | # Could do something fancy with counting TP/FP/TN/FN based on a softmax/argmax between the 2 47 | errors = None 48 | return loss, errors 49 | 50 | 51 | def gen_loss_xyhot(network_eval_batch, label_placeholder, input_size, nbins): 52 | xhot_est, yhot_est = tf.unstack(network_eval_batch) 53 | xhot, yhot = ellreg_to_xyhot(label_placeholder, nbins, nbins/input_size) 54 | loss1 = tf.reduce_mean(-tf.reduce_sum(xhot * tf.log(xhot_est), reduction_indices=[1])) 55 | loss2 = tf.reduce_mean(-tf.reduce_sum(yhot * tf.log(yhot_est), reduction_indices=[1])) 56 | loss = tf.reduce_mean(loss1 + loss2) 57 | xerr = tf.reduce_mean(tf.abs(tf.subtract(tf.cast(tf.argmax(xhot_est, 1), tf.float32),tf.cast(tf.argmax(xhot,1),tf.float32)))) 58 | yerr = tf.reduce_mean(tf.abs(tf.subtract(tf.cast(tf.argmax(yhot_est, 1), tf.float32),tf.cast(tf.argmax(yhot,1),tf.float32)))) 59 | errors = tf.stack([xerr/nbins*input_size, yerr/nbins*input_size]) 60 | return loss, errors 61 | 62 | def gen_loss_rotate(rotations, label_placeholder): 63 | label = tf.slice(label_placeholder,[0,4],[-1,2]) 64 | loss = slim.losses.mean_squared_error(rotations, label) 65 | errors = tf.multiply(tf.reduce_mean(tf.abs(tf.subtract(rotations, label)), reduction_indices=0), tf.slice(scale,[4],[2])) 66 | return loss, errors 67 | 68 | # angle_probs is of size [batch, 4] 69 | # label_placeholder is of size [batch, 6] where [batch, 4] is sin(angle) and [batch,5] is cos(angle) 70 | def gen_loss_anglequadrant(angle_probs, label_placeholder): 71 | label_sin = tf.add(tf.multiply(tf.slice(label_placeholder,[0,4],[-1,1]), scale[4]), means[4]) 72 | label_cos = tf.add(tf.multiply(tf.slice(label_placeholder,[0,5],[-1,1]), scale[5]), means[5]) 73 | label_probs = tf.zeros_like(angle_probs)+[0.,0.,1.,0.] # Default to choice 3: Everything incorrect 74 | label_probs = tf.where(tf.squeeze(tf.greater(label_cos,np.sin(np.pi/4.))), tf.zeros_like(angle_probs)+[1.,0.,0.,0.], label_probs) # Choice 1: Everything is correct 75 | label_probs = tf.where(tf.squeeze(tf.greater(label_sin,np.sin(np.pi/4.))), tf.zeros_like(angle_probs)+[0.,1.,0.,0.], label_probs) # Choice 2: Fix when sin prediction < 0.707 76 | label_probs = tf.where(tf.squeeze(tf.less(label_sin,-np.sin(np.pi/4.))), tf.zeros_like(angle_probs)+[0.,0.,0.,1.], label_probs) # Choice 4: Fix when sin prediction > -0.707 77 | 78 | loss = tf.losses.softmax_cross_entropy(label_probs, angle_probs) 79 | return loss 80 | 81 | 82 | def gen_summary_ellreg(loss, errors, angle_errs, learn_rate): 83 | learn_rate_summary = tf.summary.scalar('training/learn_rate', learn_rate) 84 | valid_loss_summary = tf.summary.scalar('validation/losses/loss_ellfit', loss) 85 | valid_xerr_summary = tf.summary.scalar('validation/xy_error/xErr', errors[0]) 86 | valid_yerr_summary = tf.summary.scalar('validation/xy_error/yErr', errors[1]) 87 | valid_minerr_summary = tf.summary.scalar('validation/axis_error/minErr', errors[2]) 88 | valid_majerr_summary = tf.summary.scalar('validation/axis_error/majErr', errors[3]) 89 | valid_sinerr_summary = tf.summary.scalar('validation/dir_error/sinAngErr', errors[4]) 90 | valid_coserr_summary = tf.summary.scalar('validation/dir_error/cosAngErr', errors[5]) 91 | valid_angle_summary = tf.summary.scalar('validation/dir_error/degAngErr', angle_errs) 92 | validation_summary = tf.summary.merge([valid_loss_summary, valid_xerr_summary, valid_yerr_summary, valid_minerr_summary, valid_majerr_summary, valid_sinerr_summary, valid_coserr_summary, valid_angle_summary]) 93 | train_loss_summary = tf.summary.scalar('training/losses/loss_ellfit', loss) 94 | train_xerr_summary = tf.summary.scalar('training/xy_error/xErr', errors[0]) 95 | train_yerr_summary = tf.summary.scalar('training/xy_error/yErr', errors[1]) 96 | train_minerr_summary = tf.summary.scalar('training/axis_error/minErr', errors[2]) 97 | train_majerr_summary = tf.summary.scalar('training/axis_error/majErr', errors[3]) 98 | train_sinerr_summary = tf.summary.scalar('training/dir_error/sinAngErr', errors[4]) 99 | train_coserr_summary = tf.summary.scalar('training/dir_error/cosAngErr', errors[5]) 100 | train_angle_summary = tf.summary.scalar('training/dir_error/degAngErr', angle_errs) 101 | training_summary = tf.summary.merge([train_loss_summary, train_xerr_summary, train_yerr_summary, train_minerr_summary, train_majerr_summary, train_sinerr_summary, train_coserr_summary, learn_rate_summary, train_angle_summary]) 102 | return training_summary, validation_summary 103 | 104 | def gen_summary_seg(loss, errors, learn_rate): 105 | learn_rate_summary = tf.summary.scalar('training/learn_rate', learn_rate) 106 | valid_loss_summary = tf.summary.scalar('validation/losses/loss_seg', loss) 107 | validation_summary = tf.summary.merge([valid_loss_summary]) 108 | train_loss_summary = tf.summary.scalar('training/losses/loss', loss) 109 | training_summary = tf.summary.merge([train_loss_summary, learn_rate_summary]) 110 | return training_summary, validation_summary 111 | 112 | def gen_summary_xyhot(loss, errors, learn_rate): 113 | learn_rate_summary = tf.summary.scalar('training/learn_rate', learn_rate) 114 | valid_loss_summary = tf.summary.scalar('validation/losses/loss_xyhot', loss) 115 | valid_xerr_summary = tf.summary.scalar('validation/xy_error/xErr', errors[0]) 116 | valid_yerr_summary = tf.summary.scalar('validation/xy_error/yErr', errors[1]) 117 | validation_summary = tf.summary.merge([valid_loss_summary, valid_xerr_summary, valid_yerr_summary]) 118 | train_loss_summary = tf.summary.scalar('training/losses/loss', loss) 119 | train_xerr_summary = tf.summary.scalar('training/xy_error/xErr', errors[0]) 120 | train_yerr_summary = tf.summary.scalar('training/xy_error/yErr', errors[1]) 121 | training_summary = tf.summary.merge([train_loss_summary, train_xerr_summary, train_yerr_summary, learn_rate_summary]) 122 | return training_summary, validation_summary 123 | 124 | def gen_train_op_adam(loss, train_size, batch_size, global_step, init_learn_rate = 1e-3, num_epochs_per_decay = 50, const_learn_rate = False): 125 | num_batches_per_epoch = ((train_size) / batch_size) 126 | decay_steps = int(num_batches_per_epoch * num_epochs_per_decay) 127 | learning_rate_decay_factor = 0.15 128 | if const_learn_rate: 129 | learn_rate = init_learn_rate 130 | else: 131 | learn_rate = tf.train.exponential_decay(init_learn_rate, global_step, decay_steps, learning_rate_decay_factor, staircase=False) 132 | optimizer = tf.train.AdamOptimizer(learn_rate) 133 | # 134 | train_op = slim.learning.create_train_op(loss, optimizer) 135 | #train_op = optimizer.minimize(loss, global_step=global_step, colocate_gradients_with_ops=True) 136 | return learn_rate, train_op 137 | 138 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | # My collection of available network models... 2 | 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | from tensorflow.contrib.slim.nets import resnet_v2 6 | from tensorflow.contrib.slim.nets import resnet_utils 7 | from tensorflow.contrib.slim.nets import inception 8 | from tensorflow.contrib.slim.nets import vgg 9 | from .readers import means, scale, atan2 10 | import scipy.ndimage.morphology as morph 11 | import numpy as np 12 | 13 | # Concats x/y gradients along the depth dimension 14 | # Used in coordconvs 15 | # Note: You should call with tf.map_fn(lambda by_batch: concat_xygrad_2d(by_batch), input_tensor) 16 | def concat_xygrad_2d(input_tensor): 17 | input_shape = [int(x) for i,x in enumerate(input_tensor.get_shape())] 18 | xgrad = tf.reshape(tf.tile([tf.lin_space(0.0,1.0,input_shape[-2])],[input_shape[-3],1]),np.concatenate([input_shape[0:-1],[1]])) 19 | ygrad = tf.reshape(tf.tile(tf.reshape([tf.lin_space(0.0,1.0,input_shape[-3])],[input_shape[-3],1]),[1,input_shape[-2]]),np.concatenate([input_shape[0:-1],[1]])) 20 | return tf.concat([input_tensor, xgrad, ygrad], axis=-1) 21 | 22 | # Fits an ellipse from a mask 23 | # Assumes that the mask is of size [?,3], where [:,0] are x indices and [:,1] are y indices 24 | def fitEll(mask): 25 | locs = tf.cast(tf.slice(mask,[0,0],[-1,2]),tf.float32) 26 | translations = tf.reduce_mean(locs, 0) 27 | sqlocs = tf.square(locs) 28 | variance = tf.reduce_mean(sqlocs,0)-tf.square(translations) 29 | variance_xy = tf.reduce_mean(tf.reduce_prod(locs, 1),0)-tf.reduce_prod(translations,0) 30 | translations = tf.reverse(translations,[0]) # Note: Moment across X-values gives you y location, so need to reverse 31 | tmp1 = tf.reduce_sum(variance) 32 | tmp2 = tf.sqrt(tf.multiply(4.0,tf.pow(variance_xy,2))+tf.pow(tf.reduce_sum(tf.multiply(variance,[1.0,-1.0])),2)) 33 | eigA = tf.multiply(tf.sqrt((tmp1+tmp2)/2.0),4.0) 34 | eigB = tf.multiply(tf.sqrt((tmp1-tmp2)/2.0),4.0) 35 | angle = 0.5*atan2(2.0*variance_xy,tf.reduce_sum(tf.multiply(variance,[1.0,-1.0]))) # Radians 36 | ellfit = tf.stack([tf.slice(translations,[0],[1]),tf.slice(translations,[1],[1]),[eigB],[eigA],[tf.sin(angle)],[tf.cos(angle)]],1) 37 | return tf.reshape(tf.divide(tf.subtract(ellfit,means),scale),[-1]) 38 | 39 | # It appears that the issue for running this is due to nested loops in the optimizer (cannot train). 40 | # https://github.com/tensorflow/tensorflow/issues/3726 41 | # Both tf.where and tf.gather_nd use loops 42 | # This can be used during inference to get slightly better results (by changing the line in the fitEllFromSeg definition). 43 | def fitEll_weighted(mask, seg): 44 | locs_orig = tf.cast(tf.slice(mask,[0,0],[-1,2]),tf.float32) 45 | weights = tf.gather_nd(seg, mask) 46 | # Normalize to sum of 1 47 | weights_orig = tf.exp(tf.divide(weights,tf.reduce_sum(weights))) 48 | weights_orig = tf.divide(weights_orig,tf.reduce_sum(weights_orig)) 49 | weights = tf.reshape(tf.tile(weights_orig,[2]),[-1,2]) 50 | # This is the line that breaks it: 51 | locs = tf.multiply(locs_orig,weights) 52 | translations = tf.reduce_sum(locs, 0) # Note: Moment across X-values gives you y location, so need to reverse. This is changed on the return values (index 1, then index 0) 53 | sqlocs = tf.multiply(tf.square(locs_orig),weights) 54 | variance = tf.reduce_sum(sqlocs,0)-tf.square(translations) 55 | variance_xy = tf.reduce_sum(tf.reduce_prod(locs_orig, 1)*weights_orig,0)-tf.reduce_prod(translations,0) 56 | tmp1 = tf.reduce_sum(variance) 57 | tmp2 = tf.sqrt(tf.multiply(4.0,tf.pow(variance_xy,2))+tf.pow(tf.reduce_sum(tf.multiply(variance,[1.0,-1.0])),2)) 58 | eigA = tf.multiply(tf.sqrt((tmp1+tmp2)/2.0),4.0) 59 | eigB = tf.multiply(tf.sqrt((tmp1-tmp2)/2.0),4.0) 60 | angle = 0.5*atan2(2.0*variance_xy,tf.reduce_sum(tf.multiply(variance,[1.0,-1.0]))) # Radians 61 | ellfit = tf.stack([tf.slice(translations,[1],[1]),tf.slice(translations,[0],[1]),[eigB],[eigA],[tf.sin(angle)],[tf.cos(angle)]],1) 62 | return tf.reshape(tf.divide(tf.subtract(ellfit,means),scale),[-1]) 63 | 64 | # Safely applies the threshold to the mask and returns default values if no indices are classified as mouse 65 | def fitEllFromSeg(seg, node_act): 66 | mask = tf.where(tf.greater(seg,node_act)) 67 | # NOTE: See note on fitEll_weighted function definition 68 | #return tf.cond(tf.shape(mask)[0]>0, lambda: fitEll_weighted(mask, seg), lambda: tf.to_float([-1.0,-1.0,-1.0,-1.0,-1.0,-1.0])) 69 | return tf.cond(tf.shape(mask)[0]>0, lambda: fitEll(mask), lambda: tf.to_float([-1.0,-1.0,-1.0,-1.0,-1.0,-1.0])) 70 | 71 | 72 | ########################################################################## 73 | # Begin defining all available models 74 | ########################################################################## 75 | def construct_segellreg_v8(images, is_training): 76 | batch_norm_params = {'is_training': is_training, 'decay': 0.999, 'updates_collections': None, 'center': True, 'scale': True, 'trainable': True} 77 | # Normalize the image inputs (map_fn used to do a "per batch" calculation) 78 | norm_imgs = tf.map_fn(lambda img: tf.image.per_image_standardization(img), images) 79 | kern_size = [5,5] 80 | filter_size = 8 81 | with tf.variable_scope('SegmentEncoder'): 82 | with slim.arg_scope([slim.conv2d], 83 | activation_fn=tf.nn.relu, 84 | padding='SAME', 85 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 86 | weights_regularizer=slim.l2_regularizer(0.0005), 87 | normalizer_fn=slim.batch_norm, 88 | normalizer_params=batch_norm_params): 89 | c1 = slim.conv2d(norm_imgs, filter_size, kern_size) 90 | p1 = slim.max_pool2d(c1, [2,2], scope='pool1') #240x240 91 | c2 = slim.conv2d(p1, filter_size*2, kern_size) 92 | p2 = slim.max_pool2d(c2, [2,2], scope='pool2') #120x120 93 | c3 = slim.conv2d(p2, filter_size*4, kern_size) 94 | p3 = slim.max_pool2d(c3, [2,2], scope='pool3') #60x60 95 | c4 = slim.conv2d(p3, filter_size*8, kern_size) 96 | p4 = slim.max_pool2d(c4, [2,2], scope='pool4') # 30x30 97 | c5 = slim.conv2d(p4, filter_size*16, kern_size) 98 | p5 = slim.max_pool2d(c5, [2,2], scope='pool5') # 15x15 99 | c6 = slim.conv2d(p5, filter_size*32, kern_size) 100 | p6 = slim.max_pool2d(c6, [3,3], stride=3, scope='pool6') # 5x5 101 | c7 = slim.conv2d(p6, filter_size*64, kern_size) 102 | with tf.variable_scope('SegmentDecoder'): 103 | upscale = 2 # Undo the pools once at a time 104 | mynet = slim.conv2d_transpose(c7, filter_size*32, kern_size, stride=[3, 3], activation_fn=None) 105 | mynet = tf.add(mynet, c6) 106 | mynet = slim.conv2d_transpose(mynet, filter_size*16, kern_size, stride=[upscale, upscale], activation_fn=None) 107 | mynet = tf.add(mynet, c5) 108 | mynet = slim.conv2d_transpose(mynet, filter_size*8, kern_size, stride=[upscale, upscale], activation_fn=None) 109 | mynet = tf.add(mynet, c4) 110 | mynet = slim.conv2d_transpose(mynet, filter_size*4, kern_size, stride=[upscale, upscale], activation_fn=None) 111 | mynet = tf.add(mynet, c3) 112 | mynet = slim.conv2d_transpose(mynet, filter_size*2, kern_size, stride=[upscale, upscale], activation_fn=None) 113 | mynet = tf.add(mynet, c2) 114 | mynet = slim.conv2d_transpose(mynet, filter_size, kern_size, stride=[upscale, upscale], activation_fn=None) 115 | mynet = tf.add(mynet, c1) 116 | seg = slim.conv2d(mynet, 2, [1,1], scope='seg') 117 | with tf.variable_scope('Ellfit'): 118 | seg_morph = tf.slice(tf.nn.softmax(seg,-1),[0,0,0,0],[-1,-1,-1,1])-tf.slice(tf.nn.softmax(seg,-1),[0,0,0,1],[-1,-1,-1,1]) 119 | # And was kept here to just assist in the ellipse-fit for any unwanted noise 120 | filter1 = tf.expand_dims(tf.constant(morph.iterate_structure(morph.generate_binary_structure(2,1),4),dtype=tf.float32),-1) 121 | seg_morph = tf.nn.dilation2d(tf.nn.erosion2d(seg_morph,filter1,[1,1,1,1],[1,1,1,1],"SAME"),filter1,[1,1,1,1],[1,1,1,1],"SAME") 122 | filter2 = tf.expand_dims(tf.constant(morph.iterate_structure(morph.generate_binary_structure(2,1),5),dtype=tf.float32),-1) 123 | seg_morph = tf.nn.erosion2d(tf.nn.dilation2d(seg_morph,filter2,[1,1,1,1],[1,1,1,1],"SAME"),filter2,[1,1,1,1],[1,1,1,1],"SAME") 124 | node_act = tf.constant(0.0,dtype=tf.float32) 125 | # Fit the ellipse from the segmentation mask algorithmically 126 | ellfit = tf.map_fn(lambda mask: fitEllFromSeg(mask, node_act), seg_morph) 127 | with tf.variable_scope('AngleFix'): 128 | mynet = slim.conv2d(c7, 128, kern_size, activation_fn=tf.nn.relu, padding='SAME', weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0005), normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params) 129 | mynet = slim.conv2d(mynet, 64, kern_size, activation_fn=tf.nn.relu, padding='SAME', weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0005), normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params) 130 | mynet = slim.flatten(mynet) 131 | angle_bins = slim.fully_connected(mynet, 4, activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='angle_bin') 132 | angles = tf.add(tf.multiply(tf.slice(ellfit, [0,4], [-1,2]), scale[4:5]), means[4:5]) # Extract angles to fix them 133 | sin_angles = tf.slice(angles,[0,0],[-1,1]) # Unmorph the sin(angles) 134 | ang_bins_max = tf.argmax(angle_bins,1) # Note: This is from 0-3, not 1-4 135 | angles = tf.where(tf.equal(ang_bins_max,2), -angles, angles) # Bin 3 always wrong 136 | angles = tf.where(tf.logical_and(tf.equal(ang_bins_max,1), tf.squeeze(tf.less(sin_angles, 0.0))), -angles, angles) # Bin 2 is wrong when sin(ang) < np.sin(np.pi/4.) ... Some bleedover, so < 0.0 137 | angles = tf.where(tf.logical_and(tf.equal(ang_bins_max,3), tf.squeeze(tf.greater(sin_angles, 0.0))), -angles, angles) # Bin 4 is wrong when sin(ang) > -np.sin(np.pi/4.) ... Some bleedover, so > 0.0 138 | angles = tf.divide(tf.subtract(angles, means[4:5]), scale[4:5]) 139 | original = tf.slice(ellfit,[0,0],[-1,4]) 140 | ellfit = tf.concat([original, angles],1) 141 | 142 | return seg, ellfit, angle_bins 143 | 144 | 145 | # XY binning for 480 x and 480 y bins 146 | def construct_xybin_v1(images, is_training, n_bins): 147 | batch_norm_params = {'is_training': is_training, 'decay': 0.8, 'updates_collections': None, 'center': True, 'scale': True, 'trainable': True} 148 | with slim.arg_scope([slim.conv2d], 149 | activation_fn=tf.nn.relu, 150 | padding='SAME', 151 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 152 | weights_regularizer=slim.l2_regularizer(0.0005), 153 | normalizer_fn=slim.batch_norm, 154 | normalizer_params=batch_norm_params): 155 | mynet = slim.repeat(images, 2, slim.conv2d, 16, [3,3], scope='conv1') 156 | mynet = slim.max_pool2d(mynet, [2,2], scope='pool1') 157 | mynet = slim.repeat(mynet, 2, slim.conv2d, 32, [3,3], scope='conv2') 158 | mynet = slim.max_pool2d(mynet, [2,2], scope='pool2') 159 | mynet = slim.repeat(mynet, 2, slim.conv2d, 64, [3,3], scope='conv3') 160 | mynet = slim.max_pool2d(mynet, [2,2], scope='pool3') 161 | mynet = slim.repeat(mynet, 2, slim.conv2d, 128, [3,3], scope='conv4') 162 | mynet = slim.max_pool2d(mynet, [2,2], scope='pool4') 163 | mynet = slim.repeat(mynet, 2, slim.conv2d, 256, [3,3], scope='conv5') 164 | mynet = slim.max_pool2d(mynet, [2,2], scope='pool5') 165 | features = slim.flatten(mynet, scope='flatten') 166 | with slim.arg_scope([slim.fully_connected], 167 | activation_fn=tf.nn.relu, 168 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 169 | weights_regularizer=slim.l2_regularizer(0.0005), 170 | normalizer_fn=slim.batch_norm, 171 | normalizer_params=batch_norm_params): 172 | # To add additional fully connected layers... 173 | # Our tests showed no substantial difference 174 | #mynet = slim.fully_connected(mynet, 4096, scope='fc5') 175 | #mynet = slim.dropout(mynet, 0.5, scope='dropout5') 176 | #mynet = slim.fully_connected(mynet, 4096, scope='fc6') 177 | #mynet = slim.dropout(mynet, 0.5, scope='dropout6') 178 | xbins = slim.fully_connected(features, n_bins, activation_fn=None, scope='xbins') 179 | xbins = slim.softmax(xbins, scope='smx') 180 | ybins = slim.fully_connected(features, n_bins, activation_fn=None, scope='ybins') 181 | ybins = slim.softmax(ybins, scope='smy') 182 | mynet = tf.stack([xbins, ybins]) 183 | return mynet, features 184 | 185 | # Attempt to predict the ellipse-regression directly (using resnet_v2_200) 186 | def construct_ellreg_v3_resnet(images, is_training): 187 | batch_norm_params = {'is_training': is_training, 'decay': 0.8, 'updates_collections': None, 'center': True, 'scale': True, 'trainable': True} 188 | mynet, _ = resnet_v2.resnet_v2_200(images, None, is_training=is_training) 189 | features = tf.reshape(mynet, [-1, 2048]) 190 | with slim.arg_scope([slim.fully_connected], 191 | activation_fn=tf.nn.relu, 192 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 193 | weights_regularizer=slim.l2_regularizer(0.0005), 194 | normalizer_fn=slim.batch_norm, 195 | normalizer_params=batch_norm_params): 196 | mynet = slim.fully_connected(features, 6, activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='outlayer') 197 | return mynet, features 198 | 199 | # Attempt to predict the ellipse-regression directly with coordinate convs 200 | def construct_ellreg_v4_resnet(images, is_training): 201 | batch_norm_params = {'is_training': is_training, 'decay': 0.8, 'updates_collections': None, 'center': True, 'scale': True, 'trainable': True} 202 | input_imgs = tf.map_fn(lambda by_batch: concat_xygrad_2d(by_batch), images) 203 | mynet, _ = resnet_v2.resnet_v2_200(input_imgs, None, is_training=is_training) 204 | features = tf.reshape(mynet, [-1, 2048]) 205 | with slim.arg_scope([slim.fully_connected], 206 | activation_fn=tf.nn.relu, 207 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 208 | weights_regularizer=slim.l2_regularizer(0.0005), 209 | normalizer_fn=slim.batch_norm, 210 | normalizer_params=batch_norm_params): 211 | mynet = slim.fully_connected(features, 6, activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='outlayer') 212 | return mynet, features 213 | 214 | 215 | # Segmentation Only Network (no angle prediction) 216 | def construct_segsoft_v5(images, is_training): 217 | batch_norm_params = {'is_training': is_training, 'decay': 0.999, 'updates_collections': None, 'center': True, 'scale': True, 'trainable': True} 218 | # Normalize the image inputs (map_fn used to do a "per batch" calculation) 219 | norm_imgs = tf.map_fn(lambda img: tf.image.per_image_standardization(img), images) 220 | kern_size = [5,5] 221 | filter_size = 8 222 | # Run the segmentation net without pooling 223 | with tf.variable_scope('SegmentEncoder'): 224 | with slim.arg_scope([slim.conv2d], 225 | activation_fn=tf.nn.relu, 226 | padding='SAME', 227 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 228 | weights_regularizer=slim.l2_regularizer(0.0005), 229 | normalizer_fn=slim.batch_norm, 230 | normalizer_params=batch_norm_params): 231 | c1 = slim.conv2d(norm_imgs, filter_size, kern_size) 232 | p1 = slim.max_pool2d(c1, [2,2], scope='pool1') #240x240 233 | c2 = slim.conv2d(p1, filter_size*2, kern_size) 234 | p2 = slim.max_pool2d(c2, [2,2], scope='pool2') #120x120 235 | c3 = slim.conv2d(p2, filter_size*4, kern_size) 236 | p3 = slim.max_pool2d(c3, [2,2], scope='pool3') #60x60 237 | c4 = slim.conv2d(p3, filter_size*8, kern_size) 238 | p4 = slim.max_pool2d(c4, [2,2], scope='pool4') # 30x30 239 | c5 = slim.conv2d(p4, filter_size*16, kern_size) 240 | p5 = slim.max_pool2d(c5, [2,2], scope='pool5') # 15x15 241 | c6 = slim.conv2d(p5, filter_size*32, kern_size) 242 | p6 = slim.max_pool2d(c6, [3,3], stride=3, scope='pool6') # 5x5 243 | c7 = slim.conv2d(p6, filter_size*64, kern_size) 244 | with tf.variable_scope('SegmentDecoder'): 245 | upscale = 2 # Undo the pools once at a time 246 | mynet = slim.conv2d_transpose(c7, filter_size*32, kern_size, stride=[3, 3], activation_fn=None) 247 | mynet = tf.add(mynet, c6) 248 | mynet = slim.conv2d_transpose(mynet, filter_size*16, kern_size, stride=[upscale, upscale], activation_fn=None) 249 | mynet = tf.add(mynet, c5) 250 | mynet = slim.conv2d_transpose(mynet, filter_size*8, kern_size, stride=[upscale, upscale], activation_fn=None) 251 | mynet = tf.add(mynet, c4) 252 | mynet = slim.conv2d_transpose(mynet, filter_size*4, kern_size, stride=[upscale, upscale], activation_fn=None) 253 | mynet = tf.add(mynet, c3) 254 | mynet = slim.conv2d_transpose(mynet, filter_size*2, kern_size, stride=[upscale, upscale], activation_fn=None) 255 | mynet = tf.add(mynet, c2) 256 | mynet = slim.conv2d_transpose(mynet, filter_size, kern_size, stride=[upscale, upscale], activation_fn=None) 257 | mynet = tf.add(mynet, c1) 258 | seg = slim.conv2d(mynet, 2, [1,1], scope='seg') 259 | return seg 260 | -------------------------------------------------------------------------------- /utils/readers.py: -------------------------------------------------------------------------------- 1 | # Handles generation of the data readers and manipulators 2 | # Also includes the re-scaling values (to optionally be applied) 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | from .transformer import * 7 | 8 | # Reshaping Ellipse Labels 9 | means = [0., 0., 10., 10., -1., -1.] # All positive, min/maj axis mins are 16.97/20.75 10 | scale = [480., 480., 60., 120., 2., 2.] # min/maj axis are 45.9/106.3 max 11 | 12 | ################################################ 13 | # Base reading and augmentation primitives 14 | # Other functions are based on these 15 | 16 | # Returns the angle in radians 17 | def atan2(y, x, epsilon=1.0e-12): 18 | # Add a small number to all zeros, to avoid division by zero: 19 | x = tf.where(tf.equal(x, 0.0), x+epsilon, x) 20 | y = tf.where(tf.equal(y, 0.0), y+epsilon, y) 21 | 22 | angle = tf.where(tf.greater(x,0.0), tf.atan(y/x), tf.zeros_like(x)) 23 | angle = tf.where(tf.logical_and(tf.less(x,0.0), tf.greater_equal(y,0.0)), tf.atan(y/x) + np.pi, angle) 24 | angle = tf.where(tf.logical_and(tf.less(x,0.0), tf.less(y,0.0)), tf.atan(y/x) - np.pi, angle) 25 | angle = tf.where(tf.logical_and(tf.equal(x,0.0), tf.greater(y,0.0)), 0.5*np.pi * tf.ones_like(x), angle) 26 | angle = tf.where(tf.logical_and(tf.equal(x,0.0), tf.less(y,0.0)), -0.5*np.pi * tf.ones_like(x), angle) 27 | angle = tf.where(tf.logical_and(tf.equal(x,0.0), tf.equal(y,0.0)), tf.zeros_like(x), angle) 28 | return angle 29 | 30 | # Reads a single input image 31 | def read_image(filename, input_size): 32 | image_contents = tf.read_file(filename) 33 | image = tf.image.decode_png(image_contents, channels=1) 34 | image = tf.image.resize_images(image, [input_size, input_size]) 35 | return image 36 | 37 | # Augments a single input image 38 | def augment_image(image, input_size, noise_std=5.0, bright_percent=0.05, contrast_percent=0.05): 39 | image = tf.add(image, tf.random_normal([input_size, input_size, 1], stddev=noise_std)) # Random Noise 40 | image = tf.image.random_brightness(image, bright_percent) 41 | image = tf.image.random_contrast(image, 1.0-contrast_percent, 1.0+contrast_percent) 42 | return image 43 | 44 | # Reads in a single ellipse fit 45 | def read_ellipse(filename): 46 | ellfit = tf.read_file(filename) 47 | record_defaults = [[0.0],[0.0],[0.0],[0.0],[0.0]] 48 | ellfit = tf.string_to_number(tf.string_split([ellfit],delimiter='\t').values) 49 | ellfit = tf.stack([ellfit[0], ellfit[1], ellfit[2], ellfit[3], tf.sin(ellfit[4]*np.pi/180.0), tf.cos(ellfit[4]*np.pi/180.0)]) 50 | ellfit = tf.div(tf.subtract(ellfit, means), scale) 51 | return ellfit 52 | 53 | # X/Y/Diag mirroring 54 | # Applies to the reference, ellipse-fit, and segmentation 55 | def rand_flip_input(reference, ellfit=None, seg=None): 56 | # Random transformations 57 | # 0-7 (or 8-fold increase) 58 | # 0 = normal, 1 = h, 2 = v, 3 = h + v, 4 = t, 5 = t + h, 6 = t + v, 7 = t + h + v 59 | # Note: randnum%2 == 1 for HF operation 60 | # randnum/2%2 == 1 for VF operation 61 | # randnum/4%2 == 1 for T operation 62 | randnum = tf.random_uniform([1], minval=0, maxval=8, dtype=tf.int32)[0] 63 | # Horizontal flipping 64 | reference = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(reference), lambda: tf.image.flip_left_right(reference)) 65 | if ellfit is not None: 66 | ellfit = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(ellfit), lambda: tf.stack([tf.subtract(1.,tf.unstack(ellfit)[0]), tf.unstack(ellfit)[1], tf.unstack(ellfit)[2], tf.unstack(ellfit)[3], tf.subtract(1.,tf.unstack(ellfit)[4]), tf.unstack(ellfit)[5]])) # -sin for horizontal flip 67 | if seg is not None: 68 | seg = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(seg), lambda: tf.image.flip_left_right(seg)) 69 | 70 | # Vertical flipping 71 | randnum = tf.div(randnum,2) 72 | reference = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(reference), lambda: tf.image.flip_up_down(reference)) 73 | if ellfit is not None: 74 | ellfit = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(ellfit), lambda: tf.stack([tf.unstack(ellfit)[0], tf.subtract(1.,tf.unstack(ellfit)[1]), tf.unstack(ellfit)[2], tf.unstack(ellfit)[3], tf.unstack(ellfit)[4], tf.subtract(1.,tf.unstack(ellfit)[5])])) #-cos for vertical flip 75 | if seg is not None: 76 | seg = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(seg), lambda: tf.image.flip_up_down(seg)) 77 | 78 | # Transpose 79 | randnum = tf.div(randnum,2) 80 | reference = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(reference), lambda: tf.image.transpose_image(reference)) 81 | if ellfit is not None: 82 | ellfit = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(ellfit), lambda: tf.stack([tf.unstack(ellfit)[1], tf.unstack(ellfit)[0], tf.unstack(ellfit)[2], tf.unstack(ellfit)[3], tf.unstack(ellfit)[5], tf.unstack(ellfit)[4]])) # sin/cos reversed for transpose 83 | if seg is not None: 84 | seg = tf.cond(tf.mod(randnum,2)<1, lambda: tf.identity(seg), lambda: tf.image.transpose_image(seg)) 85 | return reference, ellfit, seg 86 | 87 | # Rotation and Translation augmentations 88 | # Applies to the reference, ellipse-fit, and segmentation 89 | def shift_augment(reference, ellfit=None, seg=None, max_trans=15.0, max_rot=5.0): 90 | # For segmentation-only rotate around the middle + random noise 91 | if ellfit is None: 92 | ellfit = [0.5,0.5,0.5,0.5,0.5,0.5] 93 | # Generate the random values 94 | randTransX = tf.div(tf.subtract(tf.random_uniform([1], minval=-1.0, maxval=1.0)[0]*max_trans*2,means[0]),scale[0]) 95 | randTransY = tf.div(tf.subtract(tf.random_uniform([1], minval=-1.0, maxval=1.0)[0]*max_trans*2,means[1]),scale[1]) 96 | randRot = tf.random_uniform([1], minval=-1.0, maxval=1.0)[0]*max_rot*np.pi/180.0 97 | # Enforce some boundaries for keeping the mouse on in view... 98 | randTransX = tf.cond(tf.unstack(ellfit)[0]-randTransX>0, lambda: tf.identity(randTransX), lambda: tf.multiply(-1.0, randTransX)) 99 | randTransX = tf.cond(tf.unstack(ellfit)[0]-randTransX<1, lambda: tf.identity(randTransX), lambda: tf.multiply(-1.0, randTransX)) 100 | randTransY = tf.cond(tf.unstack(ellfit)[1]-randTransY>0, lambda: tf.identity(randTransY), lambda: tf.multiply(-1.0, randTransY)) 101 | randTransY = tf.cond(tf.unstack(ellfit)[1]-randTransY<1, lambda: tf.identity(randTransY), lambda: tf.multiply(-1.0, randTransY)) 102 | 103 | # Apply the transform to the reference 104 | reference = tf.reshape(reference, [1, tf.shape(reference)[0], tf.shape(reference)[1], 1]) 105 | alpha = tf.cos(randRot) 106 | beta = tf.sin(randRot) 107 | xloc = tf.subtract(tf.multiply(tf.unstack(ellfit)[0]-randTransX/2.0,2.0),1.0) # Range of -1 to 1 108 | yloc = tf.subtract(tf.multiply(tf.unstack(ellfit)[1]-randTransY/2.0,2.0),1.0) # Range of -1 to 1 109 | # This is essentially [opencv's rotate matrix]+[[0,0,xtrans],[0,0,ytrans]] 110 | affine_trans = [[alpha, beta,tf.multiply(1.0-alpha,xloc)-tf.multiply(beta,yloc)+randTransX], [-beta, alpha,tf.multiply(beta,xloc)+tf.multiply(1.0-alpha,yloc)+randTransY]] 111 | reference = transformer(reference, affine_trans, (tf.shape(reference)[1],tf.shape(reference)[2])) 112 | reference = tf.reshape(reference, [tf.shape(reference)[1],tf.shape(reference)[2], 1]) 113 | 114 | # Apply the transform to the seg 115 | if seg is not None: 116 | seg = tf.reshape(seg, [1, tf.shape(seg)[0], tf.shape(seg)[1], 1]) 117 | seg = transformer(seg, affine_trans, (tf.shape(seg)[1],tf.shape(seg)[2])) 118 | seg = tf.reshape(seg, [tf.shape(seg)[1],tf.shape(seg)[2], 1]) 119 | 120 | # Edit the ellipse fit values 121 | angle = atan2(tf.add(tf.multiply(tf.unstack(ellfit)[4], scale[4]), means[4]), tf.add(tf.multiply(tf.unstack(ellfit)[5], scale[5]), means[5])) # In radians 122 | ellfit = tf.stack([tf.unstack(ellfit)[0]-randTransX/2.0, tf.unstack(ellfit)[1]-randTransY/2.0, tf.unstack(ellfit)[2], tf.unstack(ellfit)[3], tf.div(tf.subtract(tf.sin(angle-randRot), means[4]), scale[4]), tf.div(tf.subtract(tf.cos(angle-randRot), means[5]), scale[5])]) 123 | return reference, ellfit, seg 124 | 125 | ################################################ 126 | # For all readers, the following information is constant... 127 | # input_queue[0] = image file path 128 | # input_queue[1] = label file path 129 | 130 | # Reads the image and segmentation values 131 | def read_image_and_seg(input_queue, input_size): 132 | seg = read_image(input_queue[1], input_size) 133 | image = read_image(input_queue[0], input_size) 134 | image, _, seg = rand_flip_input(image, seg=seg) 135 | return image, seg 136 | 137 | # Reads the image and segmentation values 138 | def read_augment_image_and_seg(input_queue, input_size, max_trans=120.0, max_rot=45.0): 139 | image, seg = read_image_and_seg(input_queue, input_size) 140 | image = augment_image(image, input_size) 141 | image, _, seg = shift_augment(image, seg=seg, max_trans=max_trans, max_rot=max_rot) 142 | return image, seg 143 | 144 | # Reads the image and ellipse regression values 145 | def read_image_and_ellreg(input_queue, input_size): 146 | ellfit = read_ellipse(input_queue[1]) 147 | image = read_image(input_queue[0], input_size) 148 | image, ellfit, _ = rand_flip_input(image, ellfit) 149 | return image, ellfit 150 | 151 | # Reads and augments the image and ellipse regression values 152 | def read_augment_image_and_ellreg(input_queue, input_size): 153 | ellfit = read_ellipse(input_queue[1]) 154 | image = read_image(input_queue[0], input_size) 155 | image, ellfit, _ = rand_flip_input(image, ellfit) 156 | image = augment_image(image, input_size) 157 | return image, ellfit 158 | 159 | # Reads and augments the image and ellipse regression values 160 | def read_augment_image_and_ellreg_v2(input_queue, input_size, max_trans=120.0, max_rot=45.0): 161 | ellfit = read_ellipse(input_queue[1]) 162 | image = read_image(input_queue[0], input_size) 163 | image, ellfit, _ = rand_flip_input(image, ellfit) 164 | image = augment_image(image, input_size) 165 | image, ellfit, _ = shift_augment(image, ellfit, max_trans=max_trans, max_rot=max_rot) 166 | return image, ellfit 167 | 168 | # Reads all 3 169 | def read_image_and_seg_and_ellreg(input_queue, input_size): 170 | image = read_image(input_queue[0], input_size) 171 | seg = read_image(input_queue[1], input_size) 172 | ellfit = read_ellipse(input_queue[2]) 173 | return image, seg, ellfit 174 | 175 | # Reads all 3 + augmentation 176 | def read_augment_image_and_seg_and_ellreg(input_queue, input_size, max_trans=120.0, max_rot=45.0): 177 | image, seg, ellfit = read_image_and_seg_and_ellreg(input_queue, input_size) 178 | image, ellfit, seg = rand_flip_input(image, ellfit, seg=seg) 179 | image, ellfit, seg = shift_augment(image, ellfit, seg, max_trans=max_trans, max_rot=max_rot) 180 | image = augment_image(image, input_size) 181 | return image, seg, ellfit 182 | 183 | ################################################ 184 | # Prepares the datasets into two callable generator tensors 185 | # Returns image, label tensor generators 186 | def get_train_batch_ellreg(dataset, read_threads, batch_size, input_size): 187 | inputs = dataset.train_images 188 | inputs2 = dataset.train_labels 189 | input_queue = tf.train.slice_input_producer([inputs, inputs2], shuffle=True) 190 | example_list = [read_augment_image_and_ellreg(input_queue, input_size) for _ in range(read_threads)] 191 | shapes = [[input_size,input_size,1],[6]] 192 | min_after_dequeue = 100 # Always have 100 extra in the queue 193 | capacity = min_after_dequeue + 5 * batch_size 194 | image_batch, label_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 195 | return image_batch, label_batch 196 | 197 | # Prepares the datasets into two callable generator tensors 198 | # Returns image, label tensor generators 199 | # Uses affine transformation augmentation 200 | def get_train_batch_ellreg_v2(dataset, read_threads, batch_size, input_size, max_rot = 45., max_trans = 120.): 201 | inputs = dataset.train_images 202 | inputs2 = dataset.train_labels 203 | input_queue = tf.train.slice_input_producer([inputs, inputs2], shuffle=True) 204 | example_list = [read_augment_image_and_ellreg_v2(input_queue, input_size, max_rot = max_rot, max_trans = max_trans) for _ in range(read_threads)] 205 | shapes = [[input_size,input_size,1],[6]] 206 | min_after_dequeue = 100 # Always have 100 extra in the queue 207 | capacity = min_after_dequeue + 5 * batch_size 208 | image_batch, label_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 209 | return image_batch, label_batch 210 | 211 | # Prepares the datasets into two callable generator tensors 212 | # Returns image, label tensor generators 213 | def get_eval_batch_ellreg(dataset, read_threads, batch_size, input_size): 214 | inputs = dataset.valid_images 215 | inputs2 = dataset.valid_labels 216 | input_queue = tf.train.slice_input_producer([inputs, inputs2], shuffle=True) 217 | example_list = [read_image_and_ellreg(input_queue, input_size) for _ in range(read_threads)] 218 | shapes = [[input_size,input_size,1],[6]] 219 | min_after_dequeue = 100 # Always have 100 extra in the queue 220 | capacity = min_after_dequeue + 5 * batch_size 221 | image_batch, label_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 222 | return image_batch, label_batch 223 | 224 | # Prepares the datasets into two callable generator tensors 225 | # Returns image, label tensor generators 226 | def get_train_batch_segellreg(dataset, read_threads, batch_size, input_size, max_rot = 45., max_trans = 120.): 227 | inputs = dataset.train_images 228 | inputs2 = dataset.train_seg 229 | inputs3 = dataset.train_labels 230 | input_queue = tf.train.slice_input_producer([inputs, inputs2, inputs3], shuffle=True) 231 | example_list = [read_augment_image_and_seg_and_ellreg(input_queue, input_size, max_trans, max_rot) for _ in range(read_threads)] 232 | shapes = [[input_size,input_size,1],[input_size,input_size,1],[6]] 233 | min_after_dequeue = 100 # Always have 100 extra in the queue 234 | capacity = min_after_dequeue + 5 * batch_size 235 | image_batch, seg_batch, ellfit_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 236 | return image_batch, seg_batch, ellfit_batch 237 | 238 | def get_valid_batch_segellreg(dataset, read_threads, batch_size, input_size): 239 | inputs = dataset.valid_images 240 | inputs2 = dataset.valid_seg 241 | inputs3 = dataset.valid_labels 242 | input_queue = tf.train.slice_input_producer([inputs, inputs2, inputs3], shuffle=True) 243 | example_list = [read_image_and_seg_and_ellreg(input_queue, input_size) for _ in range(read_threads)] 244 | shapes = [[input_size,input_size,1],[input_size,input_size,1],[6]] 245 | min_after_dequeue = 100 # Always have 100 extra in the queue 246 | capacity = min_after_dequeue + 5 * batch_size 247 | image_batch, seg_batch, ellfit_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 248 | return image_batch, seg_batch, ellfit_batch 249 | 250 | # Prepares the datasets into two callable generator tensors 251 | # Returns image, label tensor generators 252 | def get_train_batch_seg(dataset, read_threads, batch_size, input_size, max_trans = 0.0, max_rot = 0.0): 253 | inputs = dataset.train_images 254 | inputs2 = dataset.train_seg 255 | input_queue = tf.train.slice_input_producer([inputs, inputs2], shuffle=True) 256 | example_list = [read_augment_image_and_seg(input_queue, input_size, max_trans = max_trans, max_rot = max_rot) for _ in range(read_threads)] 257 | shapes = [[input_size,input_size,1],[input_size,input_size,1]] 258 | min_after_dequeue = 100 # Always have 100 extra in the queue 259 | capacity = min_after_dequeue + 5 * batch_size 260 | image_batch, label_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 261 | return image_batch, label_batch 262 | 263 | def get_eval_batch_seg(dataset, read_threads, batch_size, input_size): 264 | inputs = dataset.valid_images 265 | inputs2 = dataset.valid_seg 266 | input_queue = tf.train.slice_input_producer([inputs, inputs2], shuffle=True) 267 | example_list = [read_image_and_seg(input_queue, input_size) for _ in range(read_threads)] 268 | shapes = [[input_size,input_size,1],[input_size,input_size,1]] 269 | min_after_dequeue = 100 # Always have 100 extra in the queue 270 | capacity = min_after_dequeue + 5 * batch_size 271 | image_batch, label_batch = tf.train.shuffle_batch_join(example_list, batch_size=batch_size, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue) 272 | return image_batch, label_batch 273 | 274 | # Convert ellreg input to a one-hot x/y input 275 | def ellreg_to_xyhot(ellreg, nbins = 4800, bins2px = 10): 276 | label = tf.add(tf.multiply(ellreg, scale), means) 277 | xhot = tf.one_hot(tf.to_int32(label[:,0]*bins2px), nbins) 278 | yhot = tf.one_hot(tf.to_int32(label[:,1]*bins2px), nbins) 279 | return xhot, yhot 280 | 281 | 282 | -------------------------------------------------------------------------------- /utils/TrainNetwork.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | import numpy as np 6 | import cv2 7 | import time 8 | from datetime import datetime 9 | import sys 10 | from .plotters import * 11 | from .readers import * 12 | from .transformer import * 13 | from .models import * 14 | from .datasets import * 15 | from .training import * 16 | 17 | def trainEllregNetwork(arg_dict): 18 | sess = tf.Session() 19 | ########################################## 20 | with tf.variable_scope('Input_Variables'): 21 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 22 | label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], 6]) 23 | is_training = tf.placeholder(tf.bool, [], name='is_training') 24 | ########################################## 25 | with tf.variable_scope('Network'): 26 | print('Constructing model...') 27 | network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training) 28 | with tf.variable_scope('Loss'): 29 | print('Adding loss function...') 30 | loss, errors, angle_errs = gen_loss_ellreg(network_eval_batch, label_placeholder) 31 | ########################################## 32 | with tf.variable_scope('Input_Decoding'): 33 | print('Populating input queues...') 34 | image_valid_batch, label_valid_batch = get_eval_batch_ellreg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size']) 35 | image_train_batch, label_train_batch = get_train_batch_ellreg_v2(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'], max_rot = arg_dict['aug_rot_max'], max_trans = arg_dict['aug_trans_max']) 36 | print('Starting input threads...') 37 | coord = tf.train.Coordinator() 38 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 39 | ########################################## 40 | global_step = tf.Variable(0, name='global_step', trainable=False) 41 | with tf.variable_scope('Optimizer'): 42 | print('Initializing optimizer...') 43 | learn_rate, train_op = arg_dict['learn_function'](loss, arg_dict['dataset'].train_size, arg_dict['batch_size'], global_step, arg_dict['start_learn_rate'], arg_dict['epocs_per_lr_decay'], const_learn_rate=arg_dict['const_learn_rate']) 44 | ########################################## 45 | with tf.variable_scope('Saver'): 46 | print('Generating summaries and savers...') 47 | training_summary, validation_summary = gen_summary_ellreg(loss, errors, angle_errs, learn_rate) 48 | summary_writer = tf.summary.FileWriter(arg_dict['log_dir'], sess.graph) 49 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 50 | ########################################## 51 | print('Initializing model...') 52 | sess.run(tf.global_variables_initializer()) 53 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 54 | saver.restore(sess,arg_dict['network_to_restore']) 55 | print('Beginning training...') 56 | for step in range(0, arg_dict['num_steps']): 57 | start_time = time.time() 58 | img_batch, label_batch = sess.run([image_train_batch, label_train_batch]) 59 | _, train_loss, summary_output, cur_step = sess.run(fetches=[train_op, loss, training_summary, global_step], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: True}) 60 | duration = time.time() - start_time 61 | if (step+1) % 10 == 0: # CMDline updates every 10 steps 62 | examples_per_sec = arg_dict['batch_size'] / duration 63 | sec_per_batch = float(duration) 64 | format_str = ('%s: step %d, loss = %.4f (%.1f examples/sec; %.3f sec/batch)') 65 | print (format_str % (datetime.now(), cur_step, train_loss, examples_per_sec, sec_per_batch)) 66 | if (step+1) % 100 == 0: # Tensorboard updates values every 100 steps 67 | summary_writer.add_summary(summary_output, cur_step) 68 | img_batch, label_batch = sess.run([image_valid_batch, label_valid_batch]) 69 | summary_output = sess.run(fetches=[validation_summary], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: False})[0] 70 | summary_writer.add_summary(summary_output, cur_step) 71 | if (step+1) % 1000 == 0: # Save model every 1k steps 72 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 73 | saver.save(sess, checkpoint_path, global_step=cur_step) 74 | 75 | # Save model after training is terminated... 76 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 77 | saver.save(sess, checkpoint_path, global_step=cur_step) 78 | 79 | 80 | def trainSegEllfitNetwork(arg_dict): 81 | sess = tf.Session() 82 | ########################################## 83 | with tf.variable_scope('Input_Variables'): 84 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 85 | seg_label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 86 | ellfit_label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], 6]) 87 | is_training = tf.placeholder(tf.bool, [], name='is_training') 88 | ########################################## 89 | with tf.variable_scope('Network'): 90 | print('Constructing model...') 91 | network_eval_batch, ellfit_eval_batch, angle_fix_batch = arg_dict['model_construct_function'](image_placeholder, is_training) 92 | with tf.variable_scope('Loss'): 93 | print('Adding loss function...') 94 | loss1, errors1 = gen_loss_seg(network_eval_batch, seg_label_placeholder) 95 | # Alternative loss if we don't want morphological filtering 96 | #loss1, errors1 = gen_loss_seg_nomorph(network_eval_batch, seg_label_placeholder) 97 | loss2, errors2, angle_errs = gen_loss_ellreg(ellfit_eval_batch, ellfit_label_placeholder) 98 | loss3 = gen_loss_anglequadrant(angle_fix_batch, ellfit_label_placeholder) 99 | loss = tf.reduce_mean([loss1, loss2, loss3]) 100 | ########################################## 101 | with tf.variable_scope('Input_Decoding'): 102 | print('Populating input queues...') 103 | image_valid_batch, seg_valid_batch, label_valid_batch = get_valid_batch_segellreg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size']) 104 | image_train_batch, seg_train_batch, label_train_batch = get_train_batch_segellreg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'], max_trans = arg_dict['aug_trans_max'], max_rot = arg_dict['aug_rot_max']) 105 | print('Starting input threads...') 106 | coord = tf.train.Coordinator() 107 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 108 | ########################################## 109 | global_step = tf.Variable(0, name='global_step', trainable=False) 110 | with tf.variable_scope('Optimizer'): 111 | print('Initializing optimizer...') 112 | learn_rate, train_op = arg_dict['learn_function'](loss, arg_dict['dataset'].train_size, arg_dict['batch_size'], global_step, arg_dict['start_learn_rate'], arg_dict['epocs_per_lr_decay'], const_learn_rate=arg_dict['const_learn_rate']) 113 | ########################################## 114 | with tf.variable_scope('Saver'): 115 | print('Generating summaries and savers...') 116 | training_summary1, validation_summary1 = gen_summary_seg(loss1, errors1, learn_rate) 117 | training_summary2, validation_summary2 = gen_summary_ellreg(loss2, errors2, angle_errs, learn_rate) 118 | training_summary = tf.summary.merge([training_summary1, training_summary2]) 119 | validation_summary = tf.summary.merge([validation_summary1, validation_summary2]) 120 | summary_writer = tf.summary.FileWriter(arg_dict['log_dir'], sess.graph) 121 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 122 | ########################################## 123 | print('Initializing model...') 124 | sess.run(tf.global_variables_initializer()) 125 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 126 | saver.restore(sess,arg_dict['network_to_restore']) 127 | print('Beginning training...') 128 | for step in range(0, arg_dict['num_steps']): 129 | start_time = time.time() 130 | img_batch, seg_label_batch, ellfit_label_batch = sess.run([image_train_batch, seg_train_batch, label_train_batch]) 131 | _, train_loss, summary_output, cur_step, ellfit_errs = sess.run(fetches=[train_op, loss, training_summary, global_step, errors2], feed_dict={image_placeholder: img_batch, seg_label_placeholder: seg_label_batch, ellfit_label_placeholder: ellfit_label_batch, is_training: True}) 132 | duration = time.time() - start_time 133 | if (step+1) % 10 == 0: # CMDline updates every 10 steps 134 | examples_per_sec = arg_dict['batch_size'] / duration 135 | sec_per_batch = float(duration) 136 | format_str = ('%s: step %d, loss=%.4f, xerr=%.2f, yerr=%.2f (%.1f examples/sec)') 137 | print (format_str % (datetime.now(), cur_step, train_loss, ellfit_errs[0], ellfit_errs[1], examples_per_sec)) 138 | if (step+1) % 100 == 0: # Tensorboard updates values every 100 steps 139 | summary_writer.add_summary(summary_output, cur_step) 140 | img_batch, seg_label_batch, ellfit_label_batch = sess.run([image_valid_batch, seg_valid_batch, label_valid_batch]) 141 | summary_output = sess.run(fetches=[validation_summary], feed_dict={image_placeholder: img_batch, seg_label_placeholder: seg_label_batch, ellfit_label_placeholder: ellfit_label_batch, is_training: False})[0] 142 | summary_writer.add_summary(summary_output, cur_step) 143 | if (step+1) % 1000 == 0: # Save model every 1k steps 144 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 145 | saver.save(sess, checkpoint_path, global_step=cur_step) 146 | # Save model after training is terminated... 147 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 148 | saver.save(sess, checkpoint_path, global_step=cur_step) 149 | 150 | 151 | def trainBinnedNetwork(arg_dict): 152 | sess = tf.Session() 153 | ########################################## 154 | with tf.variable_scope('Input_Variables'): 155 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 156 | label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], 6]) 157 | is_training = tf.placeholder(tf.bool, [], name='is_training') 158 | ########################################## 159 | with tf.variable_scope('Network'): 160 | print('Constructing model...') 161 | network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training, int(arg_dict['input_size']*arg_dict['bin_per_px'])) 162 | with tf.variable_scope('Loss'): 163 | print('Adding loss function...') 164 | loss, errors = gen_loss_xyhot(network_eval_batch, label_placeholder, arg_dict['input_size'], int(arg_dict['input_size']*arg_dict['bin_per_px'])) 165 | ########################################## 166 | with tf.variable_scope('Input_Decoding'): 167 | print('Populating input queues...') 168 | image_valid_batch, label_valid_batch = get_eval_batch_ellreg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size']) 169 | image_train_batch, label_train_batch = get_train_batch_ellreg_v2(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'], max_rot = arg_dict['aug_rot_max'], max_trans = arg_dict['aug_trans_max']) 170 | print('Starting input threads...') 171 | coord = tf.train.Coordinator() 172 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 173 | ########################################## 174 | global_step = tf.Variable(0, name='global_step', trainable=False) 175 | with tf.variable_scope('Optimizer'): 176 | print('Initializing optimizer...') 177 | learn_rate, train_op = arg_dict['learn_function'](loss, arg_dict['dataset'].train_size, arg_dict['batch_size'], global_step, arg_dict['start_learn_rate'], arg_dict['epocs_per_lr_decay'], const_learn_rate=arg_dict['const_learn_rate']) 178 | ########################################## 179 | with tf.variable_scope('Saver'): 180 | print('Generating summaries and savers...') 181 | training_summary, validation_summary = gen_summary_xyhot(loss, errors, learn_rate) 182 | summary_writer = tf.summary.FileWriter(arg_dict['log_dir'], sess.graph) 183 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 184 | ########################################## 185 | print('Initializing model...') 186 | sess.run(tf.global_variables_initializer()) 187 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 188 | saver.restore(sess,arg_dict['network_to_restore']) 189 | print('Beginning training...') 190 | for step in range(0, arg_dict['num_steps']): 191 | start_time = time.time() 192 | img_batch, label_batch = sess.run([image_train_batch, label_train_batch]) 193 | _, train_loss, summary_output, cur_step, bin_errs = sess.run(fetches=[train_op, loss, training_summary, global_step, errors], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: True}) 194 | duration = time.time() - start_time 195 | if (step+1) % 10 == 0: # CMDline updates every 10 steps 196 | examples_per_sec = arg_dict['batch_size'] / duration 197 | sec_per_batch = float(duration) 198 | format_str = ('%s: step %d, loss=%.4f, xerr=%.2f, yerr=%.2f (%.1f examples/sec)') 199 | print (format_str % (datetime.now(), cur_step, train_loss, bin_errs[0], bin_errs[1], examples_per_sec)) 200 | if (step+1) % 100 == 0: # Tensorboard updates values every 100 steps 201 | summary_writer.add_summary(summary_output, cur_step) 202 | img_batch, label_batch = sess.run([image_valid_batch, label_valid_batch]) 203 | summary_output = sess.run(fetches=[validation_summary], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: False})[0] 204 | summary_writer.add_summary(summary_output, cur_step) 205 | if (step+1) % 1000 == 0: # Save model every 1k steps 206 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 207 | saver.save(sess, checkpoint_path, global_step=cur_step) 208 | # Save model after training is terminated... 209 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 210 | saver.save(sess, checkpoint_path, global_step=cur_step) 211 | 212 | 213 | def trainSegSoftNetwork(arg_dict): 214 | sess = tf.Session() 215 | ########################################## 216 | with tf.variable_scope('Input_Variables'): 217 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 218 | label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 219 | is_training = tf.placeholder(tf.bool, [], name='is_training') 220 | ########################################## 221 | with tf.variable_scope('Network'): 222 | print('Constructing model...') 223 | network_eval_batch = arg_dict['model_construct_function'](image_placeholder, is_training) 224 | with tf.variable_scope('Loss'): 225 | print('Adding loss function...') 226 | loss, errors = gen_loss_seg_nomorph(network_eval_batch, label_placeholder) 227 | ########################################## 228 | with tf.variable_scope('Input_Decoding'): 229 | print('Populating input queues...') 230 | image_valid_batch, label_valid_batch = get_eval_batch_seg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size']) 231 | image_train_batch, label_train_batch = get_train_batch_seg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'], max_trans = arg_dict['aug_trans_max'], max_rot = arg_dict['aug_rot_max']) 232 | print('Starting input threads...') 233 | coord = tf.train.Coordinator() 234 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 235 | ########################################## 236 | global_step = tf.Variable(0, name='global_step', trainable=False) 237 | with tf.variable_scope('Optimizer'): 238 | print('Initializing optimizer...') 239 | learn_rate, train_op = arg_dict['learn_function'](loss, arg_dict['dataset'].train_size, arg_dict['batch_size'], global_step, arg_dict['start_learn_rate'], arg_dict['epocs_per_lr_decay'], const_learn_rate=arg_dict['const_learn_rate']) 240 | ########################################## 241 | with tf.variable_scope('Saver'): 242 | print('Generating summaries and savers...') 243 | training_summary, validation_summary = gen_summary_seg(loss, errors, learn_rate) 244 | summary_writer = tf.summary.FileWriter(arg_dict['log_dir'], sess.graph) 245 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 246 | ########################################## 247 | print('Initializing model...') 248 | sess.run(tf.global_variables_initializer()) 249 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 250 | saver.restore(sess,arg_dict['network_to_restore']) 251 | # 252 | for step in range(0, arg_dict['num_steps']): 253 | start_time = time.time() 254 | img_batch, label_batch = sess.run([image_train_batch, label_train_batch]) 255 | _, train_loss, summary_output, cur_step = sess.run(fetches=[train_op, loss, training_summary, global_step], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: True}) 256 | duration = time.time() - start_time 257 | if (step+1) % 10 == 0: # CMDline updates every 10 steps 258 | examples_per_sec = arg_dict['batch_size'] / duration 259 | sec_per_batch = float(duration) 260 | format_str = ('%s: step %d, loss = %.4f (%.1f examples/sec; %.3f sec/batch)') 261 | print (format_str % (datetime.now(), cur_step, train_loss, examples_per_sec, sec_per_batch)) 262 | if (step+1) % 100 == 0: # Tensorboard updates values every 100 steps 263 | summary_writer.add_summary(summary_output, cur_step) 264 | img_batch, label_batch = sess.run([image_valid_batch, label_valid_batch]) 265 | summary_output = sess.run(fetches=[validation_summary], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: False})[0] 266 | summary_writer.add_summary(summary_output, cur_step) 267 | if (step+1) % 1000 == 0: # Save model every 1k steps 268 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 269 | saver.save(sess, checkpoint_path, global_step=cur_step) 270 | 271 | # Save model after training is terminated... 272 | checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt') 273 | saver.save(sess, checkpoint_path, global_step=cur_step) 274 | 275 | def trainNetwork(arg_dict): 276 | if arg_dict['net_type'] == 'segellreg': 277 | trainSegEllfitNetwork(arg_dict) 278 | elif arg_dict['net_type'] == 'ellreg': 279 | trainEllregNetwork(arg_dict) 280 | elif arg_dict['net_type'] == 'binned': 281 | trainBinnedNetwork(arg_dict) 282 | elif arg_dict['net_type'] == 'seg': 283 | trainSegSoftNetwork(arg_dict) 284 | 285 | -------------------------------------------------------------------------------- /utils/InferMovie.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | import numpy as np 6 | import cv2 7 | import time 8 | from datetime import datetime 9 | import sys 10 | from .plotters import * 11 | from .readers import * 12 | from .transformer import * 13 | from .models import * 14 | from .datasets import * 15 | from .training import * 16 | import imageio 17 | from time import time 18 | 19 | # Processes the input_movie using the network 20 | # Ellreg-based net 21 | def processMovie(input_movie, network, outputs): 22 | # Setup some non-modifiable values... 23 | ellfit_movie_append = '_ellfit' 24 | affine_movie_append = '_affine' 25 | crop_movie_append = '_crop' 26 | ellfit_output_append = '_ellfit' 27 | ellfit_feature_outputs_append = '_features' 28 | 29 | # Set up the output streams... 30 | stillReading = False 31 | reader = imageio.get_reader(input_movie) 32 | if outputs['ell_mov']: 33 | writer_ellfit = imageio.get_writer(input_movie[:-4]+ellfit_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 34 | stillReading = True 35 | if outputs['aff_mov']: 36 | writer_affine = imageio.get_writer(input_movie[:-4]+affine_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 37 | stillReading = True 38 | if outputs['crop_mov']: 39 | writer_crop = imageio.get_writer(input_movie[:-4]+crop_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 40 | stillReading = True 41 | if outputs['ell_file']: 42 | file_ellfit = open(input_movie[:-4]+ellfit_output_append+'.npz', 'wb') 43 | stillReading = True 44 | if outputs['ell_features']: 45 | file_features = open(input_movie[:-4]+ellfit_feature_outputs_append+'.npz', 'wb') 46 | stillReading = True 47 | 48 | # Start processing the data 49 | im_iter = reader.iter_data() 50 | framenum = 0 51 | while(stillReading): 52 | start_time = time() 53 | frames = [] 54 | framenum = framenum + 1 * network['batch_size'] 55 | if framenum % 1000 == 0: 56 | print("Frame: " + str(framenum)) 57 | for i in range(network['batch_size']): 58 | try: 59 | #frame = cv2.cvtColor(np.uint8(next(im_iter)), cv2.COLOR_BGR2GRAY) 60 | frame = np.uint8(next(im_iter)) 61 | #frames.append(np.resize(frame, (network['input_size'], network['input_size'], 1))) 62 | frames.append(np.resize(frame[:,:,1], (network['input_size'], network['input_size'], 1))) 63 | except StopIteration: 64 | stillReading = False 65 | break 66 | except RuntimeError: 67 | stillReading = False 68 | break 69 | if framenum % 1000 == 0: 70 | print('Batch Assembled in: ' + str(time()-start_time)) 71 | start_time = time() 72 | if stillReading: 73 | if outputs['ell_features']: 74 | result, result_unscaled, features = network['sess'].run(fetches=[network['network_eval_batch'], network['ellfit'], network['final_features']], feed_dict={network['image_placeholder']: frames, network['is_training']: False}) 75 | else: 76 | result, result_unscaled = network['sess'].run(fetches=[network['network_eval_batch'], network['ellfit']], feed_dict={network['image_placeholder']: frames, network['is_training']: False}) 77 | if framenum % 1000 == 0: 78 | print('Batch Processed in: ' + str(time()-start_time)) 79 | start_time = time() 80 | # Sequentially save the data 81 | for i in range(network['batch_size']): 82 | # Save the outputs and save only if the outfile pattern was identified 83 | if outputs['ell_mov']: 84 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 85 | result_temp = plot_ellipse(plot, result[i], (255, 0, 0)) 86 | writer_ellfit.append_data(result_temp.astype('u1')) 87 | if outputs['aff_mov']: 88 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 89 | angle = np.arctan2(result_unscaled[i,5],result_unscaled[i,4])*180/np.pi 90 | affine_mat = np.float32([[1,0,-result_unscaled[i,0]+outputs['affine_crop_dim']],[0,1,-result_unscaled[i,1]+outputs['affine_crop_dim']]]) 91 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 92 | affine_mat = cv2.getRotationMatrix2D((outputs['affine_crop_dim'],outputs['affine_crop_dim']),angle,1.); 93 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 94 | affine_mat = np.float32([[1,0,-outputs['affine_crop_dim']/2],[0,1,-outputs['affine_crop_dim']/2]]); 95 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim'],outputs['affine_crop_dim'])); 96 | writer_affine.append_data(plot.astype('u1')) 97 | if outputs['crop_mov']: 98 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 99 | angle = 0 100 | affine_mat = np.float32([[1,0,-result_unscaled[i,0]+outputs['affine_crop_dim']],[0,1,-result_unscaled[i,1]+outputs['affine_crop_dim']]]) 101 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 102 | affine_mat = cv2.getRotationMatrix2D((outputs['affine_crop_dim'],outputs['affine_crop_dim']),angle,1.); 103 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 104 | affine_mat = np.float32([[1,0,-outputs['affine_crop_dim']/2],[0,1,-outputs['affine_crop_dim']/2]]); 105 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim'],outputs['affine_crop_dim'])); 106 | writer_crop.append_data(plot.astype('u1')) 107 | if outputs['ell_file']: 108 | np.save(file_ellfit, result_unscaled[i,:], allow_pickle=False) 109 | if outputs['ell_features']: 110 | np.save(file_features, features[i,:], allow_pickle=False) 111 | if framenum % 1000 == 0: 112 | print('Batch Saved in: ' + str(time()-start_time)) 113 | 114 | if outputs['ell_file']: 115 | file_ellfit.close() 116 | if outputs['ell_features']: 117 | file_features.close() 118 | 119 | 120 | # Processes the input_movie using the network 121 | # Segellreg-based net 122 | def processMovie_v2(input_movie, network, outputs): 123 | # Setup some non-modifiable values... 124 | ellfit_movie_append = '_ellfit' 125 | affine_movie_append = '_affine' 126 | crop_movie_append = '_crop' 127 | ellfit_output_append = '_ellfit' 128 | ellfit_feature_outputs_append = '_features' 129 | seg_movie_append = '_seg' 130 | 131 | # Set up the output streams... 132 | stillReading = False 133 | reader = imageio.get_reader(input_movie) 134 | if outputs['ell_mov']: 135 | writer_ellfit = imageio.get_writer(input_movie[:-4]+ellfit_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 136 | stillReading = True 137 | if outputs['aff_mov']: 138 | writer_affine = imageio.get_writer(input_movie[:-4]+affine_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 139 | stillReading = True 140 | if outputs['crop_mov']: 141 | writer_crop = imageio.get_writer(input_movie[:-4]+crop_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 142 | stillReading = True 143 | if outputs['ell_file']: 144 | file_ellfit = open(input_movie[:-4]+ellfit_output_append+'.npz', 'wb') 145 | stillReading = True 146 | if outputs['ell_features']: 147 | file_features = open(input_movie[:-4]+ellfit_feature_outputs_append+'.npz', 'wb') 148 | stillReading = True 149 | if outputs['seg_mov']: 150 | writer_seg = imageio.get_writer(input_movie[:-4]+seg_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 151 | stillReading = True 152 | 153 | # Start processing the data 154 | im_iter = reader.iter_data() 155 | framenum = 0 156 | while(stillReading): 157 | start_time = time() 158 | frames = [] 159 | framenum = framenum + 1 * network['batch_size'] 160 | if framenum % 1000 == 0: 161 | print("Frame: " + str(framenum)) 162 | for i in range(network['batch_size']): 163 | try: 164 | #frame = cv2.cvtColor(np.uint8(next(im_iter)), cv2.COLOR_BGR2GRAY) 165 | frame = np.uint8(next(im_iter)) 166 | #frames.append(np.resize(frame, (network['input_size'], network['input_size'], 1))) 167 | frames.append(np.resize(frame[:,:,1], (network['input_size'], network['input_size'], 1))) 168 | except StopIteration: 169 | stillReading = False 170 | break 171 | except RuntimeError: 172 | stillReading = False 173 | break 174 | if framenum % 1000 == 0: 175 | print('Batch Assembled in: ' + str(time()-start_time)) 176 | start_time = time() 177 | if stillReading: 178 | result, result_unscaled, result_seg = network['sess'].run(fetches=[network['network_eval_batch'], network['ellfit'], network['seg']], feed_dict={network['image_placeholder']: frames, network['is_training']: False}) 179 | if framenum % 1000 == 0: 180 | print('Batch Processed in: ' + str(time()-start_time)) 181 | start_time = time() 182 | # Sequentially save the data 183 | for i in range(network['batch_size']): 184 | # Save the outputs and save only if the outfile pattern was identified 185 | if outputs['ell_mov']: 186 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 187 | result_temp = plot_ellipse(plot, result[i], (255, 0, 0)) 188 | writer_ellfit.append_data(result_temp.astype('u1')) 189 | if outputs['aff_mov']: 190 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 191 | angle = np.arctan2(result_unscaled[i,5],result_unscaled[i,4])*180/np.pi 192 | affine_mat = np.float32([[1,0,-result_unscaled[i,0]+outputs['affine_crop_dim']],[0,1,-result_unscaled[i,1]+outputs['affine_crop_dim']]]) 193 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 194 | affine_mat = cv2.getRotationMatrix2D((outputs['affine_crop_dim'],outputs['affine_crop_dim']),angle,1.); 195 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 196 | affine_mat = np.float32([[1,0,-outputs['affine_crop_dim']/2],[0,1,-outputs['affine_crop_dim']/2]]); 197 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim'],outputs['affine_crop_dim'])); 198 | writer_affine.append_data(plot.astype('u1')) 199 | if outputs['crop_mov']: 200 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 201 | angle = 0 202 | affine_mat = np.float32([[1,0,-result_unscaled[i,0]+outputs['affine_crop_dim']],[0,1,-result_unscaled[i,1]+outputs['affine_crop_dim']]]) 203 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 204 | affine_mat = cv2.getRotationMatrix2D((outputs['affine_crop_dim'],outputs['affine_crop_dim']),angle,1.); 205 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 206 | affine_mat = np.float32([[1,0,-outputs['affine_crop_dim']/2],[0,1,-outputs['affine_crop_dim']/2]]); 207 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim'],outputs['affine_crop_dim'])); 208 | writer_crop.append_data(plot.astype('u1')) 209 | if outputs['ell_file']: 210 | np.save(file_ellfit, result_unscaled[i,:], allow_pickle=False) 211 | if outputs['ell_features']: 212 | np.save(file_features, features[i,:], allow_pickle=False) 213 | if outputs['seg_mov']: 214 | seg_output = result_seg[i,:,:,:] 215 | seg_output = seg_output[:,:,0]/np.sum(seg_output,2) 216 | #seg_output = seg_output[:,:,0]/np.sum(seg_output,2)-seg_output[:,:,1]/np.sum(seg_output,2) 217 | #seg_output = seg_output+0.25 218 | #seg_output[seg_output<1e-6] = 0 219 | #seg_output[seg_output>1.0] = 1.0 220 | writer_seg.append_data((254*seg_output).astype('u1')) 221 | if framenum % 1000 == 0: 222 | print('Batch Saved in: ' + str(time()-start_time)) 223 | 224 | if outputs['ell_file']: 225 | file_ellfit.close() 226 | if outputs['ell_features']: 227 | file_features.close() 228 | 229 | 230 | # Processes the input_movie using the network 231 | # Binned-based net 232 | def processMovie_v3(input_movie, network, outputs): 233 | # Setup some non-modifiable values... 234 | ellfit_movie_append = '_xyplot' 235 | crop_movie_append = '_crop' 236 | 237 | # Set up the output streams... 238 | stillReading = False 239 | reader = imageio.get_reader(input_movie) 240 | if outputs['ell_mov']: 241 | writer_ellfit = imageio.get_writer(input_movie[:-4]+ellfit_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 242 | stillReading = True 243 | if outputs['crop_mov']: 244 | writer_crop = imageio.get_writer(input_movie[:-4]+crop_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 245 | stillReading = True 246 | 247 | # Start processing the data 248 | im_iter = reader.iter_data() 249 | framenum = 0 250 | while(stillReading): 251 | start_time = time() 252 | frames = [] 253 | framenum = framenum + 1 * network['batch_size'] 254 | if framenum % 1000 == 0: 255 | print("Frame: " + str(framenum)) 256 | for i in range(network['batch_size']): 257 | try: 258 | #frame = cv2.cvtColor(np.uint8(next(im_iter)), cv2.COLOR_BGR2GRAY) 259 | frame = np.uint8(next(im_iter)) 260 | #frames.append(np.resize(frame, (network['input_size'], network['input_size'], 1))) 261 | frames.append(np.resize(frame[:,:,1], (480, 480, 1))) 262 | except StopIteration: 263 | stillReading = False 264 | break 265 | except RuntimeError: 266 | stillReading = False 267 | break 268 | if framenum % 1000 == 0: 269 | print('Batch Assembled in: ' + str(time()-start_time)) 270 | start_time = time() 271 | if stillReading: 272 | xhot, yhot = network['sess'].run(fetches=[network['xhot_est'], network['yhot_est']], feed_dict={network['image_placeholder']: frames, network['is_training']: False}) 273 | if framenum % 1000 == 0: 274 | print('Batch Processed in: ' + str(time()-start_time)) 275 | start_time = time() 276 | # Sequentially save the data 277 | for i in range(network['batch_size']): 278 | if outputs['ell_mov']: 279 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 280 | # Place crosshair on predicted location... 281 | cv2.line(plot,(np.float32(np.argmax(xhot,1)[i]/network['bin_per_px']-2), np.float32(np.argmax(yhot,1)[i]/network['bin_per_px'])),(np.float32(np.argmax(xhot,1)[i]/network['bin_per_px']+2), np.float32(np.argmax(yhot,1)[i]/network['bin_per_px'])), (255, 0, 0)) 282 | cv2.line(plot,(np.float32(np.argmax(xhot,1)[i]/network['bin_per_px']), np.float32(np.argmax(yhot,1)[i]/network['bin_per_px']-2)),(np.float32(np.argmax(xhot,1)[i]/network['bin_per_px']), np.float32(np.argmax(yhot,1)[i]/network['bin_per_px']+2)), (255, 0, 0)) 283 | writer_ellfit.append_data(plot.astype('u1')) 284 | if outputs['crop_mov']: 285 | plot = cv2.cvtColor(frames[i],cv2.COLOR_GRAY2RGB) 286 | angle = 0 287 | affine_mat = np.float32([[1,0,-np.argmax(xhot,1)[i]/network['bin_per_px']+outputs['affine_crop_dim']],[0,1,-np.argmax(yhot,1)[i]/network['bin_per_px']+outputs['affine_crop_dim']]]) 288 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 289 | affine_mat = cv2.getRotationMatrix2D((outputs['affine_crop_dim'],outputs['affine_crop_dim']),angle,1.); 290 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim']*2,outputs['affine_crop_dim']*2)); 291 | affine_mat = np.float32([[1,0,-outputs['affine_crop_dim']/2],[0,1,-outputs['affine_crop_dim']/2]]); 292 | plot = cv2.warpAffine(plot, affine_mat, (outputs['affine_crop_dim'],outputs['affine_crop_dim'])); 293 | writer_crop.append_data(plot.astype('u1')) 294 | if framenum % 1000 == 0: 295 | print('Batch Saved in: ' + str(time()-start_time)) 296 | 297 | # Processes the input_movie using the network 298 | # Segmentation ONLY based net 299 | def processSegSoftMovie(input_movie, network, outputs): 300 | # Setup some non-modifiable values... 301 | seg_movie_append = '_seg' 302 | 303 | # Set up the output streams... 304 | stillReading = False 305 | reader = imageio.get_reader(input_movie) 306 | if outputs['seg_mov']: 307 | writer_seg = imageio.get_writer(input_movie[:-4]+seg_movie_append+'.avi', fps=reader.get_meta_data()['fps'], codec='mpeg4', quality=10) 308 | stillReading = True 309 | 310 | # Start processing the data 311 | im_iter = reader.iter_data() 312 | framenum = 0 313 | while(stillReading): 314 | start_time = time() 315 | frames = [] 316 | framenum = framenum + 1 * network['batch_size'] 317 | if framenum % 1000 == 0: 318 | print("Frame: " + str(framenum)) 319 | for i in range(network['batch_size']): 320 | try: 321 | frame = np.uint8(next(im_iter)) 322 | frames.append(np.resize(frame[:,:,1], (network['input_size'], network['input_size'], 1))) 323 | except StopIteration: 324 | stillReading = False 325 | break 326 | except RuntimeError: 327 | stillReading = False 328 | break 329 | if framenum % 1000 == 0: 330 | print('Batch Assembled in: ' + str(time()-start_time)) 331 | start_time = time() 332 | if stillReading: 333 | result_seg = network['sess'].run(fetches=[network['seg']], feed_dict={network['image_placeholder']: frames, network['is_training']: False})[0] 334 | if framenum % 1000 == 0: 335 | print('Batch Processed in: ' + str(time()-start_time)) 336 | start_time = time() 337 | # Sequentially save the data 338 | for i in range(network['batch_size']): 339 | # Save the outputs and save only if the outfile pattern was identified 340 | if outputs['seg_mov']: 341 | seg_output = result_seg[i,:,:] 342 | writer_seg.append_data((254*seg_output[:,:]).astype('u1')) 343 | if framenum % 1000 == 0: 344 | print('Batch Saved in: ' + str(time()-start_time)) 345 | 346 | 347 | def inferEllregNetwork(arg_dict): 348 | start_time = time() 349 | sess = tf.Session() 350 | ########################################## 351 | with tf.variable_scope('Input_Variables'): 352 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 353 | is_training = tf.placeholder(tf.bool, [], name='is_training') 354 | ########################################## 355 | with tf.variable_scope('Network'): 356 | print('Constructing model...') 357 | network_eval_batch, final_features = arg_dict['model_construct_function'](image_placeholder, is_training) 358 | ellfit = tf.add(tf.multiply(network_eval_batch, scale), means) 359 | ########################################## 360 | global_step = tf.Variable(0, name='global_step', trainable=False) 361 | ########################################## 362 | with tf.variable_scope('Saver'): 363 | print('Generating summaries and savers...') 364 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 365 | ########################################## 366 | print('Initializing model...') 367 | sess.run(tf.global_variables_initializer()) 368 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 369 | saver.restore(sess,arg_dict['network_to_restore']) 370 | 371 | # Pack the parameters into a dictionary 372 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'network_eval_batch':network_eval_batch, 'ellfit':ellfit, 'final_features':final_features, 'image_placeholder':image_placeholder, 'is_training':is_training} 373 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'aff_mov':arg_dict['affine_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'ell_file':arg_dict['ellfit_output'], 'ell_features':arg_dict['ellfit_features_output'], 'affine_crop_dim':arg_dict['affine_crop_dim']} 374 | # Process a single movie 375 | time_duration = time()-start_time 376 | print('Initializing Network Duration: ' + str(time_duration)) 377 | processMovie(arg_dict['input_movie'], network, outputs) 378 | 379 | def inferEllregNetwork_Loop(arg_dict): 380 | start_time = time() 381 | sess = tf.Session() 382 | with tf.variable_scope('Input_Variables'): 383 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 384 | is_training = tf.placeholder(tf.bool, [], name='is_training') 385 | with tf.variable_scope('Network'): 386 | print('Constructing model...') 387 | network_eval_batch, final_features = arg_dict['model_construct_function'](image_placeholder, is_training) 388 | ellfit = tf.add(tf.multiply(network_eval_batch, scale), means) 389 | global_step = tf.Variable(0, name='global_step', trainable=False) 390 | with tf.variable_scope('Saver'): 391 | print('Generating summaries and savers...') 392 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 393 | print('Initializing model...') 394 | sess.run(tf.global_variables_initializer()) 395 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 396 | saver.restore(sess,arg_dict['network_to_restore']) 397 | 398 | # Pack the parameters into a dictionary 399 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'network_eval_batch':network_eval_batch, 'ellfit':ellfit, 'final_features':final_features, 'image_placeholder':image_placeholder, 'is_training':is_training} 400 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'aff_mov':arg_dict['affine_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'ell_file':arg_dict['ellfit_output'], 'ell_features':arg_dict['ellfit_features_output'], 'affine_crop_dim':arg_dict['affine_crop_dim']} 401 | 402 | time_duration = time()-start_time 403 | print('Initializing Network Duration: ' + str(time_duration)) 404 | 405 | # Process multiple movies 406 | f = open(arg_dict['input_movie_list']) 407 | lines = f.read().split('\n') 408 | lines = lines[0:-1] # Remove the last split '' string 409 | for input_movie in lines: 410 | processMovie(input_movie, network, outputs) 411 | 412 | def inferSegEllregNetwork(arg_dict): 413 | start_time = time() 414 | sess = tf.Session() 415 | ########################################## 416 | with tf.variable_scope('Input_Variables'): 417 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 418 | is_training = tf.placeholder(tf.bool, [], name='is_training') 419 | ########################################## 420 | with tf.variable_scope('Network'): 421 | print('Constructing model...') 422 | seg_eval_batch, network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training) 423 | ellfit = tf.add(tf.multiply(network_eval_batch, scale), means) 424 | ########################################## 425 | global_step = tf.Variable(0, name='global_step', trainable=False) 426 | ########################################## 427 | with tf.variable_scope('Saver'): 428 | print('Generating summaries and savers...') 429 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 430 | ########################################## 431 | print('Initializing model...') 432 | sess.run(tf.global_variables_initializer()) 433 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 434 | saver.restore(sess,arg_dict['network_to_restore']) 435 | 436 | # Pack the parameters into a dictionary 437 | # Force never to save the features... 438 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'network_eval_batch':network_eval_batch, 'ellfit':ellfit, 'final_features':seg_eval_batch, 'image_placeholder':image_placeholder, 'is_training':is_training, 'seg':seg_eval_batch} 439 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'aff_mov':arg_dict['affine_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'ell_file':arg_dict['ellfit_output'], 'ell_features':False, 'affine_crop_dim':arg_dict['affine_crop_dim'], 'seg_mov':arg_dict['seg_movie_output']} 440 | # Process a single movie 441 | time_duration = time()-start_time 442 | print('Initializing Network Duration: ' + str(time_duration)) 443 | print('Processing ' + arg_dict['input_movie']) 444 | processMovie_v2(arg_dict['input_movie'], network, outputs) 445 | 446 | def inferSegEllregNetwork_loop(arg_dict): 447 | start_time = time() 448 | sess = tf.Session() 449 | ########################################## 450 | with tf.variable_scope('Input_Variables'): 451 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 452 | is_training = tf.placeholder(tf.bool, [], name='is_training') 453 | ########################################## 454 | with tf.variable_scope('Network'): 455 | print('Constructing model...') 456 | seg_eval_batch, network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training) 457 | ellfit = tf.add(tf.multiply(network_eval_batch, scale), means) 458 | ########################################## 459 | global_step = tf.Variable(0, name='global_step', trainable=False) 460 | ########################################## 461 | with tf.variable_scope('Saver'): 462 | print('Generating summaries and savers...') 463 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 464 | ########################################## 465 | print('Initializing model...') 466 | sess.run(tf.global_variables_initializer()) 467 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 468 | saver.restore(sess,arg_dict['network_to_restore']) 469 | 470 | # Pack the parameters into a dictionary 471 | # Force never to save the features... 472 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'network_eval_batch':network_eval_batch, 'ellfit':ellfit, 'final_features':seg_eval_batch, 'image_placeholder':image_placeholder, 'is_training':is_training, 'seg':seg_eval_batch} 473 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'aff_mov':arg_dict['affine_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'ell_file':arg_dict['ellfit_output'], 'ell_features':False, 'affine_crop_dim':arg_dict['affine_crop_dim'], 'seg_mov':arg_dict['seg_movie_output']} 474 | # Process a single movie 475 | time_duration = time()-start_time 476 | print('Initializing Network Duration: ' + str(time_duration)) 477 | 478 | f = open(arg_dict['input_movie_list']) 479 | lines = f.read().split('\n') 480 | lines = lines[0:-1] # Remove the last split '' string 481 | for input_movie in lines: 482 | processMovie_v2(input_movie, network, outputs) 483 | 484 | 485 | def inferBinnedNetwork(arg_dict): 486 | start_time = time() 487 | sess = tf.Session() 488 | ########################################## 489 | with tf.variable_scope('Input_Variables'): 490 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 491 | is_training = tf.placeholder(tf.bool, [], name='is_training') 492 | ########################################## 493 | with tf.variable_scope('Network'): 494 | print('Constructing model...') 495 | network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training, int(arg_dict['input_size']*arg_dict['bin_per_px'])) 496 | xhot_est, yhot_est = tf.unstack(network_eval_batch) 497 | ########################################## 498 | global_step = tf.Variable(0, name='global_step', trainable=False) 499 | ########################################## 500 | with tf.variable_scope('Saver'): 501 | print('Generating summaries and savers...') 502 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 503 | ########################################## 504 | print('Initializing model...') 505 | sess.run(tf.global_variables_initializer()) 506 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 507 | saver.restore(sess,arg_dict['network_to_restore']) 508 | 509 | # Pack the parameters into a dictionary 510 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'bin_per_px':arg_dict['bin_per_px'], 'image_placeholder':image_placeholder, 'is_training':is_training, 'xhot_est':xhot_est, 'yhot_est':yhot_est} 511 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'affine_crop_dim':arg_dict['affine_crop_dim']} 512 | # Process a single movie 513 | time_duration = time()-start_time 514 | print('Initializing Network Duration: ' + str(time_duration)) 515 | 516 | # Process a single movie 517 | time_duration = time()-start_time 518 | print('Initializing Network Duration: ' + str(time_duration)) 519 | print('Processing ' + arg_dict['input_movie']) 520 | processMovie_v3(arg_dict['input_movie'], network, outputs) 521 | 522 | def inferBinnedNetwork_loop(arg_dict): 523 | start_time = time() 524 | sess = tf.Session() 525 | ########################################## 526 | with tf.variable_scope('Input_Variables'): 527 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 528 | is_training = tf.placeholder(tf.bool, [], name='is_training') 529 | ########################################## 530 | with tf.variable_scope('Network'): 531 | print('Constructing model...') 532 | network_eval_batch, _ = arg_dict['model_construct_function'](image_placeholder, is_training, int(arg_dict['input_size']*arg_dict['bin_per_px'])) 533 | xhot_est, yhot_est = tf.unstack(network_eval_batch) 534 | ########################################## 535 | global_step = tf.Variable(0, name='global_step', trainable=False) 536 | ########################################## 537 | with tf.variable_scope('Saver'): 538 | print('Generating summaries and savers...') 539 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 540 | ########################################## 541 | print('Initializing model...') 542 | sess.run(tf.global_variables_initializer()) 543 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 544 | saver.restore(sess,arg_dict['network_to_restore']) 545 | 546 | # Pack the parameters into a dictionary 547 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'bin_per_px':arg_dict['bin_per_px'], 'image_placeholder':image_placeholder, 'is_training':is_training, 'xhot_est':xhot_est, 'yhot_est':yhot_est} 548 | outputs = {'ell_mov':arg_dict['ellfit_movie_output'], 'crop_mov':arg_dict['crop_movie_output'], 'affine_crop_dim':arg_dict['affine_crop_dim']} 549 | # Process a single movie 550 | time_duration = time()-start_time 551 | print('Initializing Network Duration: ' + str(time_duration)) 552 | 553 | f = open(arg_dict['input_movie_list']) 554 | lines = f.read().split('\n') 555 | lines = lines[0:-1] # Remove the last split '' string 556 | for input_movie in lines: 557 | processMovie_v3(input_movie, network, outputs) 558 | 559 | 560 | def inferSegSoftNetwork(arg_dict): 561 | start_time = time() 562 | sess = tf.Session() 563 | ########################################## 564 | with tf.variable_scope('Input_Variables'): 565 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 566 | is_training = tf.placeholder(tf.bool, [], name='is_training') 567 | ########################################## 568 | with tf.variable_scope('Network'): 569 | print('Constructing model...') 570 | seg_eval_batch = arg_dict['model_construct_function'](image_placeholder, is_training) 571 | seg_eval_batch = tf.nn.softmax(seg_eval_batch)[:,:,:,0] # Only grab the "Mouse" 572 | ########################################## 573 | global_step = tf.Variable(0, name='global_step', trainable=False) 574 | ########################################## 575 | with tf.variable_scope('Saver'): 576 | print('Generating summaries and savers...') 577 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 578 | ########################################## 579 | print('Initializing model...') 580 | sess.run(tf.global_variables_initializer()) 581 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 582 | saver.restore(sess,arg_dict['network_to_restore']) 583 | 584 | # Pack the parameters into a dictionary 585 | # Force never to save the features... 586 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'image_placeholder':image_placeholder, 'is_training':is_training, 'seg':seg_eval_batch} 587 | outputs = {'seg_mov':arg_dict['seg_movie_output']} 588 | # Process a single movie 589 | time_duration = time()-start_time 590 | print('Initializing Network Duration: ' + str(time_duration)) 591 | 592 | processSegSoftMovie(arg_dict['input_movie'], network, outputs) 593 | 594 | def inferSegSoftNetwork_loop(arg_dict): 595 | start_time = time() 596 | sess = tf.Session() 597 | ########################################## 598 | with tf.variable_scope('Input_Variables'): 599 | image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1]) 600 | is_training = tf.placeholder(tf.bool, [], name='is_training') 601 | ########################################## 602 | with tf.variable_scope('Network'): 603 | print('Constructing model...') 604 | seg_eval_batch = arg_dict['model_construct_function'](image_placeholder, is_training) 605 | seg_eval_batch = tf.nn.softmax(seg_eval_batch)[:,:,:,0] # Only grab the "Mouse" 606 | ########################################## 607 | global_step = tf.Variable(0, name='global_step', trainable=False) 608 | ########################################## 609 | with tf.variable_scope('Saver'): 610 | print('Generating summaries and savers...') 611 | saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2) 612 | ########################################## 613 | print('Initializing model...') 614 | sess.run(tf.global_variables_initializer()) 615 | if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None: 616 | saver.restore(sess,arg_dict['network_to_restore']) 617 | 618 | # Pack the parameters into a dictionary 619 | # Force never to save the features... 620 | network = {'sess':sess, 'batch_size':arg_dict['batch_size'], 'input_size':arg_dict['input_size'], 'image_placeholder':image_placeholder, 'is_training':is_training, 'seg':seg_eval_batch} 621 | outputs = {'seg_mov':arg_dict['seg_movie_output']} 622 | # Process a single movie 623 | time_duration = time()-start_time 624 | print('Initializing Network Duration: ' + str(time_duration)) 625 | 626 | f = open(arg_dict['input_movie_list']) 627 | lines = f.read().split('\n') 628 | lines = lines[0:-1] # Remove the last split '' string 629 | for input_movie in lines: 630 | processSegSoftMovie(input_movie, network, outputs) 631 | 632 | 633 | # Parses the argument dictionary to select the correct processing functions 634 | def inferMovie(arg_dict): 635 | if arg_dict['net_type'] == 'segellreg': 636 | if 'input_movie_list' in arg_dict.keys(): 637 | inferSegEllregNetwork_loop(arg_dict) 638 | else: 639 | inferSegEllregNetwork(arg_dict) 640 | elif arg_dict['net_type'] == 'ellreg': 641 | if 'input_movie_list' in arg_dict.keys(): 642 | inferEllregNetwork_Loop(arg_dict) 643 | else: 644 | inferEllregNetwork(arg_dict) 645 | elif arg_dict['net_type'] == 'binned': 646 | if 'input_movie_list' in arg_dict.keys(): 647 | inferBinnedNetwork_Loop(arg_dict) 648 | else: 649 | inferBinnedNetwork(arg_dict) 650 | elif arg_dict['net_type'] == 'seg': 651 | if 'input_movie_list' in arg_dict.keys(): 652 | inferSegSoftNetwork_loop(arg_dict) 653 | else: 654 | inferSegSoftNetwork(arg_dict) 655 | --------------------------------------------------------------------------------