├── Makefile ├── setup.cfg ├── transfer ├── __init__.py ├── images_to_array.py ├── pre_model.py ├── xception.py ├── resnet50.py ├── input.py ├── server.py ├── inception_v3.py ├── augment_arrays.py ├── predict_model.py ├── __main__.py ├── model.py └── project.py ├── LICENSE.txt ├── setup.py ├── .gitignore └── README.md /Makefile: -------------------------------------------------------------------------------- 1 | distribute: 2 | rm -rf dist 3 | python setup.py sdist 4 | python setup.py bdist_wheel 5 | twine upload dist/* 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | # This flag says that the code is written to work on both Python 2 and Python 3 | # 3. If at all possible, it is good practice to do this. If you cannot, you 4 | # will need to generate wheels for each Python version that you support. 5 | universal=1 6 | -------------------------------------------------------------------------------- /transfer/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | import pkgutil 4 | import inspect 5 | 6 | for loader, name, is_pkg in pkgutil.walk_packages(__path__): 7 | module = loader.find_module(name).load_module(name) 8 | 9 | for name, value in inspect.getmembers(module): 10 | if name.startswith('__'): 11 | continue 12 | 13 | globals()[name] = value 14 | __all__.append(name) -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Matthew A Sochor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="transfer", 5 | version="0.40", 6 | description="Transfer learning for deep image classification", 7 | packages=find_packages(), 8 | 9 | # Project uses reStructuredText, so ensure that the docutils get 10 | # installed or upgraded on the target machine 11 | install_requires=['numpy', 12 | 'keras', 13 | 'pyyaml', 14 | 'tqdm', 15 | 'pandas', 16 | 'opencv-python', 17 | 'termcolor', 18 | 'colorama', 19 | 'flask', 20 | 'flask_jsonpify', 21 | 'flask_restful', 22 | 'matplotlib', 23 | 'pillow', 24 | 'h5py', 25 | 'scikit-learn', 26 | 'seaborn'], 27 | python_requires='>=3', 28 | 29 | # metadata for upload to PyPI 30 | author="Matthew Sochor", 31 | author_email="matthew.sochor@gmail.com", 32 | license="MIT", 33 | keywords="keras transfer learning resnet deep neural net image classification command line", 34 | url="http://github.com/matthew-sochor/transfer", # project home page, if any 35 | download_url="http://github.com/matthew-sochor/transfer", # project home page, if any 36 | 37 | # could also include long_description, download_url, classifiers, etc. 38 | entry_points = { 39 | 'console_scripts': [ 40 | 'transfer = transfer.__main__:main' 41 | ] 42 | } 43 | ) 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | jobs/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /transfer/images_to_array.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | from keras.preprocessing.image import load_img 7 | 8 | 9 | def images_to_array(project): 10 | 11 | categories = [d for d in os.listdir(project['img_path']) if os.path.isdir(os.path.join(project['img_path'],d))] 12 | project['categories'] = categories 13 | img_dim = project['img_dim'] * project['img_size'] 14 | print('Converting images to array') 15 | category_rounds = val_images_to_array(project['img_path'], project['path'], img_dim, project['categories']) 16 | 17 | project['is_array'] = True 18 | project['category_rounds'] = category_rounds 19 | return project 20 | 21 | 22 | def val_images_to_array(img_path, source_path, img_dim, categories): 23 | 24 | array_path = os.path.join(source_path, 'array') 25 | shutil.rmtree(array_path,ignore_errors=True) 26 | os.makedirs(array_path) 27 | 28 | print('Iterating over all categories: ', categories) 29 | category_lengths = [] 30 | for category_idx, category in enumerate(categories): 31 | print('categories:', category) 32 | category_path = os.path.join(img_path, category) 33 | img_files = sorted(os.listdir(category_path)) 34 | category_lengths.append(len(img_files)) 35 | for img_idx, img_file in tqdm(enumerate(img_files)): 36 | this_img_path = os.path.join(category_path, img_file) 37 | img = load_img(this_img_path, target_size=(img_dim, img_dim)) 38 | 39 | img_name = '{}-img-{}-{}'.format(img_idx, category, category_idx) 40 | label_name = '{}-label-{}-{}'.format(img_idx, category, category_idx) 41 | 42 | label = np.eye(len(categories), dtype = np.float32)[category_idx] 43 | 44 | img_array_path = os.path.join(array_path, img_name) 45 | img_label_path = os.path.join(array_path, label_name) 46 | 47 | np.save(img_array_path, img) 48 | np.save(img_label_path, label) 49 | category_lengths = np.array(category_lengths) / sum(category_lengths) 50 | category_lengths = list(category_lengths / max(category_lengths)) 51 | category_rounds = {cat: min(int(np.round(1 / l)), 10) for cat, l in zip(categories, category_lengths)} 52 | return category_rounds 53 | -------------------------------------------------------------------------------- /transfer/pre_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | from keras.applications.resnet50 import preprocess_input as resnet_preprocess_input 6 | from keras.applications.xception import preprocess_input as xception_preprocess_input 7 | from keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess_input 8 | from tqdm import tqdm 9 | 10 | from transfer.resnet50 import get_resnet_pre_model 11 | from transfer.xception import get_xception_pre_model 12 | from transfer.inception_v3 import get_inception_v3_pre_model 13 | 14 | def gen_array_from_dir(array_dir): 15 | array_files = sorted(os.listdir(array_dir)) 16 | 17 | array_names = list(filter(lambda x: r'-img-' in x, array_files)) 18 | label_names = list(filter(lambda x: r'-label-' in x, array_files)) 19 | 20 | assert len(array_names) == len(label_names) 21 | 22 | for arr_name, lab_name in zip(array_names, label_names): 23 | X = np.load(os.path.join(array_dir, arr_name)) 24 | Y = np.load(os.path.join(array_dir, lab_name)) 25 | yield X, Y, arr_name, lab_name 26 | 27 | 28 | def pre_model(project): 29 | 30 | img_dim = project['img_dim'] * project['img_size'] 31 | print('Predicting pre-model') 32 | val_pre_model(project['path'], 'augmented', img_dim, project['architecture']) 33 | 34 | project['is_pre_model'] = True 35 | return project 36 | 37 | 38 | def val_pre_model(source_path, folder, img_dim, architechture): 39 | 40 | array_path = os.path.join(source_path, folder) 41 | pre_model_path = os.path.join(source_path, 'pre_model') 42 | shutil.rmtree(pre_model_path,ignore_errors=True) 43 | os.makedirs(pre_model_path) 44 | 45 | if architechture == 'resnet50': 46 | popped, pre_model = get_resnet_pre_model(img_dim) 47 | elif architechture == 'xception': 48 | popped, pre_model = get_xception_pre_model(img_dim) 49 | else: 50 | popped, pre_model = get_inception_v3_pre_model(img_dim) 51 | 52 | for (array, label, array_name, label_name) in tqdm(gen_array_from_dir(array_path)): 53 | if architechture == 'resnet50': 54 | array = resnet_preprocess_input(array[np.newaxis].astype(np.float32)) 55 | elif architechture == 'xception': 56 | array = xception_preprocess_input(array[np.newaxis].astype(np.float32)) 57 | else: 58 | array = inception_v3_preprocess_input(array[np.newaxis].astype(np.float32)) 59 | array_pre_model = np.squeeze(pre_model.predict(array, batch_size=1)) 60 | 61 | array_name = array_name.split('.')[0] 62 | label_name = label_name.split('.')[0] 63 | 64 | img_pre_model_path = os.path.join(pre_model_path, array_name) 65 | label_pre_model_path = os.path.join(pre_model_path, label_name) 66 | 67 | np.save(img_pre_model_path, array_pre_model) 68 | np.save(label_pre_model_path, label) 69 | -------------------------------------------------------------------------------- /transfer/xception.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from keras.layers import Input, Activation, Conv2D, AveragePooling2D, Flatten, BatchNormalization, Dropout, Dense 5 | from keras.models import Model 6 | from keras import layers 7 | from keras.applications.xception import Xception 8 | 9 | 10 | def pop_layer(model, count=1): 11 | if not model.outputs: 12 | raise Exception('Sequential model cannot be popped: model is empty.') 13 | 14 | popped = [model.layers.pop() for i in range(count)] 15 | 16 | if not model.layers: 17 | model.outputs = [] 18 | model.inbound_nodes = [] 19 | model.outbound_nodes = [] 20 | else: 21 | model.layers[-1].outbound_nodes = [] 22 | model.outputs = [model.layers[-1].output] 23 | 24 | model.container_nodes = sorted([l.name for l in model.layers]) 25 | model.built = True 26 | 27 | return popped, model 28 | 29 | 30 | def get_xception_model(img_dim): 31 | array_input = Input(shape=(img_dim, img_dim, 3)) 32 | xception = Xception(include_top=True, 33 | weights='imagenet', 34 | input_tensor=array_input, 35 | pooling='avg') 36 | return xception 37 | 38 | 39 | def get_xception_pre_model(img_dim): 40 | xception = get_xception_model(img_dim) 41 | popped, pre_model = pop_layer(xception, 8) 42 | return popped, pre_model 43 | 44 | 45 | def get_xception_pre_post_model(img_dim, conv_dim, number_categories, model_weights = None): 46 | 47 | popped, pre_model = get_xception_pre_model(img_dim) 48 | 49 | input_dims = (conv_dim, conv_dim, 1024) 50 | # Take last 8 layers from xception with their starting weights! 51 | x_in = Input(shape = input_dims) 52 | x = popped[7](x_in) 53 | x = popped[6](x) 54 | x = Activation('relu', name='block14_sepconv1_act')(x) 55 | 56 | x = popped[4](x) 57 | x = popped[3](x) 58 | x = Activation('relu', name='block14_sepconv2_act')(x) 59 | 60 | x = popped[1](x) 61 | x = Dense(number_categories, activation = 'softmax', name='predictions')(x) 62 | 63 | post_model = Model(x_in, x) 64 | 65 | if model_weights is not None: 66 | print('Loading model weights:', model_weights) 67 | post_model.load_weights(model_weights) 68 | 69 | return pre_model, post_model 70 | 71 | 72 | def get_xception_final_model(img_dim, conv_dim, number_categories, weights, is_final): 73 | 74 | if is_final: 75 | pre_post_weights = None 76 | else: 77 | pre_post_weights = weights 78 | pre_model, post_model = get_xception_pre_post_model(img_dim, conv_dim, number_categories, model_weights = pre_post_weights) 79 | x_in = Input(shape = (img_dim, img_dim, 3)) 80 | x = pre_model(x_in) 81 | x = post_model(x) 82 | final_model = Model(x_in, x) 83 | if is_final: 84 | print('Loading model weights:', weights) 85 | final_model.load_weights(weights) 86 | return final_model 87 | -------------------------------------------------------------------------------- /transfer/resnet50.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from keras.layers import Input, Activation, Conv2D, AveragePooling2D, Flatten, BatchNormalization, Dropout, Dense 5 | from keras.models import Model 6 | from keras import layers 7 | from keras.applications.resnet50 import ResNet50 8 | 9 | 10 | def pop_layer(model, count=1): 11 | if not model.outputs: 12 | raise Exception('Sequential model cannot be popped: model is empty.') 13 | 14 | popped = [model.layers.pop() for i in range(count)] 15 | 16 | if not model.layers: 17 | model.outputs = [] 18 | model.inbound_nodes = [] 19 | model.outbound_nodes = [] 20 | else: 21 | model.layers[-1].outbound_nodes = [] 22 | model.outputs = [model.layers[-1].output] 23 | 24 | model.container_nodes = sorted([l.name for l in model.layers]) 25 | model.built = True 26 | 27 | return popped, model 28 | 29 | 30 | def get_resnet_model(img_dim): 31 | array_input = Input(shape=(img_dim, img_dim, 3)) 32 | resnet = ResNet50(include_top=False, 33 | weights='imagenet', 34 | input_tensor=array_input, 35 | pooling='avg') 36 | return resnet 37 | 38 | 39 | def get_resnet_pre_model(img_dim): 40 | resnet = get_resnet_model(img_dim) 41 | popped, pre_model = pop_layer(resnet, 12) 42 | return popped, pre_model 43 | 44 | 45 | def get_resnet_pre_post_model(img_dim, conv_dim, number_categories, model_weights = None): 46 | 47 | popped, pre_model = get_resnet_pre_model(img_dim) 48 | 49 | input_dims = (conv_dim, conv_dim, 2048) 50 | # Take last 12 layers from resnet 50 with their starting weights! 51 | x_in = Input(shape = input_dims) 52 | 53 | x = popped[11](x_in) 54 | x = popped[10](x) 55 | x = Activation('relu')(x) 56 | 57 | x = popped[8](x) 58 | x = popped[7](x) 59 | x = Activation('relu')(x) 60 | 61 | x = popped[5](x) 62 | x = popped[4](x) 63 | 64 | x = layers.add([x, x_in]) 65 | x = Activation('relu')(x) 66 | mid_model = Model(x_in, x) 67 | 68 | x_in_2 = Input(shape = input_dims) 69 | 70 | x = AveragePooling2D((7, 7), name = 'avg_pool')(x_in_2) 71 | x = Flatten()(x) 72 | x = Dense(number_categories, activation = 'softmax')(x) 73 | 74 | end_model = Model(x_in_2, x) 75 | 76 | x_in_3 = Input(shape = input_dims) 77 | x = mid_model(x_in_3) 78 | x = end_model(x) 79 | post_model = Model(x_in_3, x) 80 | 81 | if model_weights is not None: 82 | print('Loading model weights:', model_weights) 83 | post_model.load_weights(model_weights) 84 | 85 | return pre_model, post_model 86 | 87 | 88 | def get_resnet_final_model(img_dim, conv_dim, number_categories, weights, is_final): 89 | 90 | if is_final: 91 | pre_post_weights = None 92 | else: 93 | pre_post_weights = weights 94 | pre_model, post_model = get_resnet_pre_post_model(img_dim, conv_dim, number_categories, model_weights = pre_post_weights) 95 | x_in = Input(shape = (img_dim, img_dim, 3)) 96 | x = pre_model(x_in) 97 | x = post_model(x) 98 | final_model = Model(x_in, x) 99 | if is_final: 100 | print('Loading model weights:', weights) 101 | final_model.load_weights(weights) 102 | return final_model 103 | -------------------------------------------------------------------------------- /transfer/input.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from colorama import init 5 | from termcolor import colored 6 | 7 | def int_input(message, low, high, show_range = True): 8 | ''' 9 | Ask a user for a int input between two values 10 | 11 | args: 12 | message (str): Prompt for user 13 | low (int): Low value, user entered value must be > this value to be accepted 14 | high (int): High value, user entered value must be < this value to be accepted 15 | show_range (boolean, Default True): Print hint to user the range 16 | 17 | returns: 18 | int_in (int): Input integer 19 | ''' 20 | 21 | int_in = low - 1 22 | while (int_in < low) or (int_in > high): 23 | if show_range: 24 | suffix = ' (integer between ' + str(low) + ' and ' + str(high) + ')' 25 | else: 26 | suffix = '' 27 | inp = input('Enter a ' + message + suffix + ': ') 28 | if re.match('^-?[0-9]+$', inp) is not None: 29 | int_in = int(inp) 30 | else: 31 | print(colored('Must be an integer, try again!', 'red')) 32 | return int_in 33 | 34 | 35 | def float_input(message, low, high): 36 | ''' 37 | Ask a user for a float input between two values 38 | 39 | args: 40 | message (str): Prompt for user 41 | low (float): Low value, user entered value must be > this value to be accepted 42 | high (float): High value, user entered value must be < this value to be accepted 43 | 44 | returns: 45 | float_in (int): Input float 46 | ''' 47 | 48 | float_in = low - 1.0 49 | while (float_in < low) or (float_in > high): 50 | inp = input('Enter a ' + message + ' (float between ' + str(low) + ' and ' + str(high) + '): ') 51 | if re.match('^([0-9]*[.])?[0-9]+$', inp) is not None: 52 | float_in = float(inp) 53 | else: 54 | print(colored('Must be a float, try again!', 'red')) 55 | return float_in 56 | 57 | def bool_input(message): 58 | ''' 59 | Ask a user for a boolean input 60 | 61 | args: 62 | message (str): Prompt for user 63 | 64 | returns: 65 | bool_in (boolean): Input boolean 66 | ''' 67 | 68 | while True: 69 | suffix = ' (true or false): ' 70 | inp = input(message + suffix) 71 | if inp.lower() == 'true': 72 | return True 73 | elif inp.lower() == 'false': 74 | return False 75 | else: 76 | print(colored('Must be either true or false, try again!', 'red')) 77 | 78 | def str_input(message, inputs = None): 79 | 80 | user_str = None 81 | while user_str is None: 82 | inp = input(message) 83 | if inputs is None: 84 | user_str = inp 85 | elif inp in inputs: 86 | user_str = inp 87 | else: 88 | print(colored('Invalid input, should be one of:', 'red')) 89 | print(inputs) 90 | return user_str 91 | 92 | 93 | def model_input(project): 94 | 95 | print('Select model weights:') 96 | print('[0] best weights: ', colored(os.path.split(project['best_weights'][0])[-1], 'cyan')) 97 | print('[1] last weights: ', colored(os.path.split(project['last_weights'][0])[-1], 'cyan')) 98 | 99 | model_choice = int_input('choice', 0, 1, show_range = False) 100 | weights = ['best_weights', 'last_weights'][model_choice] 101 | 102 | return weights 103 | 104 | 105 | def model_individual_input(project, weights): 106 | 107 | print('What individual weights would you like to use?') 108 | print('[-1]', colored('Use all weights', 'green')) 109 | for i in range(len(project[weights])): 110 | print('[' + str(i) + ']', colored(os.path.split(project[weights][i])[-1], 'cyan')) 111 | 112 | model_choice = int_input('choice', -1, len(project[weights]), show_range = False) 113 | if model_choice == -1: 114 | return None 115 | return model_choice 116 | -------------------------------------------------------------------------------- /transfer/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from flask import Flask, request 4 | from flask_restful import Resource, Api, reqparse 5 | from flask_jsonpify import jsonify 6 | from colorama import init 7 | from termcolor import colored 8 | import numpy as np 9 | 10 | from transfer.resnet50 import get_resnet_final_model 11 | from transfer.xception import get_xception_final_model 12 | from transfer.inception_v3 import get_inception_v3_final_model 13 | from transfer.predict_model import prep_from_image, gen_from_directory, multi_predict 14 | 15 | def start_server(project, weights): 16 | 17 | app = Flask(__name__) 18 | api = Api(app) 19 | 20 | parser = reqparse.RequestParser() 21 | parser.add_argument('img_path', type = str) 22 | 23 | img_dim = 224 * project['img_size'] 24 | conv_dim = 7 * project['img_size'] 25 | models = [] 26 | for weight in project[weights]: 27 | if project['architecture'] == 'resnet50': 28 | models.append(get_resnet_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 29 | elif project['architecture'] == 'xception': 30 | models.append(get_xception_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 31 | else: 32 | models.append(get_inception_v3_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 33 | 34 | class Predict(Resource): 35 | def post(self): 36 | args = parser.parse_args(strict = True) 37 | img_path = os.path.expanduser(args['img_path']) 38 | if os.path.isfile(img_path): 39 | if img_path.lower().find('.png') > 0 or img_path.lower().find('.jpg') > 0 or img_path.lower().find('.jpeg') > 0: 40 | aug_gen = prep_from_image(img_path, img_dim, project['augmentations']) 41 | pred_list, predicted = multi_predict(aug_gen, models, project['architecture']) 42 | pred_list = [[float(p) for p in pred] for pred in list(pred_list)] 43 | result = {'weights': project[weights], 44 | 'image_path': img_path, 45 | 'predicted': project['categories'][np.argmax(predicted)], 46 | 'classes': project['categories'], 47 | 'class_predictions': pred_list} 48 | 49 | return jsonify(result) 50 | else: 51 | return 'File must be a jpeg or png: ' + args['img_path'] 52 | elif os.path.isdir(img_path): 53 | result = [] 54 | 55 | for aug_gen, file_name in gen_from_directory(img_path, img_dim, project): 56 | pred_list, predicted = multi_predict(aug_gen, models, project['architecture']) 57 | pred_list = [[float(p) for p in pred] for pred in list(pred_list)] 58 | result.append({'weights': project[weights], 59 | 'image_path': file_name, 60 | 'predicted': project['categories'][np.argmax(predicted)], 61 | 'classes': project['categories'], 62 | 'class_predictions': pred_list}) 63 | if len(result) > 0: 64 | return jsonify(result) 65 | else: 66 | return 'No images found in directory: ' + args['img_path'] 67 | 68 | else: 69 | return 'Image does not exist locally: ' + args['img_path'] 70 | 71 | 72 | api.add_resource(Predict, '/predict') 73 | print('') 74 | print('To predict a local image, simply:') 75 | print('') 76 | print(colored('curl http://localhost:' + str(project['api_port']) + '/predict -d "img_path=/path/to/your/img.png" -X POST', 'green')) 77 | print('') 78 | print('or') 79 | print('') 80 | print(colored('curl http://localhost:' + str(project['api_port']) + '/predict -d "img_path=/path/to/your/img_dir" -X POST', 'green')) 81 | print('') 82 | app.run(port = project['api_port']) 83 | -------------------------------------------------------------------------------- /transfer/inception_v3.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from keras.layers import Input, Activation, concatenate, GlobalAveragePooling2D, Dense 5 | from keras.models import Model 6 | from keras import layers 7 | from keras.applications.inception_v3 import InceptionV3 8 | 9 | 10 | def pop_layer(model, count=1): 11 | if not model.outputs: 12 | raise Exception('Sequential model cannot be popped: model is empty.') 13 | 14 | popped = [model.layers.pop() for i in range(count)] 15 | 16 | if not model.layers: 17 | model.outputs = [] 18 | model.inbound_nodes = [] 19 | model.outbound_nodes = [] 20 | else: 21 | model.layers[-1].outbound_nodes = [] 22 | model.outputs = [model.layers[-1].output] 23 | 24 | model.container_nodes = sorted([l.name for l in model.layers]) 25 | model.built = True 26 | 27 | return popped, model 28 | 29 | 30 | def get_inception_v3_model(img_dim): 31 | array_input = Input(shape=(img_dim, img_dim, 3)) 32 | inception_v3 = InceptionV3(include_top=True, 33 | weights='imagenet', 34 | input_tensor=array_input) 35 | return inception_v3 36 | 37 | 38 | def get_inception_v3_pre_model(img_dim): 39 | inception_v3 = get_inception_v3_model(img_dim) 40 | popped, pre_model = pop_layer(inception_v3, 33) 41 | return popped, pre_model 42 | 43 | 44 | def get_inception_v3_pre_post_model(img_dim, conv_dim, number_categories, model_weights = None): 45 | 46 | popped, pre_model = get_inception_v3_pre_model(img_dim) 47 | 48 | input_dims = (conv_dim, conv_dim, 2048) 49 | # Take last 33 layers from inception_v3 with their starting weights! 50 | mixed_9 = Input(shape = input_dims) 51 | 52 | branch1x1 = popped[18](mixed_9) 53 | branch1x1 = popped[12](branch1x1) 54 | branch1x1 = popped[6](branch1x1) 55 | 56 | branch3x3 = popped[29](mixed_9) 57 | branch3x3 = popped[27](branch3x3) 58 | branch3x3 = popped[25](branch3x3) 59 | 60 | branch3x3_1 = popped[23](branch3x3) 61 | branch3x3_1 = popped[17](branch3x3_1) 62 | branch3x3_1 = popped[11](branch3x3_1) 63 | 64 | branch3x3_2 = popped[22](branch3x3) 65 | branch3x3_2 = popped[16](branch3x3_2) 66 | branch3x3_2 = popped[10](branch3x3_2) 67 | 68 | branch3x3 = concatenate([branch3x3_1, branch3x3_2], axis=3, name='mixed9_1') 69 | 70 | branch3x3dbl = popped[32](mixed_9) 71 | branch3x3dbl = popped[31](branch3x3dbl) 72 | branch3x3dbl = popped[30](branch3x3dbl) 73 | branch3x3dbl = popped[28](branch3x3dbl) 74 | branch3x3dbl = popped[26](branch3x3dbl) 75 | branch3x3dbl = popped[24](branch3x3dbl) 76 | 77 | branch3x3dbl_1 = popped[21](branch3x3dbl) 78 | branch3x3dbl_1 = popped[15](branch3x3dbl_1) 79 | branch3x3dbl_1 = popped[9](branch3x3dbl_1) 80 | 81 | branch3x3dbl_2 = popped[20](branch3x3dbl) 82 | branch3x3dbl_2 = popped[14](branch3x3dbl_2) 83 | branch3x3dbl_2 = popped[8](branch3x3dbl_2) 84 | 85 | branch3x3dbl = concatenate([branch3x3dbl_1, branch3x3dbl_2], axis=3, name='concatenate_4') 86 | 87 | branch_pool = popped[19](mixed_9) 88 | branch_pool = popped[13](branch_pool) 89 | branch_pool = popped[7](branch_pool) 90 | branch_pool = popped[3](branch_pool) 91 | 92 | x = concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=3, name='mixed10') 93 | 94 | x = GlobalAveragePooling2D(name='avg_pool')(x) 95 | x = Dense(number_categories, activation='softmax', name='predictions')(x) 96 | 97 | post_model = Model(mixed_9, x) 98 | if model_weights is not None: 99 | print('Loading model weights:', model_weights) 100 | post_model.load_weights(model_weights) 101 | 102 | return pre_model, post_model 103 | 104 | 105 | def get_inception_v3_final_model(img_dim, conv_dim, number_categories, weights, is_final): 106 | 107 | if is_final: 108 | pre_post_weights = None 109 | else: 110 | pre_post_weights = weights 111 | pre_model, post_model = get_inception_v3_pre_post_model(img_dim, conv_dim, number_categories, model_weights = pre_post_weights) 112 | x_in = Input(shape = (img_dim, img_dim, 3)) 113 | x = pre_model(x_in) 114 | x = post_model(x) 115 | final_model = Model(x_in, x) 116 | if is_final: 117 | print('Loading model weights:', weights) 118 | final_model.load_weights(weights) 119 | return final_model 120 | -------------------------------------------------------------------------------- /transfer/augment_arrays.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | from keras.preprocessing.image import ImageDataGenerator 7 | 8 | 9 | def gen_arrays_from_dir(array_dir): 10 | array_files = sorted(os.listdir(array_dir)) 11 | array_names = list(filter(lambda x: r'-img-' in x, array_files)) 12 | label_names = list(filter(lambda x: r'-label-' in x, array_files)) 13 | 14 | assert len(array_names) == len(label_names) 15 | 16 | for array_name, label_name in zip(array_names, label_names): 17 | array = np.load(os.path.join(array_dir, array_name)) 18 | label = np.load(os.path.join(array_dir, label_name)) 19 | yield array, label, label_name 20 | 21 | 22 | def gen_augment_arrays(array, label, augmentations, rounds = 1): 23 | if augmentations is None: 24 | yield array, label 25 | else: 26 | 27 | auggen = ImageDataGenerator(featurewise_center = augmentations['featurewise_center'], 28 | samplewise_center = augmentations['samplewise_center'], 29 | featurewise_std_normalization = augmentations['featurewise_std_normalization'], 30 | samplewise_std_normalization = augmentations['samplewise_std_normalization'], 31 | zca_whitening = augmentations['zca_whitening'], 32 | rotation_range = augmentations['rotation_range'], 33 | width_shift_range = augmentations['width_shift_range'], 34 | height_shift_range = augmentations['height_shift_range'], 35 | shear_range = augmentations['shear_range'], 36 | zoom_range = augmentations['zoom_range'], 37 | channel_shift_range = augmentations['channel_shift_range'], 38 | fill_mode = augmentations['fill_mode'], 39 | cval = augmentations['cval'], 40 | horizontal_flip = augmentations['horizontal_flip'], 41 | vertical_flip = augmentations['vertical_flip'], 42 | rescale = augmentations['rescale']) 43 | 44 | array_augs, label_augs = next(auggen.flow(np.tile(array[np.newaxis], 45 | (rounds * augmentations['rounds'], 1, 1, 1)), 46 | np.tile(label[np.newaxis], 47 | (rounds * augmentations['rounds'], 1)), 48 | batch_size=rounds * augmentations['rounds'])) 49 | 50 | for array_aug, label_aug in zip(array_augs, label_augs): 51 | yield array_aug, label_aug 52 | 53 | 54 | def augment_arrays(project): 55 | 56 | array_path = os.path.join(project['path'], 'array') 57 | augmented_path = os.path.join(project['path'], 'augmented') 58 | shutil.rmtree(augmented_path,ignore_errors=True) 59 | os.makedirs(augmented_path) 60 | 61 | if project['augmentations'] is None: 62 | print('No augmentations selected: copying train arrays as is.') 63 | files = os.listdir(array_path) 64 | for file in tqdm(files): 65 | shutil.copy(os.path.join(array_path, file),augmented_path) 66 | 67 | else: 68 | print('Generating image augmentations:') 69 | 70 | for img_idx, (array, label, label_name) in tqdm(enumerate(gen_arrays_from_dir(array_path))): 71 | split_label_name = '-'.join(label_name.split('-')[2:-1]) 72 | for aug_idx, (array_aug, label_aug) in enumerate(gen_augment_arrays(array, label, project['augmentations'], project['category_rounds'][split_label_name])): 73 | cat_idx = np.argmax(label_aug) 74 | cat = project['categories'][cat_idx] 75 | img_name = '{}-{:02d}-img-{}-{}'.format(img_idx, aug_idx, 76 | cat, cat_idx) 77 | label_name = '{}-{:02d}-label-{}-{}'.format(img_idx, aug_idx, 78 | cat, cat_idx) 79 | aug_path = os.path.join(augmented_path, img_name) 80 | label_path = os.path.join(augmented_path, label_name) 81 | np.save(aug_path, array_aug) 82 | np.save(label_path, label_aug) 83 | 84 | project['is_augmented'] = True 85 | return project 86 | -------------------------------------------------------------------------------- /transfer/predict_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from keras.layers import Input 5 | from keras.layers.core import Lambda 6 | from keras.models import Model 7 | from keras.preprocessing.image import load_img 8 | from keras.applications.resnet50 import preprocess_input as resnet_preprocess_input 9 | from keras.applications.xception import preprocess_input as xception_preprocess_input 10 | from keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess_input 11 | import matplotlib.pyplot as plt 12 | import tensorflow as tf 13 | import pandas as pd 14 | from tqdm import tqdm 15 | import numpy as np 16 | from colorama import init 17 | from termcolor import colored 18 | 19 | from transfer.resnet50 import get_resnet_final_model 20 | from transfer.xception import get_xception_final_model 21 | from transfer.inception_v3 import get_inception_v3_final_model 22 | from transfer.augment_arrays import gen_augment_arrays 23 | 24 | 25 | def prep_from_image(file_name, img_dim, augmentations): 26 | img = np.array(load_img(file_name, target_size = (img_dim, img_dim, 3))) 27 | 28 | return gen_augment_arrays(img, np.array([]), augmentations) 29 | 30 | 31 | def gen_from_directory(directory, img_dim, project): 32 | file_names = [os.path.join(dp, f) for dp, dn, fn in os.walk(directory) for f in fn] 33 | 34 | for file_name in file_names: 35 | if ((file_name.lower().find('.jpg') > 0) or (file_name.lower().find('.jpeg') > 0) or (file_name.lower().find('.png') > 0)): 36 | yield prep_from_image(file_name, img_dim, project['augmentations']), file_name 37 | 38 | 39 | def multi_predict(aug_gen, models, architecture): 40 | predicted = [] 41 | for img, _ in aug_gen: 42 | if architecture == 'resnet50': 43 | img = resnet_preprocess_input(img[np.newaxis].astype(np.float32)) 44 | elif architecture == 'xception': 45 | img = xception_preprocess_input(img[np.newaxis].astype(np.float32)) 46 | else: 47 | img = inception_v3_preprocess_input(img[np.newaxis].astype(np.float32)) 48 | for model in models: 49 | predicted.append(model.predict(img)) 50 | predicted = np.array(predicted).sum(axis=0) 51 | pred_list = list(predicted[0]) 52 | return predicted, pred_list 53 | 54 | def predict_model(project, weights, user_files): 55 | 56 | img_dim = project['img_dim'] * project['img_size'] 57 | conv_dim = project['conv_dim'] * project['img_size'] 58 | models = [] 59 | for weight in project[weights]: 60 | if project['architecture'] == 'resnet50': 61 | models.append(get_resnet_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 62 | elif project['architecture'] == 'xception': 63 | models.append(get_xception_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 64 | else: 65 | models.append(get_inception_v3_final_model(img_dim, conv_dim, project['number_categories'], weight, project['is_final'])) 66 | 67 | output = [] 68 | user_files = os.path.expanduser(user_files) 69 | if os.path.isdir(user_files): 70 | for aug_gen, file_name in tqdm(gen_from_directory(user_files, img_dim, project)): 71 | predicted, pred_list = multi_predict(aug_gen, models, project['architecture']) 72 | output.append([project[weights], file_name, project['categories'][np.argmax(predicted)]] + pred_list) 73 | 74 | elif ((user_files.find('.jpg') > 0) or (user_files.find('.jpeg') > 0) or (user_files.find('.png') > 0)): 75 | aug_gen = prep_from_image(user_files, img_dim, project['augmentations']) 76 | predicted, pred_list = multi_predict(aug_gen, models, project['architecture']) 77 | output.append([project[weights], user_files, project['categories'][np.argmax(predicted)]] + pred_list) 78 | 79 | else: 80 | print(colored('Should either be a directory or a .jpg, .jpeg, and .png', 'red')) 81 | return 82 | 83 | 84 | if len(output) > 0: 85 | columns = ['weights_used','file_name', 'predicted'] + project['categories'] 86 | pred_df = pd.DataFrame(output, columns = columns) 87 | 88 | predictions_file = os.path.join(project['path'], project['name'] + '_' + weights + '_predictions.csv') 89 | if os.path.isfile(predictions_file): 90 | old_pred_df = pd.read_csv(predictions_file) 91 | pred_df = pd.concat([pred_df, old_pred_df]) 92 | 93 | os.makedirs(project['path'], exist_ok = True) 94 | pred_df.to_csv(predictions_file, index = False) 95 | print('Predictions saved to:', colored(predictions_file, 'cyan')) 96 | 97 | else: 98 | print(colored('No image files found.', 'red')) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transfer - making transfer learning easy 2 | 3 | This is a command line tool to perform transfer learning of image classification. 4 | 5 | Currently, it will re-learn the resnet50, inception_v3 or xception models pre-trained on ImageNet. 6 | 7 | Furthermore, it provides a framework to serve your models! You can export a trained model, import it on another computer later, make local predictions or setup a rest api to make predictions. 8 | 9 | [Here are some models I have trained ready for prediction!](http://www.mattso.ch/transfer-models) 10 | 11 | ## How is this different? Why transfer? 12 | 13 | Transfer pre-calculates and saves the early layer outputs for each model. It then re-learns the final several layers and just those layers. This differs from other transfer learning approaches that only learn the very final layer and then relearns every layer in that it is *faster* because it calculates the early layers once and *still very accurate*. It also has built in support for some machine learning best practices like k-fold validation and ensembling of k-fold models. 14 | 15 | As a benchmark, using transfer I was able to score a 0.96599 on the [plant seedling classification](https://www.kaggle.com/c/plant-seedlings-classification) competition on Kaggle. Not quite as good as the fastai benchmark of ~0.98, but good enough for many applications! 16 | 17 | Finally, transfer is meant to be a model delivery platform as well. Train a model with transfer, export it to save it, re-import it elsewhere via transfer, and make predictions on new images! Its a great way to share models with friends, colleagues and collaborators. 18 | 19 | Transfer can manage multiple models simultaneously via a simple project based organization. 20 | 21 | ## What is the community saying 22 | - **@thenomemac : You could probably code something better, but why... Just use transfer** 23 | - **@anonymous : Even I can use this!** 24 | 25 | ## Software to pre-install 26 | 27 | Please first install [tensorflow](https://www.tensorflow.org/install/) and python 3. I recommend installing the latest python via [Anaconda](https://anaconda.org/anaconda/python). 28 | 29 | Install transfer with 30 | 31 | `pip install transfer` 32 | 33 | Thats it! You can test that transfer is correctly installed by typing: 34 | 35 | `transfer` 36 | 37 | You should see help and a list of available commands. Now we just need some images to classify. 38 | 39 | ## Get your images ready! 40 | 41 | Prior to starting, organize the pictures you want to classify by label in a folder. A great example of a project already organized like this is the Kaggle competition for [plant seedling classification](https://www.kaggle.com/c/plant-seedlings-classification). 42 | 43 | In a theoretical example where you are classifying if something is a **hat** or a **donkey** you would organize the images in the following manner: 44 | 45 | ``` 46 | ~/donkey-vs-hat/hat/hat_1.jpg 47 | ~/donkey-vs-hat/hat/hat_2.jpg 48 | ~/donkey-vs-hat/hat/ridiculous_proper_english_lady_hat.jpg 49 | ... 50 | ~/donkey-vs-hat/donkey/donkey_1.jpg 51 | ~/donkey-vs-hat/donkey/super_cute_donkey.jpg 52 | ~/donkey-vs-hat/donkey/donkey_in_tree.jpg 53 | ... 54 | ``` 55 | 56 | Basically put all of your hat pictures in: 57 | 58 | `~/donkey-vs-hat/hat` 59 | 60 | and all of your donkey pictures in: 61 | 62 | `~/donkey-vs-hat/donkey` 63 | 64 | ## Classifying images with transfer 65 | 66 | First configure a project with: 67 | 68 | `transfer --configure` 69 | 70 | Follow the prompts to point to your parent image directory (`~/donkey-vs-hat` in the above example) and to provide modeling parameters. 71 | 72 | You can always see your projects by inspecting the local configuration file: 73 | 74 | `~/.transfer/config.yaml` 75 | 76 | ## Train your models! 77 | 78 | Train your model with: 79 | 80 | `transfer --run` 81 | 82 | ## Predict on an image or directory 83 | 84 | Transfer provides two modes to predict your models with, either make local predictions on either a directory or single images with: 85 | 86 | `transfer --predict PATH_TO_IMAGES` 87 | 88 | or serve your model via a simple local rest-api: 89 | 90 | `transfer --prediction-rest-api` 91 | 92 | ## Save and share your model 93 | 94 | Great, so you trained a model and you can make predictions. Now what? You can save your model and configuration for later import on another computer (with transfer installed, obviously) or even give it to a friend (they probably have difficulty telling the difference between donkeys and hats?) 95 | 96 | Export your model with: 97 | 98 | `transfer --export` 99 | 100 | ## Import pre-trained project 101 | 102 | Did your friend send you a donkey-vs-hat model trained with transfer? Well, how about we import that: 103 | 104 | `transfer --import IMPORT_CONFIG` 105 | 106 | where IMPORT_CONFIG is the path to tar.gz file where the config.yaml and model files are. 107 | 108 | [Here are some models I have trained ready for prediction!](http://www.mattso.ch/transfer-models) 109 | 110 | ## Contribute 111 | 112 | Please, if you use transfer and run into any issues or have suggestions for new features, submit an issue on Github or make a pull request. 113 | -------------------------------------------------------------------------------- /transfer/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import readline 5 | import glob 6 | 7 | import yaml 8 | from colorama import init 9 | from termcolor import colored 10 | try: 11 | import tensorflow 12 | except ModuleNotFoundError: 13 | print(colored('Tensorflow not installed!', 'red')) 14 | print('Note: there are too many system specific things with this module') 15 | print('Please look up how to install it for your system:') 16 | print('') 17 | print(colored('https://www.tensorflow.org/install/', 'yellow')) 18 | print('') 19 | raise ModuleNotFoundError 20 | import keras 21 | 22 | from transfer.project import configure, configure_server, select_project, update_config, import_config, export_config 23 | from transfer import images_to_array, pre_model 24 | from transfer.model import train_model 25 | from transfer.predict_model import predict_model 26 | from transfer.augment_arrays import augment_arrays 27 | from transfer.input import model_input, model_individual_input, str_input 28 | from transfer.server import start_server 29 | 30 | 31 | class Completer(object): 32 | def path_completer(self, text, state): 33 | return [os.path.join(x, '') if os.path.isdir(x) else x for x in glob.glob(os.path.expanduser(text) + '*')][state] 34 | 35 | 36 | def main(args = None): 37 | ''' 38 | Main entry point for transfer command line tool. 39 | 40 | This essentially will marshall the user to the functions they need. 41 | ''' 42 | 43 | parser = argparse.ArgumentParser(description = 'Tool to perform transfer learning') 44 | 45 | parser.add_argument('-c','--configure', 46 | action = 'store_true', 47 | help = 'Configure transfer') 48 | 49 | parser.add_argument('-e','--export', 50 | action = 'store_true', 51 | dest = 'export_config', 52 | help = 'Export configuration and models') 53 | 54 | parser.add_argument('-i','--import', 55 | action = 'store', 56 | type = str, 57 | default = None, 58 | dest = 'import_config', 59 | help = 'Import configuration and models') 60 | 61 | parser.add_argument('-p','--project', 62 | action = 'store', 63 | type = str, 64 | default = None, 65 | dest = 'project', 66 | help = 'Specify a project, if not supplied it will be picked from configured projects') 67 | 68 | parser.add_argument('-r','--run', 69 | action = 'store_true', 70 | help = 'Run all transfer learning operations') 71 | 72 | parser.add_argument('-f','--final', 73 | action = 'store_true', 74 | help = 'Run final training on all layers: Warning SLOW!') 75 | 76 | parser.add_argument('-l','--last-weights', 77 | action = 'store_true', 78 | dest = 'last', 79 | help = 'Restart from the last weights, rather than the best intermediate weights') 80 | 81 | parser.add_argument('--predict', 82 | action = 'store', 83 | type = str, 84 | default = None, 85 | const = 'default', 86 | dest = 'predict', 87 | nargs='?', 88 | help = 'Predict model on file or directory') 89 | 90 | parser.add_argument('--prediction-rest-api', 91 | action = 'store_true', 92 | dest = 'rest_api', 93 | help = 'Start rest api to make predictions on files or directories') 94 | 95 | if len(sys.argv) == 1: 96 | parser.print_help() 97 | return 98 | 99 | args = parser.parse_args() 100 | 101 | if args.import_config is not None: 102 | import_config(args.import_config) 103 | return 104 | elif args.export_config: 105 | project = select_project(args.project) 106 | weights = model_input(project) 107 | ind = model_individual_input(project, weights) 108 | export_config(project, weights, ind) 109 | return 110 | elif args.configure: 111 | configure() 112 | return 113 | else: 114 | project = select_project(args.project) 115 | 116 | if args.run: 117 | 118 | if project['is_array'] == False: 119 | project = images_to_array(project) 120 | update_config(project) 121 | 122 | if project['is_augmented'] == False: 123 | project = augment_arrays(project) 124 | update_config(project) 125 | 126 | if project['is_pre_model'] == False: 127 | project = pre_model(project) 128 | update_config(project) 129 | 130 | project = train_model(project, final = args.final, last = args.last) 131 | update_config(project) 132 | 133 | print('') 134 | print(colored('Completed modeling round: ' + str(project['model_round']), 'cyan')) 135 | print('') 136 | print('Best current model: ', colored(project['best_weights'], 'yellow')) 137 | print('Last current model: ', colored(project['last_weights'], 'yellow')) 138 | print('') 139 | print('To further refine the model, run again with:') 140 | print('') 141 | print(colored(' transfer --run --project ' + project['name'], 'green')) 142 | print('') 143 | print('To make predictions:') 144 | print('') 145 | print(colored(' transfer --predict [optional dir or file] --project ' + project['name'], 'yellow')) 146 | print('') 147 | 148 | elif args.rest_api: 149 | if project['server_weights'] is not None: 150 | start_server(project, 'server_weights') 151 | 152 | elif project['best_weights'] is not None: 153 | weights = model_input(project) 154 | start_server(project, weights) 155 | 156 | else: 157 | print('Model is not trained. Please first run your project:') 158 | print('') 159 | print(colored(' transfer --run', 'green')) 160 | print('') 161 | elif args.predict is not None: 162 | if args.predict == 'default': 163 | 164 | completer = Completer() 165 | readline.set_completer_delims('\t') 166 | readline.parse_and_bind('tab: complete') 167 | readline.set_completer(completer.path_completer) 168 | args.predict = str_input('Enter a path to file(s): ') 169 | if project['server_weights'] is not None: 170 | predict_model(project, 'server_weights', args.predict) 171 | 172 | elif project['best_weights'] is not None: 173 | weights = model_input(project) 174 | print('Predicting on image(s) in: ', colored(args.predict, 'yellow')) 175 | predict_model(project, weights, args.predict) 176 | 177 | else: 178 | print('Model is not trained. Please first run your project:') 179 | print('') 180 | print(colored(' transfer --run', 'green')) 181 | print('') 182 | -------------------------------------------------------------------------------- /transfer/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | from keras.optimizers import Adam, SGD 6 | from keras.callbacks import ModelCheckpoint 7 | from keras.preprocessing.image import load_img 8 | from keras.applications.resnet50 import preprocess_input as resnet_preprocess_input 9 | from keras.applications.xception import preprocess_input as xception_preprocess_input 10 | from keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess_input 11 | from keras import layers 12 | import pandas as pd 13 | from tqdm import tqdm 14 | from sklearn.model_selection import StratifiedKFold 15 | from sklearn.metrics import confusion_matrix 16 | import matplotlib.pyplot as plt 17 | import seaborn as sns 18 | from colorama import init 19 | from termcolor import colored 20 | 21 | from transfer.resnet50 import get_resnet_pre_post_model, get_resnet_final_model 22 | from transfer.xception import get_xception_pre_post_model, get_xception_final_model 23 | from transfer.inception_v3 import get_inception_v3_pre_post_model, get_inception_v3_final_model 24 | 25 | def gen_minibatches(array_dir, array_names, batch_size, architecture, final = False): 26 | 27 | array_names = list(array_names) 28 | 29 | while True: 30 | # in place shuffle 31 | np.random.shuffle(array_names) 32 | array_names_mb = array_names[:batch_size] 33 | 34 | arrays = [] 35 | labels = [] 36 | for array_name in array_names_mb: 37 | img_path = os.path.join(array_dir, array_name) 38 | array = np.load(img_path) 39 | if final: 40 | if architecture == 'resnet50': 41 | array = np.squeeze(resnet_preprocess_input(array[np.newaxis].astype(np.float32))) 42 | elif architecture == 'xception': 43 | array = np.squeeze(xception_preprocess_input(array[np.newaxis].astype(np.float32))) 44 | else: 45 | array = np.squeeze(inception_v3_preprocess_input(array[np.newaxis].astype(np.float32))) 46 | 47 | arrays.append(array) 48 | labels.append(np.load(img_path.replace('-img-','-label-'))) 49 | 50 | yield np.array(arrays), np.array(labels) 51 | 52 | 53 | def no_folds_generator(pre_model_files): 54 | yield [i for i in range(len(pre_model_files))], -1 55 | 56 | 57 | def train_model(project, final = False, last = False): 58 | weight_label = '-' + project['architecture'] + '-weights-' 59 | source_path = project['path'] 60 | weights_path = os.path.join(source_path, 'weights') 61 | plot_path = os.path.join(source_path, 'plots') 62 | if last: 63 | weights = 'last_weights' 64 | else: 65 | weights = 'best_weights' 66 | 67 | if final: 68 | weight_label += '-final-' 69 | use_path = os.path.join(source_path, 'augmented') 70 | else: 71 | use_path = os.path.join(source_path, 'pre_model') 72 | 73 | project['model_round'] += 1 74 | os.makedirs(weights_path, exist_ok=True) 75 | os.makedirs(plot_path, exist_ok=True) 76 | 77 | img_dim = project['img_dim'] * project['img_size'] 78 | conv_dim = project['conv_dim'] * project['img_size'] 79 | 80 | lr = project['learning_rate'] 81 | decay = project['learning_rate_decay'] 82 | 83 | all_files = os.listdir(use_path) 84 | pre_model_files = list(filter(lambda x: r'-img-' in x, all_files)) 85 | label_names = list(filter(lambda x: r'-label-' in x, all_files)) 86 | 87 | pre_model_files_df = pd.DataFrame({'files': pre_model_files}) 88 | pre_model_files_df['suffix'] = pre_model_files_df.apply(lambda row: row.files.split('.')[-1], axis = 1) 89 | pre_model_files_df = pre_model_files_df[pre_model_files_df.suffix == 'npy'] 90 | pre_model_files_df['ind'] = pre_model_files_df.apply(lambda row: row.files.split('-')[0], axis = 1).astype(int) 91 | pre_model_files_df['label'] = pre_model_files_df.apply(lambda row: row.files.split('-')[3], axis = 1) 92 | 93 | pre_model_files_df_dedup = pre_model_files_df.drop_duplicates(subset='ind') 94 | pre_model_files_df = pre_model_files_df.set_index(['ind']) 95 | 96 | pre_model_files.sort() 97 | label_names.sort() 98 | 99 | pre_model_files_arr = np.array(pre_model_files) 100 | label_names_arr = np.array(label_names) 101 | 102 | labels = [np.argmax(np.load(os.path.join(use_path, label_name))) for label_name in label_names] 103 | best_weights = [] 104 | last_weights = [] 105 | 106 | if project['kfold'] >= 3: 107 | kfold = StratifiedKFold(n_splits=project['kfold'], shuffle=True, random_state = project['seed']) 108 | kfold_generator = kfold.split(pre_model_files_df_dedup, pre_model_files_df_dedup.label) 109 | validate = True 110 | else: 111 | print('Too few k-folds selected, fitting on all data') 112 | kfold_generator = no_folds_generator(pre_model_files_df_dedup) 113 | validate = False 114 | 115 | for i, (train, test) in enumerate(kfold_generator): 116 | if project['kfold_every']: 117 | print('----- Fitting Fold', i, '-----') 118 | elif i > 0: 119 | break 120 | 121 | 122 | weights_name = project['name'] + weight_label + '-kfold-' + str(i) + '-round-' + str(project['model_round']) +'.hdf5' 123 | plot_name = project['name'] + weight_label + '-kfold-' + str(i) + '-round-' + str(project['model_round']) +'.png' 124 | 125 | if project[weights] is None: 126 | fold_weights = None 127 | else: 128 | fold_weights = project[weights][i] 129 | if final: 130 | if project['architecture'] == 'resnet50': 131 | model = get_resnet_final_model(img_dim, conv_dim, project['number_categories'], fold_weights, project['is_final']) 132 | elif project['architecture'] == 'xception': 133 | model = get_xception_final_model(img_dim, conv_dim, project['number_categories'], fold_weights, project['is_final']) 134 | else: 135 | model = get_inception_v3_final_model(img_dim, conv_dim, project['number_categories'], fold_weights, project['is_final']) 136 | 137 | for i, layer in enumerate(model.layers[1].layers): 138 | if len(layer.trainable_weights) > 0: 139 | if i < project['final_cutoff']: 140 | mult = 0.01 141 | else: 142 | mult = 0.1 143 | layer.learning_rate_multiplier = [mult for tw in layer.trainable_weights] 144 | 145 | else: 146 | if project['architecture'] == 'resnet50': 147 | pre_model, model = get_resnet_pre_post_model(img_dim, 148 | conv_dim, 149 | len(project['categories']), 150 | model_weights = fold_weights) 151 | elif project['architecture'] == 'xception': 152 | pre_model, model = get_xception_pre_post_model(img_dim, 153 | conv_dim, 154 | len(project['categories']), 155 | model_weights = fold_weights) 156 | else: 157 | pre_model, model = get_inception_v3_pre_post_model(img_dim, 158 | conv_dim, 159 | len(project['categories']), 160 | model_weights = fold_weights) 161 | 162 | pre_model_files_dedup_train = pre_model_files_df_dedup.iloc[train] 163 | train_ind = list(set(pre_model_files_dedup_train.ind)) 164 | pre_model_files_train = pre_model_files_df.loc[train_ind] 165 | 166 | gen_train = gen_minibatches(use_path, pre_model_files_train.files, project['batch_size'], project['architecture'], final = final) 167 | number_train_samples = len(pre_model_files_train) 168 | 169 | if validate: 170 | pre_model_files_dedup_test = pre_model_files_df_dedup.iloc[test] 171 | test_ind = list(set(pre_model_files_dedup_test.ind)) 172 | pre_model_files_test = pre_model_files_df.loc[test_ind] 173 | 174 | gen_test = gen_minibatches(use_path, pre_model_files_test.files, project['batch_size'], project['architecture'], final = final) 175 | number_test_samples = len(pre_model_files_test) 176 | validation_steps = (number_test_samples // project['batch_size']) 177 | 178 | weights_checkpoint_file = weights_name.split('.')[0] + '-kfold-' + str(i) + "-improvement-{epoch:02d}-{val_categorical_accuracy:.4f}.hdf5" 179 | checkpoint = ModelCheckpoint(os.path.join(weights_path, weights_checkpoint_file), 180 | monitor='val_categorical_accuracy', 181 | verbose=1, 182 | save_best_only=True, 183 | mode='max') 184 | 185 | callbacks_list = [checkpoint] 186 | else: 187 | gen_test = None 188 | validation_steps = None 189 | callbacks_list = None 190 | 191 | 192 | steps_per_epoch = (number_train_samples // project['batch_size']) 193 | for j in range(project['rounds']): 194 | optimizer = Adam(lr = lr, decay = decay) 195 | 196 | model.compile(optimizer = optimizer, 197 | loss = 'categorical_crossentropy', 198 | metrics = ['categorical_accuracy']) 199 | 200 | model.fit_generator(gen_train, 201 | steps_per_epoch = steps_per_epoch, 202 | epochs = project['cycle'] * (j + 1), 203 | verbose = 1, 204 | validation_data = gen_test, 205 | validation_steps = validation_steps, 206 | initial_epoch = j * project['cycle'], 207 | callbacks = callbacks_list) 208 | 209 | model.save_weights(os.path.join(weights_path, weights_name)) 210 | last_weights.append(os.path.join(weights_path, weights_name)) 211 | weights_names = os.listdir(weights_path) 212 | max_val = -1 213 | max_i = -1 214 | for j, name in enumerate(weights_names): 215 | if name.find(weights_name.split('.')[0]) >= 0: 216 | if (name.find(weight_label) >= 0) and (name.find('improvement') >= 0): 217 | val = int(name.split('.')[1]) 218 | if val > max_val: 219 | max_val = val 220 | max_i = j 221 | if project['plot']: 222 | print('Plotting confusion matrix') 223 | 224 | if max_i == -1: 225 | print('Loading last weights:', os.path.join(weights_path, weights_name)) 226 | model.load_weights(os.path.join(weights_path, weights_name)) 227 | else: 228 | print('Loading best weights:', os.path.join(weights_path, weights_names[max_i])) 229 | model.load_weights(os.path.join(weights_path, weights_names[max_i])) 230 | best_predictions = [] 231 | true_labels = [] 232 | 233 | print('Predicting test files') 234 | if validate: 235 | use_files = pre_model_files_test.files 236 | else: 237 | use_files = pre_model_files_train.files 238 | for array_name in tqdm(use_files): 239 | img_path = os.path.join(use_path, array_name) 240 | img = np.load(img_path) 241 | if final: 242 | if project['architecture'] == 'resnet50': 243 | img = np.squeeze(resnet_preprocess_input(img[np.newaxis].astype(np.float32))) 244 | elif project['architecture'] == 'xception': 245 | img = np.squeeze(xception_preprocess_input(img[np.newaxis].astype(np.float32))) 246 | else: 247 | img = np.squeeze(inception_v3_preprocess_input(img[np.newaxis].astype(np.float32))) 248 | prediction = model.predict(img[np.newaxis]) 249 | best_predictions.append(project['categories'][np.argmax(prediction)]) 250 | true_label = np.load(img_path.replace('-img-','-label-')) 251 | true_labels.append(project['categories'][np.argmax(true_label)]) 252 | 253 | cm = confusion_matrix(true_labels, best_predictions, project['categories']) 254 | plt.clf() 255 | sns.heatmap(pd.DataFrame(cm, project['categories'], project['categories']), annot = True, fmt = 'g') 256 | plt.xlabel('Actual') 257 | plt.xlabel('Predicted') 258 | plt.xticks(rotation = 45, fontsize = 8) 259 | plt.yticks(rotation = 45, fontsize = 8) 260 | plt.title('Confusion matrix for fold: ' + str(i) + '\nweights' + weights_name) 261 | plt.savefig(os.path.join(plot_path, plot_name)) 262 | print('Confusion matrix plot saved:', colored(os.path.join(plot_path, plot_name), 'magenta')) 263 | 264 | 265 | if max_i == -1: 266 | best_weights.append(os.path.join(weights_path, weights_name)) 267 | else: 268 | best_weights.append(os.path.join(weights_path, weights_names[max_i])) 269 | 270 | project['number_categories'] = len(project['categories']) 271 | project['best_weights'] = best_weights 272 | project['last_weights'] = last_weights 273 | project['is_final'] = final 274 | 275 | return project 276 | -------------------------------------------------------------------------------- /transfer/project.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tarfile 4 | import readline 5 | import glob 6 | 7 | import yaml 8 | import numpy as np 9 | from colorama import init 10 | from termcolor import colored 11 | 12 | from transfer.input import int_input, float_input, bool_input, str_input 13 | 14 | 15 | class Completer(object): 16 | def path_completer(self, text, state): 17 | return [os.path.join(x, '') if os.path.isdir(x) else x for x in glob.glob(os.path.expanduser(text) + '*')][state] 18 | 19 | 20 | def configure(): 21 | ''' 22 | Configure the transfer environment and store 23 | ''' 24 | completer = Completer() 25 | readline.set_completer_delims('\t') 26 | readline.parse_and_bind('tab: complete') 27 | readline.set_completer(completer.path_completer) 28 | 29 | home = os.path.expanduser('~') 30 | if os.path.isfile(os.path.join(home, '.transfer', 'config.yaml')): 31 | with open(os.path.join(home, '.transfer', 'config.yaml'), 'r') as fp: 32 | config = yaml.load(fp.read()) 33 | else: 34 | config = [] 35 | 36 | project_name = input('Name your project: ') 37 | existing_project = None 38 | for project in config: 39 | if project_name == project['name']: 40 | existing_project = project_name 41 | if existing_project is not None: 42 | print(colored('Project ' + project_name + ' already exists', 'red')) 43 | overwrite = str_input('Would you like to overwrite this project? (yes or no) ', ['yes', 'no']) 44 | if overwrite == 'no': 45 | return 46 | else: 47 | config = [project for project in config if project_name != project['name']] 48 | 49 | image_path = os.path.expanduser(input('Select parent directory for your images: ')) 50 | path_unset = True 51 | while path_unset: 52 | project_path = os.path.expanduser(input('Select destination for your project: ')) 53 | if (project_path.find(image_path) == 0): 54 | print('Project destination should not be same or within image directory!') 55 | else: 56 | path_unset = False 57 | 58 | print('Select architecture:') 59 | print('[0] resnet50') 60 | print('[1] xception') 61 | print('[2] inception_v3') 62 | architecture = int_input('choice', 0, 2, show_range = False) 63 | if architecture == 0: 64 | arch = 'resnet50' 65 | img_dim = 224 66 | conv_dim = 7 67 | final_cutoff = 80 68 | elif architecture == 1: 69 | arch = 'xception' 70 | img_dim = 299 71 | conv_dim = 10 72 | final_cutoff = 80 73 | else: 74 | arch = 'inception_v3' 75 | img_dim = 299 76 | conv_dim = 8 77 | final_cutoff = 80 78 | api_port = int_input('port for local prediction API (suggested: 5000)', 1024, 49151) 79 | kfold = int_input('number of folds to use (suggested: 5)', 3, 10) 80 | kfold_every = bool_input('Fit a model for every fold? (if false, just fit one)') 81 | print('Warning: if working on a remote computer, you may not be able to plot!') 82 | plot_cm = bool_input('Plot a confusion matrix after training?') 83 | batch_size = int_input('batch size (suggested: 8)', 1, 64) 84 | learning_rate = float_input('learning rate (suggested: 0.001)', 0, 1) 85 | learning_rate_decay = float_input('learning decay rate (suggested: 0.000001)', 0, 1) 86 | cycle = int_input('number of cycles before resetting the learning rate (suggested: 3)', 1, 10) 87 | num_rounds = int_input('number of rounds (suggested: 3)', 1, 100) 88 | print('Select image resolution:') 89 | print('[0] low (' + str(img_dim) + ' px)') 90 | print('[1] mid (' + str(img_dim * 2) + ' px)') 91 | print('[2] high (' + str(img_dim * 4) + ' px)') 92 | img_resolution_index = int_input('choice', 0, 2, show_range = False) 93 | if img_resolution_index == 0: 94 | img_size = 1 95 | elif img_resolution_index == 1: 96 | img_size = 2 97 | else: 98 | img_size = 4 99 | use_augmentation = str_input('Would you like to add image augmentation? (yes or no) ', ['yes', 'no']) 100 | if use_augmentation == 'yes': 101 | augmentations = select_augmentations() 102 | else: 103 | augmentations = None 104 | 105 | project = {'name': project_name, 106 | 'img_path': image_path, 107 | 'path': project_path, 108 | 'plot': plot_cm, 109 | 'api_port': api_port, 110 | 'kfold': kfold, 111 | 'kfold_every': kfold_every, 112 | 'cycle': cycle, 113 | 'seed': np.random.randint(9999), 114 | 'batch_size': batch_size, 115 | 'learning_rate': learning_rate, 116 | 'learning_rate_decay': learning_rate_decay, 117 | 'final_cutoff': final_cutoff, 118 | 'rounds': num_rounds, 119 | 'img_size': img_size, 120 | 'augmentations': augmentations, 121 | 'architecture': arch, 122 | 'img_dim': img_dim, 123 | 'conv_dim': conv_dim, 124 | 'is_split': False, 125 | 'is_array': False, 126 | 'is_augmented': False, 127 | 'is_pre_model': False, 128 | 'is_final': False, 129 | 'model_round': 0, 130 | 'server_weights': None, 131 | 'last_weights': None, 132 | 'best_weights': None} 133 | 134 | config.append(project) 135 | store_config(config) 136 | print('') 137 | print(colored('Project configure saved!', 'cyan')) 138 | print('') 139 | print('To run project:') 140 | print('') 141 | print(colored(' transfer --run --project ' + project_name, 'green')) 142 | print('or') 143 | print(colored(' transfer -r -p ' + project_name, 'green')) 144 | 145 | 146 | def configure_server(): 147 | ''' 148 | Configure the transfer environment and store 149 | ''' 150 | 151 | home = os.path.expanduser('~') 152 | if os.path.isfile(os.path.join(home, '.transfer', 'config.yaml')): 153 | with open(os.path.join(home, '.transfer', 'config.yaml'), 'r') as fp: 154 | config = yaml.load(fp.read()) 155 | else: 156 | config = [] 157 | 158 | project_name = input('Name your project: ') 159 | existing_project = None 160 | for project in config: 161 | if project_name == project['name']: 162 | existing_project = project_name 163 | if existing_project is not None: 164 | print(colored('Project ' + project_name + ' already exists', 'red')) 165 | overwrite = str_input('Would you like to overwrite this project? (yes or no) ', ['yes', 'no']) 166 | if overwrite == 'no': 167 | return 168 | else: 169 | config = [project for project in config if project_name != project['name']] 170 | 171 | api_port = int_input('port for local prediction API (suggested: 5000)', 1024, 49151) 172 | print('Select image resolution:') 173 | print('[0] low (224 px)') 174 | print('[1] mid (448 px)') 175 | print('[2] high (896 px)') 176 | img_resolution_index = int_input('choice', 0, 2, show_range = False) 177 | if img_resolution_index == 0: 178 | img_size = 1 179 | elif img_resolution_index == 1: 180 | img_size = 2 181 | else: 182 | img_size = 4 183 | num_categories = int_input('number of image categories in your model', 0, 10000000) 184 | 185 | weights = False 186 | while weights == False: 187 | server_weights = os.path.expanduser(input('Select weights file: ')) 188 | if os.path.isfile(server_weights): 189 | weights = True 190 | else: 191 | print('Cannot find the weight file: ', server_weights) 192 | 193 | project = {'name': project_name, 194 | 'api_port': api_port, 195 | 'img_size': img_size, 196 | 'number_categories': num_categories, 197 | 'server_weights': server_weights} 198 | 199 | config.append(project) 200 | store_config(config) 201 | print('') 202 | print(colored('Project configure saved!', 'cyan')) 203 | print('') 204 | print('To start the server:') 205 | print('') 206 | print(colored(' transfer --prediction-rest-api --project ' + project_name, 'green')) 207 | print('or') 208 | print(colored(' transfer --prediction-rest-api -p ' + project_name, 'green')) 209 | 210 | 211 | def select_augmentations(): 212 | print('Select augmentations:') 213 | print(colored('Note: defaults are all zero or false.', 'cyan')) 214 | rounds = int_input('number of augmentation rounds', 1, 100) 215 | # These three are not implemented because they require training and then that would 216 | # need to get propogated over which is complicated for prediction 217 | 218 | featurewise_center = False #bool_input('featurewise_center: set input mean to 0 over the dataset.') 219 | featurewise_std_normalization = False #bool_input('featurewise_std_normalization: divide inputs by std of the dataset.') 220 | zca_whitening = False #bool_input('zca_whitening: apply ZCA whitening.') 221 | samplewise_center = False #bool_input('samplewise_center: set each sample mean to 0.') 222 | samplewise_std_normalization = False #bool_input('samplewise_std_normalization: divide each input by its std.') 223 | rotation_range = int_input('rotation_range: degrees', 0, 180, True) 224 | width_shift_range = float_input('width_shift_range: fraction of total width.', 0., 1.) 225 | height_shift_range = float_input('height_shift_range: fraction of total width.', 0., 1.) 226 | shear_range = float_input('shear_range: shear intensity (shear angle in radians)', 0., np.pi/2) 227 | zoom_range_in = float_input('zoom_range: amount of zoom in. 1.0 is no zoom, 0 is full zoom.', 0., 1.) 228 | zoom_range_out = float_input('zoom_range: amount of zoom out. 1.0 is no zoom, 2.0 is full zoom ', 1., 2.) 229 | channel_shift_range = float_input('channel_shift_rangee: shift range for each channels.', 0., 1.) 230 | print('fill_mode: points outside the boundaries are filled according to the given mode.') 231 | fill_mode = str_input('constant, nearest, reflect, or wrap. Default nearest: ',['constant', 'nearest', 'reflect', 'wrap']) 232 | if (fill_mode == 'constant'): 233 | cval = float_input('cval: value used for points outside the boundaries', 0., 1.) 234 | else: 235 | cval = 0.0 236 | horizontal_flip = bool_input('horizontal_flip: whether to randomly flip images horizontally.') 237 | vertical_flip = bool_input('vertical_flip: whether to randomly flip images vertically.') 238 | rescale = None #int_input('rescale: rescaling factor. If None or 0, no rescaling is applied, otherwise we multiply the data by the value provided.', 0, 255) 239 | if rescale == 0: 240 | rescale = None 241 | 242 | augmentations = {'rounds': rounds, 243 | 'featurewise_center': featurewise_center, 244 | 'featurewise_std_normalization': featurewise_std_normalization, 245 | 'samplewise_center': samplewise_center, 246 | 'samplewise_std_normalization': samplewise_std_normalization, 247 | 'zca_whitening': zca_whitening, 248 | 'rotation_range': rotation_range, 249 | 'width_shift_range': width_shift_range, 250 | 'height_shift_range': height_shift_range, 251 | 'shear_range': shear_range, 252 | 'zoom_range': [zoom_range_in, zoom_range_out], 253 | 'channel_shift_range': channel_shift_range, 254 | 'fill_mode': fill_mode, 255 | 'cval': cval, 256 | 'horizontal_flip': horizontal_flip, 257 | 'vertical_flip': vertical_flip, 258 | 'rescale': rescale} 259 | return augmentations 260 | 261 | 262 | def select_project(user_provided_project): 263 | ''' 264 | Select a project from configuration to run transfer on 265 | 266 | args: 267 | user_provided_project (str): Project name that should match a project in the config 268 | 269 | returns: 270 | project (dict): Configuration settings for a user selected project 271 | 272 | ''' 273 | 274 | home = os.path.expanduser('~') 275 | if os.path.isfile(os.path.join(home, '.transfer', 'config.yaml')): 276 | with open(os.path.join(home, '.transfer', 'config.yaml'), 'r') as fp: 277 | projects = yaml.load(fp.read()) 278 | if len(projects) == 1: 279 | project = projects[0] 280 | else: 281 | if user_provided_project in [project['name'] for project in projects]: 282 | for inner_project in projects: 283 | if user_provided_project == inner_project['name']: 284 | project = inner_project 285 | else: 286 | print('Select your project') 287 | for i, project in enumerate(projects): 288 | print('[' + str(i) + ']: ' + project['name']) 289 | project_index = int_input('project', -1, len(projects), show_range = False) 290 | project = projects[project_index] 291 | else: 292 | print('Transfer is not configured.') 293 | print('Please run:') 294 | print('') 295 | print(colored(' transfer --configure', 'green')) 296 | return 297 | 298 | print(colored('Project selected: ' + project['name'], 'cyan')) 299 | return project 300 | 301 | 302 | def read_imported_config(project_path, project_name, projects = None): 303 | 304 | completer = Completer() 305 | readline.set_completer_delims('\t') 306 | readline.parse_and_bind('tab: complete') 307 | readline.set_completer(completer.path_completer) 308 | 309 | # Oh god this logic is a disaster, user interfaces are hard 310 | bad_user = True 311 | while bad_user: 312 | relearn_str = str_input('Do you want to learn on new starting from these weights? (yes or no) ') 313 | if relearn_str.lower() == 'yes' or relearn_str.lower() == 'y': 314 | bad_user = False 315 | relearn = True 316 | elif relearn_str.lower() == 'no' or relearn_str.lower() == 'n': 317 | bad_user = False 318 | relearn = False 319 | 320 | unique_name = False 321 | while unique_name == False: 322 | unique_name = True 323 | if projects is not None: 324 | for project in projects: 325 | if project['name'] == project_name: 326 | print(colored('Project with this name already exists.', 'red')) 327 | project_name = str_input('Provide a new project name: ') 328 | unique_name = False 329 | if relearn: 330 | image_path = os.path.expanduser(input('Select parent directory for your images: ')) 331 | path_unset = True 332 | while path_unset: 333 | project_dest = os.path.expanduser(input('Select destination for your project: ')) 334 | if (project_dest.find(image_path) == 0): 335 | print('Project destination should not be same or within image directory!') 336 | else: 337 | path_unset = False 338 | else: 339 | project_dest = os.path.expanduser(input('Select destination for your project: ')) 340 | 341 | if os.path.isdir(project_dest) == False: 342 | print('Creating directory:', project_dest) 343 | os.makedirs(project_dest, exist_ok = True) 344 | # You don't get to judge me t('-' t) 345 | with open(os.path.join(project_path, 'config.yaml'), 'r') as fp: 346 | import_project = yaml.load(fp.read()) 347 | import_project['name'] = project_name 348 | import_project['path'] = project_dest 349 | 350 | if relearn: 351 | 352 | kfold = int_input('number of folds to use (suggested: 5)', 3, 10) 353 | kfold_every = bool_input('Fit a model for every fold? (if false, just fit one)') 354 | print('Warning: if working on a remote computer, you may not be able to plot!') 355 | plot_cm = bool_input('Plot a confusion matrix after training?') 356 | batch_size = int_input('batch size (suggested: 8)', 1, 64) 357 | learning_rate = float_input('learning rate (suggested: 0.001)', 0, 1) 358 | learning_rate_decay = float_input('learning decay rate (suggested: 0.000001)', 0, 1) 359 | cycle = int_input('number of cycles before resetting the learning rate (suggested: 3)', 1, 10) 360 | num_rounds = int_input('number of rounds (suggested: 3)', 1, 100) 361 | 362 | import_project['img_path'] = image_path 363 | import_project['best_weights'] = [os.path.join(project_path, weight) for weight in os.listdir(project_path) if weight.find('.hdf5') > 0] 364 | import_project['last_weights'] = import_project['best_weights'] 365 | import_project['server_weights'] = None 366 | import_project['kfold'] = kfold 367 | import_project['kfold_every'] = kfold_every 368 | import_project['cycle'] = cycle 369 | import_project['seed'] = np.random.randint(9999) 370 | import_project['batch_size'] = batch_size 371 | import_project['learning_rate'] = learning_rate 372 | import_project['learning_rate_decay'] = learning_rate_decay 373 | if 'final_cutoff' not in import_project.keys(): 374 | import_project['final_cutoff'] = 80 375 | import_project['rounds'] = num_rounds 376 | import_project['is_split'] = False 377 | import_project['is_array'] = False 378 | import_project['is_augmented'] = False 379 | import_project['is_pre_model'] = False 380 | import_project['model_round'] = 1 381 | import_project['plot'] = plot_cm 382 | 383 | print('') 384 | print('To re-learn new images with project:') 385 | print('') 386 | print(colored(' transfer --run --project ' + project_name, 'green')) 387 | print('or') 388 | print(colored(' transfer -r -p ' + project_name, 'green')) 389 | print('') 390 | else: 391 | import_project['server_weights'] = [os.path.join(project_path, weight) for weight in os.listdir(project_path) if weight.find('.hdf5') > 0] 392 | 393 | return import_project 394 | 395 | def import_config(config_file): 396 | 397 | config_file = os.path.expanduser(config_file) 398 | print(config_file) 399 | transfer_path = os.path.expanduser(os.path.join('~','.transfer')) 400 | 401 | import_temp_path = os.path.join(transfer_path, 'import-temp') 402 | import_path = os.path.join(transfer_path, 'import') 403 | shutil.rmtree(import_temp_path, ignore_errors = True) 404 | os.makedirs(import_temp_path, exist_ok = True) 405 | 406 | if os.path.isfile(config_file) == False: 407 | print('This is not a file:', colored(config_file, 'red')) 408 | return 409 | 410 | with tarfile.open(config_file, mode = "r:gz") as tf: 411 | tf.extractall(path = import_temp_path) 412 | 413 | for listed in os.listdir(import_temp_path): 414 | if os.path.isdir(os.path.join(import_temp_path, listed)): 415 | project_name = listed 416 | 417 | project_path = os.path.join(transfer_path, 'import', project_name) 418 | shutil.rmtree(project_path, ignore_errors = True) 419 | os.makedirs(os.path.join(transfer_path, 'import'), exist_ok = True) 420 | 421 | shutil.move(os.path.join(import_temp_path, project_name), import_path) 422 | shutil.rmtree(import_temp_path, ignore_errors = True) 423 | 424 | print('Imported project:', colored(project_name, 'magenta')) 425 | if os.path.isfile(os.path.join(transfer_path, 'config.yaml')): 426 | with open(os.path.join(transfer_path, 'config.yaml'), 'r') as fp: 427 | projects = yaml.load(fp.read()) 428 | 429 | import_project = read_imported_config(project_path, project_name, projects) 430 | 431 | projects.append(import_project) 432 | store_config(projects) 433 | 434 | else: 435 | shutil.copy(os.path.join(project_path, 'config.yaml'), os.path.join(transfer_path, 'config.yaml')) 436 | 437 | import_project = read_imported_config(project_path, project_name) 438 | 439 | store_config([import_project]) 440 | 441 | print('Project successfully imported!') 442 | print('') 443 | print('Make predictions with:') 444 | print('') 445 | print(colored(' transfer --predict [optional dir or file] --project ' + import_project['name'], 'yellow')) 446 | print('') 447 | print('Or start a prediction server with:') 448 | print('') 449 | print(colored(' transfer --prediction-rest-api --project ' + import_project['name'], 'yellow')) 450 | 451 | 452 | def export_config(config, weights, ind = None): 453 | export_path = os.path.expanduser(os.path.join('~','.transfer','export', config['name'])) 454 | if ind is None: 455 | export_tar = export_path + '_' + weights + '.tar.gz' 456 | else: 457 | export_tar = export_path + '_' + weights + '_kfold_' + str(ind) + '.tar.gz' 458 | 459 | os.makedirs(export_path, exist_ok = True) 460 | server_weights = [] 461 | if ind is None: 462 | for i in range(len(config[weights])): 463 | server_weights.append(os.path.join(export_path, 'server_model_kfold_' + str(i) +'.hdf5')) 464 | shutil.copy(config[weights][i], server_weights[-1]) 465 | else: 466 | server_weights = [os.path.join(export_path, 'server_model_kfold_' + str(ind) +'.hdf5')] 467 | shutil.copy(config[weights][ind], server_weights[-1]) 468 | 469 | project = {'name': config['name'], 470 | 'api_port': config['api_port'], 471 | 'img_size': config['img_size'], 472 | 'img_dim': config['img_dim'], 473 | 'conv_dim': config['conv_dim'], 474 | 'final_cutoff': config['final_cutoff'], 475 | 'architecture': config['architecture'], 476 | 'number_categories': config['number_categories'], 477 | 'categories': config['categories'], 478 | 'augmentations': config['augmentations'], 479 | 'is_final': config['is_final'], 480 | 'server_weights': server_weights} 481 | store_config(project, suffix = os.path.join('export', config['name'])) 482 | 483 | with tarfile.open(export_tar, mode = "w:gz") as tf: 484 | tf.add(os.path.expanduser(os.path.join('~', '.transfer', 'export', config['name'])), config['name']) 485 | 486 | shutil.rmtree(export_path, ignore_errors = True) 487 | print('Project successfully exported, please save the following file for re-import to transfer') 488 | print('') 489 | print(colored(export_tar, 'green')) 490 | 491 | 492 | def store_config(config, suffix = None): 493 | ''' 494 | Store configuration 495 | 496 | args: 497 | config (list[dict]): configurations for each project 498 | ''' 499 | home = os.path.expanduser('~') 500 | if suffix is not None: 501 | config_path = os.path.join(home, '.transfer', suffix) 502 | else: 503 | config_path = os.path.join(home, '.transfer') 504 | 505 | os.makedirs(config_path, exist_ok = True) 506 | with open(os.path.join(config_path, 'config.yaml'), 'w') as fp: 507 | yaml.dump(config, fp) 508 | 509 | 510 | def update_config(updated_project): 511 | ''' 512 | Update project in configuration 513 | 514 | args: 515 | updated_project (dict): Updated project configuration values 516 | 517 | ''' 518 | 519 | home = os.path.expanduser('~') 520 | if os.path.isfile(os.path.join(home, '.transfer', 'config.yaml')): 521 | with open(os.path.join(home, '.transfer', 'config.yaml'), 'r') as fp: 522 | projects = yaml.load(fp.read()) 523 | replace_index = -1 524 | for i, project in enumerate(projects): 525 | if project['name'] == updated_project['name']: 526 | replace_index = i 527 | 528 | if replace_index > -1: 529 | projects[replace_index] = updated_project 530 | store_config(projects) 531 | else: 532 | print('Not saving configuration') 533 | print(colored('Project: ' + updated_project['name'] + ' was not found in configured projects!', 'red')) 534 | 535 | else: 536 | print('Transfer is not configured.') 537 | print('Please run:') 538 | print('') 539 | print(colored(' transfer --configure', 'cyan')) 540 | return 541 | --------------------------------------------------------------------------------