├── oversight ├── notifiers │ ├── __init__.py │ ├── image_file_notifier.py │ ├── pushover_notifier.py │ └── smtp_notifier.py ├── __init__.py ├── image_buffer.py ├── signals.py ├── logging_config.py ├── cnn_classifier.py ├── monitor.py ├── image_source.py └── retrain.py ├── tests ├── __init__.py ├── data │ └── picture.jpeg ├── test_utils.py ├── test_smtp_notifier.py ├── test_pushover_notifier.py ├── test_image_buffer.py └── test_monitor.py ├── requirements.txt ├── .dockerignore ├── Dockerfile ├── bin ├── train.sh └── resizer.sh ├── setup.py ├── .gitignore ├── oversight_runner.py ├── README.md └── LICENSE /oversight/notifiers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /oversight/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blinker=1.4 2 | nose=1.3.7 3 | pillow=3.2.0 4 | requests=2.10.0 5 | tensorflow=0.10.0 -------------------------------------------------------------------------------- /tests/data/picture.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hebenon/oversight/HEAD/tests/data/picture.jpeg -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .idea 3 | tests 4 | .gitignore 5 | README.md 6 | requirements.txt 7 | setup.py 8 | Dockerfile -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow 2 | 3 | MAINTAINER Ben Carson "ben.carson@bigpond.com" 4 | 5 | # Version of Pillow in the container is O.L.D. 6 | RUN pip install --upgrade pillow blinker requests 7 | 8 | ADD . /opt/oversight/ 9 | 10 | RUN chmod +x /opt/oversight/bin/* 11 | 12 | CMD cd /opt/oversight && python oversight_runner.py 13 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | 3 | import io 4 | from PIL import Image 5 | 6 | def load_image(path='./data/picture.jpeg', output_width=1000, output_height=565): 7 | image = Image.open(path) 8 | resized_image = image.resize((output_width, output_height)) 9 | 10 | output_buffer = io.BytesIO() 11 | resized_image.save(output_buffer, 'jpeg') 12 | 13 | return output_buffer.getvalue() -------------------------------------------------------------------------------- /bin/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This idea was pretty much ripped wholesale from https://github.com/xblaster/tensor-guess 4 | python /opt/oversight/oversight/retrain.py \ 5 | --bottleneck_dir=/oversight_data/bottlenecks \ 6 | --how_many_training_steps 4000 \ 7 | --model_dir=/oversight_data/inception \ 8 | --output_graph=/oversight_data/retrained_graph.pb \ 9 | --output_labels=/oversight_data/retrained_labels.txt \ 10 | --image_dir /oversight_data/labelled_images 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | 3 | try: 4 | from setuptools import setup 5 | except ImportError: 6 | from distutils.core import setup 7 | 8 | config = { 9 | 'description': 'Oversight', 10 | 'author': 'Ben Carson', 11 | 'url': 'https://github.com/hebenon/oversight', 12 | 'download_url': 'https://github.com/hebenon/oversight', 13 | 'author_email': 'ben.carson@bigpond.com', 14 | 'version': '0.2', 15 | 'install_requires': ['blinker', 'nose', 'pillow', 'requests', 'tensorflow'], 16 | 'packages': ['oversight'], 17 | 'scripts': [], 18 | 'name': 'oversight' 19 | } 20 | 21 | setup(**config) 22 | -------------------------------------------------------------------------------- /oversight/image_buffer.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | 3 | from signals import image, image_buffer 4 | 5 | 6 | class ImageBuffer(object): 7 | def __init__(self, buffer_length=3): 8 | self.buffer_length = buffer_length 9 | self.image_buffer = [] 10 | 11 | # Handle image events 12 | image.connect(self.handle_image) 13 | image_buffer.connect(self.handle_image_buffer) 14 | 15 | def handle_image(self, sender, **data): 16 | source_image = data['image'] 17 | 18 | if source_image is not None: 19 | self.image_buffer = [source_image] + self.image_buffer[:int(self.buffer_length) - 1] 20 | 21 | def handle_image_buffer(self, sender): 22 | return self.image_buffer -------------------------------------------------------------------------------- /oversight/signals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | from blinker import signal 18 | 19 | image = signal('image') 20 | image_analysis = signal('image_analysis') 21 | image_buffer = signal('image_buffer') 22 | 23 | trigger_event = signal('trigger_event') 24 | -------------------------------------------------------------------------------- /oversight/logging_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | 18 | LOGGING_CONFIG = { 19 | 'version': 1, 20 | 'disable_existing_loggers': True, 21 | 'formatters': { 22 | 'simple': { 23 | 'format': '%(asctime)s %(levelname)s -- %(message)s' 24 | } 25 | }, 26 | 'handlers': { 27 | 'console': { 28 | 'level': 'DEBUG', 29 | 'class': 'logging.StreamHandler', 30 | 'formatter': 'simple' 31 | } 32 | }, 33 | 'loggers': { 34 | 'root': { 35 | 'level': 'DEBUG', 36 | 'handlers': ['console'] 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /bin/resizer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | __author__ = 'bcarson' 4 | # Simple script to resize raw images into standard format. 5 | 6 | import sys 7 | import os 8 | import argparse 9 | 10 | from PIL import Image 11 | 12 | parser = argparse.ArgumentParser(description='resizer') 13 | parser.add_argument('--input_path', default='./labelled_images_raw') 14 | parser.add_argument('--output_path', default='./labelled_images') 15 | parser.add_argument('--output_width', default='1000') 16 | parser.add_argument('--output_height', default='565') 17 | 18 | if __name__ == "__main__": 19 | args = parser.parse_args() 20 | 21 | output_width = int(args.output_width) 22 | output_height = int(args.output_height) 23 | 24 | input_path = args.input_path 25 | output_path = args.output_path 26 | 27 | for category in [path for path in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, path))]: 28 | output_category_path = os.path.join(output_path, category) 29 | input_category_path = os.path.join(input_path, category) 30 | 31 | if not os.path.exists(output_category_path): 32 | os.mkdir(output_category_path) 33 | 34 | for image_file in os.listdir(input_category_path): 35 | output_file_path = os.path.join(output_category_path, image_file) 36 | input_file_path = os.path.join(input_category_path, image_file) 37 | 38 | im = Image.open(input_file_path) 39 | im.resize((output_width, output_height)).save(output_file_path) -------------------------------------------------------------------------------- /oversight/notifiers/image_file_notifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import logging 18 | import os 19 | 20 | from oversight.signals import trigger_event 21 | 22 | logger = logging.getLogger('root') 23 | 24 | class ImageFileNotifier(object): 25 | """ 26 | A notification sink that responds to trigger events by storing the image file. 27 | """ 28 | def __init__(self, labelled_images_root): 29 | self.labelled_images_root = labelled_images_root 30 | 31 | trigger_event.connect(self.handle_trigger_event) 32 | 33 | def handle_trigger_event(self, sender, **data): 34 | # Write the image into a file in under the image root. 35 | filename = '%s-%s.jpg' % (data['source'], data['timestamp'].strftime('%Y%m%d-%H%M%S')) 36 | image_path = '/'.join([self.labelled_images_root, data['prediction'], filename]) 37 | 38 | try: 39 | file_handle = open(image_path, 'wb') 40 | file_handle.write(data['image']) 41 | file_handle.close() 42 | except IOError, e: 43 | logger.error("IO Error: ", e) 44 | 45 | -------------------------------------------------------------------------------- /tests/test_smtp_notifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | from datetime import datetime 18 | 19 | import mock 20 | from blinker import ANY 21 | from nose.tools import with_setup 22 | 23 | from oversight.notifiers.smtp_notifier import SmtpNotifier 24 | from oversight.signals import trigger_event 25 | from tests.test_utils import load_image 26 | 27 | 28 | def teardown(): 29 | # Disconnect any triggers 30 | for receiver in trigger_event.receivers_for(ANY): 31 | trigger_event.disconnect(receiver) 32 | 33 | 34 | @with_setup(teardown=teardown) 35 | @mock.patch('smtplib.SMTP_SSL') 36 | def test_send_mail_to_recipients(mock_smtplib): 37 | from_address = 'Test ' 38 | recipients = ['first@test.test', 'second@test.test'] 39 | smtp_server = 'smtp.test.test' 40 | 41 | smtp_notifier_instance = SmtpNotifier(from_address=from_address, recipients=recipients, smtp_server=smtp_server) 42 | 43 | # Set up test conditions 44 | now = datetime.utcnow() 45 | test_image = load_image() 46 | 47 | trigger_event.send('test', prediction='test event', probability=0.95, source="test_source", timestamp=now, image=test_image) 48 | 49 | # Verify result 50 | assert mock_smtplib.return_value.sendmail.call_count == 1 51 | mock_smtplib.return_value.sendmail.assert_called_with(from_address, recipients, mock.ANY) 52 | -------------------------------------------------------------------------------- /tests/test_pushover_notifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import calendar 18 | from datetime import datetime 19 | 20 | import mock 21 | from blinker import ANY 22 | from nose.tools import with_setup 23 | 24 | from oversight.notifiers.pushover_notifier import PushoverNotifier 25 | from oversight.signals import trigger_event 26 | from tests.test_utils import load_image 27 | 28 | 29 | def teardown(): 30 | # Disconnect any triggers 31 | for receiver in trigger_event.receivers_for(ANY): 32 | trigger_event.disconnect(receiver) 33 | 34 | 35 | @with_setup(teardown=teardown) 36 | @mock.patch('requests.post') 37 | def test_send_notification(mock_requests): 38 | pushover_user = 'test_user' 39 | pushover_token = 'test_token' 40 | pushover_device = 'test_device' 41 | 42 | pushover_notifier_instance = PushoverNotifier(pushover_user=pushover_user, 43 | pushover_token=pushover_token, 44 | pushover_device=pushover_device) 45 | 46 | # Set up test conditions 47 | now = datetime.utcnow() 48 | test_image = load_image() 49 | test_prediction = 'test event' 50 | test_source = 'test_source' 51 | 52 | trigger_event.send('test', prediction=test_prediction, probability=0.95, timestamp=now, image=test_image, source=test_source) 53 | 54 | # Verify result 55 | assert mock_requests.call_count == 1 56 | assert mock_requests.call_args[1]['data']['user'] == pushover_user 57 | assert mock_requests.call_args[1]['data']['token'] == pushover_token 58 | assert mock_requests.call_args[1]['data']['timestamp'] == calendar.timegm(now.timetuple()) 59 | assert test_prediction in mock_requests.call_args[1]['data']['message'] 60 | assert test_source in mock_requests.call_args[1]['data']['message'] 61 | -------------------------------------------------------------------------------- /oversight/cnn_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import tensorflow as tf 18 | import os 19 | 20 | from signals import image, image_analysis 21 | 22 | 23 | class CNNClassifier(object): 24 | """ 25 | Classifier class which uses a Convolutional Neural Network (CNN) to predict the contents of the image. 26 | The CNNClassifier requires a graph that has been pre-trained with labels of the expected events. 27 | """ 28 | 29 | def __init__(self, model_directory, session): 30 | self.session = session 31 | 32 | graph_file = os.path.expanduser(os.path.join(model_directory, "retrained_graph.pb")) 33 | label_file = os.path.expanduser(os.path.join(model_directory, "retrained_labels.txt")) 34 | 35 | self.labels = [line.rstrip() for line in tf.gfile.GFile(label_file)] 36 | 37 | with tf.gfile.FastGFile(graph_file, 'rb') as f: 38 | graph_def = tf.GraphDef() 39 | graph_def.ParseFromString(f.read()) 40 | _ = tf.import_graph_def(graph_def, name='') 41 | 42 | image.connect(self.predict) 43 | 44 | def predict(self, sender, **data): 45 | image_data = data['image'] 46 | 47 | # Feed the image_data as input to the graph and get first prediction 48 | softmax_tensor = self.session.graph.get_tensor_by_name('final_result:0') 49 | 50 | predictions = self.session.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data}) 51 | 52 | # Sort to show labels of first prediction in order of confidence 53 | top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] 54 | 55 | results = [] 56 | 57 | for node_id in top_k: 58 | results += [(self.labels[node_id], predictions[0][node_id])] 59 | 60 | # Emit analysis results 61 | image_analysis.send(self, source=data['source'], timestamp=data['timestamp'], image=image_data, predictions=results) 62 | -------------------------------------------------------------------------------- /oversight/notifiers/pushover_notifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import logging 18 | import calendar 19 | 20 | import requests 21 | 22 | from oversight.signals import trigger_event, image_buffer 23 | 24 | logger = logging.getLogger('root') 25 | 26 | class PushoverNotifier(object): 27 | """ 28 | A notification sink that sends notifications via pushover. 29 | """ 30 | def __init__(self, pushover_user, pushover_token, pushover_device=None): 31 | self.pushover_user = pushover_user 32 | self.pushover_token = pushover_token 33 | self.pushover_device = pushover_device 34 | 35 | trigger_event.connect(self.handle_trigger_event) 36 | 37 | def handle_trigger_event(self, sender, **data): 38 | # Pushover doesn't support images, so just send the event. 39 | notification_data = { 40 | "user": self.pushover_user, 41 | "token": self.pushover_token, 42 | "message": "Camera %s, event: %s" % (data['source'], data['prediction']), 43 | "timestamp": calendar.timegm(data['timestamp'].timetuple()) 44 | } 45 | 46 | # Optionally, set the device. 47 | if self.pushover_device: 48 | notification_data['device'] = self.pushover_device 49 | 50 | try: 51 | r = requests.post("https://api.pushover.net/1/messages.json", data=notification_data) 52 | 53 | if r.status_code != 200: 54 | logger.error("Failed to send notification, (%d): %s" % (r.status_code, r.text)) 55 | except requests.ConnectionError, e: 56 | logger.error("Connection Error:", e) 57 | except requests.HTTPError, e: 58 | logger.error("HTTP Error:", e) 59 | 60 | def get_image_buffer(self): 61 | image_set = [] 62 | for (receiver, return_value) in image_buffer.send(self): 63 | image_set.extend(return_value) 64 | 65 | return image_set -------------------------------------------------------------------------------- /tests/test_image_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | from datetime import datetime, timedelta 18 | 19 | from blinker import ANY 20 | from nose.tools import with_setup 21 | 22 | from oversight.signals import image, image_buffer 23 | from oversight.image_buffer import ImageBuffer 24 | 25 | from test_utils import load_image 26 | 27 | 28 | def get_image_buffer(): 29 | image_set = [] 30 | for (receiver, return_value) in image_buffer.send('test'): 31 | image_set.extend(return_value) 32 | 33 | return image_set 34 | 35 | 36 | def teardown(): 37 | # Disconnect any triggers 38 | for receiver in image.receivers_for(ANY): 39 | image.disconnect(receiver) 40 | 41 | for receiver in image_buffer.receivers_for(ANY): 42 | image_buffer.disconnect(receiver) 43 | 44 | @with_setup(teardown=teardown) 45 | def test_return_images_if_less_than_length(): 46 | image_buffer_instance = ImageBuffer(buffer_length=4) 47 | 48 | # Set up test conditions 49 | now = datetime.utcnow() 50 | test_image = load_image() 51 | 52 | image.send('test', source="test_source", timestamp=now, image=test_image) 53 | 54 | result = get_image_buffer() 55 | 56 | # Verify result 57 | assert len(result) is 1 58 | assert result[0] == test_image 59 | 60 | @with_setup(teardown=teardown) 61 | def test_return_buffer_length_images(): 62 | buffer_length = 4 63 | image_buffer_instance = ImageBuffer(buffer_length=buffer_length) 64 | 65 | # Set up test conditions 66 | now = datetime.utcnow() 67 | 68 | test_images = [] 69 | for i in xrange(0, buffer_length + 1): 70 | test_image = load_image() 71 | test_images.append(test_image) 72 | image.send('test', source="test_source", timestamp=now + timedelta(seconds=i), image=test_image) 73 | 74 | result = get_image_buffer() 75 | 76 | # Verify result 77 | assert len(result) is buffer_length 78 | 79 | offset = len(test_images) - buffer_length 80 | for i in xrange(0, len(result)): 81 | assert result[i] == test_images[i + offset] -------------------------------------------------------------------------------- /oversight/notifiers/smtp_notifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import smtplib 18 | 19 | from email.mime.image import MIMEImage 20 | from email.mime.multipart import MIMEMultipart 21 | 22 | from oversight.signals import trigger_event, image_buffer 23 | 24 | COMMASPACE = ', ' 25 | 26 | 27 | class SmtpNotifier(object): 28 | """ 29 | A notification sink that sends notifications via smtp. 30 | """ 31 | def __init__(self, from_address, recipients, smtp_server, username=None, password=None): 32 | self.smtp_server = smtp_server 33 | self.username = username 34 | self.password = password 35 | self.from_address = from_address 36 | self.recipients = recipients 37 | 38 | trigger_event.connect(self.handle_trigger_event) 39 | 40 | def handle_trigger_event(self, sender, **data): 41 | # Retrieve the image buffer, and add the event image, if not there. 42 | images = self.get_image_buffer() 43 | event_image = data['image'] 44 | 45 | if event_image not in images: 46 | images = [event_image] + images 47 | 48 | # Create the container (outer) email message. 49 | msg = MIMEMultipart() 50 | msg['Subject'] = 'Camera %s, event: %s' % (data['source'], data['prediction']) 51 | 52 | # Set addresses 53 | msg['From'] = self.from_address 54 | msg['To'] = COMMASPACE.join(self.recipients) 55 | msg.preamble = msg['Subject'] 56 | 57 | for image in images: 58 | # Open the files in binary mode. Let the MIMEImage class automatically 59 | img = MIMEImage(image) 60 | msg.attach(img) 61 | 62 | # Send the email via our own SMTP server. 63 | s = smtplib.SMTP_SSL(self.smtp_server) 64 | s.sendmail(self.from_address, self.recipients, msg.as_string()) 65 | s.quit() 66 | 67 | def get_image_buffer(self): 68 | image_set = [] 69 | for (receiver, return_value) in image_buffer.send(self): 70 | image_set.extend(return_value) 71 | 72 | return image_set -------------------------------------------------------------------------------- /oversight/monitor.py: -------------------------------------------------------------------------------- 1 | __author__ = 'bcarson' 2 | 3 | import logging 4 | 5 | from threading import Timer 6 | 7 | from signals import image_analysis, trigger_event 8 | 9 | logger = logging.getLogger('root') 10 | 11 | 12 | class Monitor(object): 13 | def __init__(self, triggers, notification_delay=2): 14 | self.triggers = triggers 15 | self.notification_delay = notification_delay 16 | self.active_triggers = dict() 17 | 18 | self.notification_timer = None 19 | 20 | image_analysis.connect(self.handle_image_analysis) 21 | 22 | def send_notification(self, prediction, probability, source, timestamp, image): 23 | logger.debug("Sending trigger: %s (%f) @ %s" % (prediction, probability, str(timestamp))) 24 | trigger_event.send(self, prediction=prediction, probability=probability, 25 | source=source, timestamp=timestamp, image=image) 26 | 27 | def handle_image_analysis(self, sender, **data): 28 | # Get predictions 29 | predictions = data['predictions'] 30 | source = data['source'] 31 | 32 | # Check if this is a new source or not. 33 | if source not in self.active_triggers: 34 | self.active_triggers[source] = set() 35 | 36 | # Check for a result 37 | for (prediction, probability) in predictions: 38 | logger.debug("prediction %s: %f", prediction, probability) 39 | 40 | # The graph uses softmax in the final layer, so it's *unlikely* that this will be useful. 41 | # That being said, it's possible to configure multiple triggers with low thresholds. 42 | if prediction in self.triggers and probability >= self.triggers[prediction]: 43 | # Prevent alarm storms by not acting on active triggers 44 | if prediction not in self.active_triggers[source]: 45 | logger.warning("Trigger event active: %s %f", prediction, probability) 46 | self.active_triggers[source].add(prediction) 47 | 48 | # Only send a notification if one isn't already triggered. 49 | if not self.notification_timer or not self.notification_timer.isAlive(): 50 | self.notification_timer = Timer(self.notification_delay, self.send_notification, 51 | (prediction, probability, source, data['timestamp'], data['image'])) 52 | self.notification_timer.start() 53 | else: 54 | # Log any clearing alarms 55 | if prediction in self.active_triggers[source]: 56 | logger.warning("Trigger event ended: %s %f", prediction, probability) 57 | 58 | self.active_triggers[source].discard(prediction) # Remove from active triggers (if it exists) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python,intellij 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # IPython Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv/ 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | 97 | ### Intellij ### 98 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 99 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 100 | 101 | # User-specific stuff: 102 | .idea/workspace.xml 103 | .idea/tasks.xml 104 | .idea/dictionaries 105 | .idea/vcs.xml 106 | .idea/jsLibraryMappings.xml 107 | 108 | # Sensitive or high-churn files: 109 | .idea/dataSources.ids 110 | .idea/dataSources.xml 111 | .idea/dataSources.local.xml 112 | .idea/sqlDataSources.xml 113 | .idea/dynamic.xml 114 | .idea/uiDesigner.xml 115 | 116 | # Gradle: 117 | .idea/gradle.xml 118 | .idea/libraries 119 | 120 | # Mongo Explorer plugin: 121 | .idea/mongoSettings.xml 122 | 123 | ## File-based project format: 124 | *.iws 125 | 126 | ## Plugin-specific files: 127 | 128 | # IntelliJ 129 | /out/ 130 | 131 | # mpeltonen/sbt-idea plugin 132 | .idea_modules/ 133 | 134 | # JIRA plugin 135 | atlassian-ide-plugin.xml 136 | 137 | # Crashlytics plugin (for Android Studio and IntelliJ) 138 | com_crashlytics_export_strings.xml 139 | crashlytics.properties 140 | crashlytics-build.properties 141 | fabric.properties 142 | 143 | ### Intellij Patch ### 144 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 145 | 146 | # *.iml 147 | # modules.xml 148 | # .idea/misc.xml 149 | # *.ipr 150 | 151 | ## Oversight specific files 152 | images -------------------------------------------------------------------------------- /oversight/image_source.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import requests 18 | import io 19 | import logging 20 | import random 21 | import time 22 | 23 | from datetime import datetime 24 | from threading import Timer 25 | 26 | from PIL import Image 27 | 28 | from signals import image 29 | 30 | logger = logging.getLogger('root') 31 | 32 | 33 | class ImageSource(object): 34 | """ 35 | ImageSource will generate a stream of image events. 36 | It periodically connects to a URL and downloads an image to generate each event. 37 | """ 38 | 39 | def __init__(self, download_url, username=None, password=None, tag=None, download_frequency=2.0, output_width=1000, output_height=565): 40 | self.download_url = download_url 41 | 42 | if username is not None: 43 | self.authorisation = (username, password) 44 | else: 45 | self.authorisation = None 46 | 47 | self.tag = tag 48 | 49 | # Size of images to work with 50 | self.output_width = output_width 51 | self.output_height = output_height 52 | 53 | self.download_frequency = download_frequency 54 | 55 | Timer(self.download_frequency * random.random(), self.get_image).start() 56 | 57 | def get_image(self): 58 | start = time.time() 59 | downloaded_image = None 60 | resized_image = None 61 | 62 | try: 63 | request = requests.get(self.download_url, auth=self.authorisation) 64 | 65 | if request.status_code is 200: 66 | downloaded_image = io.BytesIO(request.content) 67 | 68 | except requests.ConnectionError, e: 69 | logger.error("Connection Error: %s", e) 70 | except requests.HTTPError, e: 71 | logger.error("HTTP Error: %s", e) 72 | 73 | if downloaded_image is not None: 74 | try: 75 | resized_image = self.get_resized_image(downloaded_image) 76 | except IOError, e: 77 | logger.error("Failed to resize image: %s", e) 78 | 79 | if resized_image is not None: 80 | image.send(self, timestamp=datetime.utcnow(), image=resized_image, source=self.tag) 81 | 82 | next_time = max(self.download_frequency - (time.time() - start), 0) 83 | Timer(next_time, self.get_image).start() 84 | 85 | def get_resized_image(self, image_input): 86 | """ 87 | Given a raw image from an image source, resize it to a standard size. Doing this results in more consistent 88 | results against the training set. 89 | 90 | :param image_input: A buffer with the raw image data. 91 | :return: Resized image data in jpeg format. 92 | """ 93 | image = Image.open(image_input) 94 | resized_image = image.resize((self.output_width, self.output_height)) 95 | 96 | output_buffer = io.BytesIO() 97 | resized_image.save(output_buffer, 'jpeg') 98 | 99 | return output_buffer.getvalue() 100 | -------------------------------------------------------------------------------- /tests/test_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | from datetime import datetime, timedelta 18 | 19 | from blinker import ANY 20 | from nose.tools import with_setup 21 | 22 | from oversight.signals import image_analysis, trigger_event 23 | from oversight.monitor import Monitor 24 | 25 | from test_utils import load_image 26 | 27 | 28 | def teardown(): 29 | # Disconnect any triggers 30 | for receiver in image_analysis.receivers_for(ANY): 31 | image_analysis.disconnect(receiver) 32 | 33 | for receiver in trigger_event.receivers_for(ANY): 34 | trigger_event.disconnect(receiver) 35 | 36 | @with_setup(teardown=teardown) 37 | def test_generate_event_if_over_threshold(): 38 | monitor = Monitor(triggers={'test_event': 0.5}, notification_delay=0) 39 | 40 | # Set up test conditions 41 | now = datetime.utcnow() 42 | image = load_image() 43 | generated_events = [] 44 | 45 | def event_received(sender,**data): 46 | generated_events.append(data) 47 | 48 | trigger_event.connect(event_received) 49 | 50 | image_analysis.send('test', source="test_source", timestamp=now, image=image, predictions=[('test_event', 0.99)]) 51 | 52 | # Verify result 53 | expected = dict(timestamp=now, source="test_source", image=image, prediction='test_event', probability=0.99) 54 | 55 | assert len(generated_events) is 1 56 | assert generated_events[0] == expected 57 | 58 | @with_setup(teardown=teardown) 59 | def test_no_event_if_under_threshold(): 60 | monitor = Monitor(triggers={'test_event': 0.5}, notification_delay=0) 61 | 62 | # Set up test conditions 63 | now = datetime.utcnow() 64 | image = load_image() 65 | generated_events = [] 66 | 67 | def event_received(sender,**data): 68 | generated_events.append(data) 69 | 70 | trigger_event.connect(event_received) 71 | 72 | image_analysis.send('test', source="test_source", timestamp=now, image=image, predictions=[('test_event', 0.4)]) 73 | 74 | # Verify result 75 | assert len(generated_events) is 0 76 | 77 | @with_setup(teardown=teardown) 78 | def test_no_event_if_already_active(): 79 | monitor = Monitor(triggers={'test_event': 0.5}, notification_delay=0) 80 | 81 | # Set up test conditions 82 | now = datetime.utcnow() 83 | image = load_image() 84 | generated_events = [] 85 | 86 | def event_received(sender,**data): 87 | generated_events.append(data) 88 | 89 | trigger_event.connect(event_received) 90 | 91 | # Trigger once 92 | image_analysis.send('test', source="test_source", timestamp=now, image=image, predictions=[('test_event', 0.99)]) 93 | 94 | # Second prediction - should not generate 95 | image_analysis.send('test', source="test_source", timestamp=now + timedelta(seconds=1), image=image, predictions=[('test_event', 0.99)]) 96 | 97 | # Verify result 98 | expected = dict(timestamp=now, source="test_source", image=image, prediction='test_event', probability=0.99) 99 | 100 | assert len(generated_events) is 1 101 | assert generated_events[0] == expected 102 | 103 | @with_setup(teardown=teardown) 104 | def test_generate_event_once_active_cleared(): 105 | monitor = Monitor(triggers={'test_event': 0.5}, notification_delay=0) 106 | 107 | # Set up test conditions 108 | now = datetime.utcnow() 109 | image = load_image() 110 | generated_events = [] 111 | 112 | def event_received(sender,**data): 113 | generated_events.append(data) 114 | 115 | trigger_event.connect(event_received) 116 | 117 | # Trigger initial 118 | image_analysis.send('test', source="test_source", timestamp=now, image=image, predictions=[('test_event', 0.99)]) 119 | 120 | # Clear the first alarm. 121 | image_analysis.send('test', source="test_source", timestamp=now + timedelta(seconds=1), image=image, predictions=[('test_event', 0.4)]) 122 | 123 | # Send the third event 124 | image_analysis.send('test', source="test_source", timestamp=now + timedelta(seconds=2), image=image, predictions=[('test_event', 0.99)]) 125 | 126 | # Verify result 127 | expected_first = dict(source="test_source", timestamp=now, image=image, prediction='test_event', probability=0.99) 128 | expected_second = dict(source="test_source", timestamp=now + timedelta(seconds=2), image=image, prediction='test_event', probability=0.99) 129 | 130 | assert len(generated_events) is 2 131 | assert generated_events[0] == expected_first 132 | assert generated_events[1] == expected_second -------------------------------------------------------------------------------- /oversight_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Ben Carson. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | __author__ = 'bcarson' 16 | 17 | import argparse 18 | import logging 19 | import logging.config 20 | import time 21 | import urlparse 22 | 23 | import os 24 | import tensorflow as tf 25 | 26 | from oversight.cnn_classifier import CNNClassifier 27 | from oversight.image_buffer import ImageBuffer 28 | from oversight.image_source import ImageSource 29 | from oversight.logging_config import LOGGING_CONFIG 30 | from oversight.monitor import Monitor 31 | from oversight.notifiers.pushover_notifier import PushoverNotifier 32 | from oversight.notifiers.smtp_notifier import SmtpNotifier 33 | from oversight.notifiers.image_file_notifier import ImageFileNotifier 34 | 35 | logger = logging.getLogger('root') 36 | 37 | 38 | def validate_args(): 39 | """ 40 | Validate the provided arguments and return a parsed object. 41 | Because we use os.environ to provide the default value for unprovided arguments, 42 | we have to manually check for mandatory arguments. 43 | 44 | :return: parsed ArgumentParser object 45 | """ 46 | parser = argparse.ArgumentParser(description='oversight') 47 | parser.add_argument('--download_urls', default=os.environ.get('OVERSIGHT_DOWNLOAD_URLS', '').split(' '), nargs='*') 48 | parser.add_argument('--model_directory', default=os.environ.get('OVERSIGHT_MODEL_DIRECTORY', '~/.oversight')) 49 | parser.add_argument('--image_buffer_length', default=os.environ.get('OVERSIGHT_IMAGE_BUFFER_LENGTH', 3), type=int) 50 | parser.add_argument('--notification_delay', default=os.environ.get('OVERSIGHT_NOTIFICATION_DELAY', 2), type=int) 51 | parser.add_argument('--smtp_recipients', default=os.environ.get('OVERSIGHT_SMTP_RECIPIENTS', ''), nargs='*') 52 | parser.add_argument('--smtp_host', default=os.environ.get('OVERSIGHT_SMTP_HOST', '')) 53 | parser.add_argument('--pushover_user', default=os.environ.get('OVERSIGHT_PUSHOVER_USER', '')) 54 | parser.add_argument('--pushover_token', default=os.environ.get('OVERSIGHT_PUSHOVER_TOKEN', '')) 55 | parser.add_argument('--pushover_device', default=os.environ.get('OVERSIGHT_PUSHOVER_DEVICE', '')) 56 | parser.add_argument('--image_storage_directory', default=os.environ.get('OVERSIGHT_IMAGE_STORAGE_DIRECTORY')) 57 | parser.add_argument('--triggers', default=os.environ.get('OVERSIGHT_TRIGGERS', '').split(' '), nargs='*') 58 | parser.add_argument('--log_level', default=os.environ.get('OVERSIGHT_LOG_LEVEL', 'INFO')) 59 | args = parser.parse_args() 60 | 61 | # Mandatory args 62 | if len(args.download_urls) < 1: 63 | exit(parser.print_usage()) 64 | 65 | return args 66 | 67 | 68 | def parse_triggers(trigger_args): 69 | """ 70 | Parses a list of trigger:threshold pairs to return a dictionary of triggers. 71 | I've used a function instead of a list slice for the opportunity to apply validation. 72 | 73 | :param trigger_args: List of trigger:threshold pairs. 74 | :return: Dictionary of trigger -> threshold. 75 | """ 76 | triggers = {} 77 | 78 | for trigger_pair in trigger_args: 79 | (trigger, threshold) = trigger_pair.split(':') 80 | triggers[trigger] = float(threshold) 81 | 82 | return triggers 83 | 84 | 85 | def create_image_sources(image_args): 86 | """ 87 | Parses a list of name:url pairs to create a list of image sources. 88 | :param image_args: list of name:url pairs that describe image sources. 89 | :return: list of image sources 90 | """ 91 | image_sources = [] 92 | for image_source in image_args: 93 | (tag, image_url) = image_source.split(":", 1) 94 | parsed = urlparse.urlparse(image_url) 95 | 96 | plain_url = "%s://%s%s%s" % (parsed.scheme, 97 | parsed.hostname, 98 | ":%s" % parsed.port if parsed.port else "", parsed.path) 99 | ImageSource(plain_url, parsed.username, parsed.password, tag=tag) 100 | 101 | return image_sources 102 | 103 | 104 | def main(_): 105 | args = validate_args() 106 | 107 | # Configure logging 108 | LOGGING_CONFIG["loggers"]["root"]["level"] = args.log_level 109 | logging.config.dictConfig(LOGGING_CONFIG) 110 | 111 | with tf.Session() as sess: 112 | # Create notifiers 113 | notifiers = [] 114 | smtp_recipients = args.smtp_recipients.split(',') 115 | if len(smtp_recipients) > 0 and args.smtp_host: 116 | notifiers.append(SmtpNotifier('Oversight ', smtp_recipients, args.smtp_host)) 117 | 118 | if args.pushover_user and args.pushover_token: 119 | notifiers.append(PushoverNotifier(args.pushover_user, args.pushover_token)) 120 | 121 | if args.image_storage_directory: 122 | notifiers.append(ImageFileNotifier(args.image_storage_directory)) 123 | 124 | # Parse any triggers, and create monitor 125 | triggers = parse_triggers(args.triggers) 126 | logger.info('Triggers: ' + str(triggers)) 127 | monitor = Monitor(triggers, args.notification_delay) 128 | 129 | # Create classifiers 130 | logger.info('Loading classifier...') 131 | classifier = CNNClassifier(args.model_directory, sess) 132 | 133 | # Create image buffer 134 | image_buffer = ImageBuffer(args.image_buffer_length * len(args.download_urls)) 135 | 136 | # Create image source 137 | logger.info('Creating image sources...') 138 | image_sources = create_image_sources(args.download_urls) 139 | 140 | while True: 141 | time.sleep(2) 142 | 143 | if __name__ == '__main__': 144 | logger.info('Initialising Tensorflow...') 145 | tf.app.run() 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Oversight 2 | Oversight is an application that uses machine learning to detect events in IP camera feeds. 3 | 4 | ## Building 5 | The easiest way to build and run Oversight is using [Docker](https://docker.com). From the directory where you checked out the Oversight code, build the container: 6 | 7 | docker build -t oversight . 8 | 9 | ## Usage 10 | ### Training the model 11 | Before you can run Oversight, you need to train a model that will recognise the different events that you want the system to react to. To do this, you will first need to prepare a set of images that represent each event. For this example, we will use three categories: person, car and nothing, but you can use any categories you want. The last condition, nothing, is important so that the system can tell when none of the other conditions has occurred. 12 | 13 | First, prepare a directory structure for the labelled images: 14 | 15 | . 16 | +-- _config.yml 17 | +-- _labelled_images 18 | | +-- car 19 | | +-- nothing 20 | | +-- person 21 | 22 | In each of the labelled sub-folders place example images from your cameras. In the nothing folder, place images that don't include any of the active categories (e.g. people or cars) you want to track. Now you're ready to train the model. 23 | 24 | #### With Docker: 25 | It's best to create a data volume to store your labelled images and the trained model as Docker can struggle with file permissions between the host and container: 26 | 27 | docker create -v /oversight_data --name oversight-data oversight 28 | 29 | Then, you can get the local path that the data path corresponds to using docker inspect: 30 | 31 | docker inspect oversight-data 32 | 33 | The mount path will be in the "Mounts" section: 34 | 35 | "Mounts": [ 36 | { 37 | "Name": "dde7f9d22ac3183bacbcd93eaac7fd62472990d8c6b01858df31ed506b634926", 38 | "Source": "/var/lib/docker/volumes/dde7f9d22ac3183bacbcd93eaac7fd62472990d8c6b01858df31ed506b634926/_data", 39 | "Destination": "/oversight_data", 40 | "Driver": "local", 41 | "Mode": "", 42 | "RW": true, 43 | "Propagation": "" 44 | } 45 | ] 46 | 47 | You can then stage your labelled images: 48 | 49 | cp -R labelled_images /var/lib/docker/volumes/dde7f9d22ac3183bacbcd93eaac7fd62472990d8c6b01858df31ed506b634926/_data/ 50 | 51 | Finally, you can start training the model: 52 | 53 | sudo docker run --volumes-from oversight-data -d --name oversight-train oversight /opt/oversight/bin/train.sh 54 | 55 | sudo docker logs --follow oversight-train 56 | 57 | Looking for images in 'car' 58 | Looking for images in 'nothing' 59 | Looking for images in 'person' 60 | Creating bottleneck at /oversight_data/bottlenecks/nothing/NYd2kqBk_0002.jpg.txt 61 | Creating bottleneck at /oversight_data/bottlenecks/nothing/NYd2kqBk_0003.jpg.txt 62 | Creating bottleneck at /oversight_data/bottlenecks/nothing/NYd2kqBk_0004.jpg.txt 63 | Creating bottleneck at /oversight_data/bottlenecks/nothing/NYd2kqBk_0006.jpg.txt 64 | Creating bottleneck at /oversight_data/bottlenecks/nothing/NYd2kqBk_0007.jpg.txt 65 | .... 66 | 67 | Once the process has completed, you're ready to run Oversight. 68 | 69 | #### Without Docker: 70 | Training produces some intermediary files, so it's suggested to create a data directory for oversight. 71 | 72 | python /oversight/retrain.py \ 73 | --bottleneck_dir=/bottlenecks \ 74 | --how_many_training_steps 4000 \ 75 | --model_dir=/inception \ 76 | --output_graph=/retrained_graph.pb \ 77 | --output_labels=/retrained_labels.txt \ 78 | --image_dir 79 | 80 | ### Running Oversight 81 | Oversight doesn't use config files, but uses command line arguments and environment variables to control its configuration. This allows you to easily inject different configurations at runtime, according to your environment. Any variable supplied as a command line argument will overwrite an environment variable. 82 | 83 | #### Command Line Arguments 84 | 85 | ##### --download_urls, OVERSIGHT_DOWNLOAD_URLS \[TAG:DOWNLOAD_URL ...\] 86 | One or more pairs of tag:download-url pairs. The tag is the name of the camera, and the download-url is the url on the camera where you can retrieve a still jpeg image. 87 | 88 | E.g. --download_urls "front:https://user:password@10.0.0.150/Streaming/channels/1/picture side:https://user:password@10.0.0.151/Streaming/channels/1/picture" 89 | 90 | ##### --model_directory, OVERSIGHT_MODEL_DIRECTORY \[MODEL_DIRECTORY\] 91 | Path to the directory where the pre-trained model is stored. Default is ~/.oversight 92 | 93 | E.g. --model_directory ~/.oversight 94 | 95 | ##### --image_buffer_length, OVERSIGHT_IMAGE_BUFFER_LENGTH \[BUFFER_LENGTH\] 96 | How many images from each camera to store at a time. Images in the buffer will be sent with email notifications. Default is 3. 97 | 98 | E.g. --image_buffer-length 5 99 | 100 | ##### --notification_delay, OVERSIGHT_NOTIFICATION_DELAY \[DELAY\] 101 | How long in seconds to wait after a trigger event before sending a notification. This allows the notification to include images before and after the trigger event. Default is 2 seconds. 102 | 103 | E.g. --notification_delay 2 104 | 105 | ##### --smtp_recipients, OVERSIGHT_SMTP_RECIPIENTS \[RECIPIENTS ...\] 106 | A comma separated list of email addresses to receive email notifications when trigger events occur. 107 | 108 | E.g. --smtp_recipients someone@somewhere.com,another@another.com 109 | 110 | ##### --smtp_server, OVERSIGHT_SMTP_SERVER \[SERVER\] 111 | The SMTP host to send email notifications through. 112 | 113 | E.g. --smtp_server mail.somewhere.com 114 | 115 | ##### --pushover_user, OVERSIGHT_PUSHOVER_USER \[PUSHOVER_USER\] 116 | The Pushover API user to use when sending Pushover notifications. 117 | 118 | E.g. --pushover_user auh139ds2lkjxcjbcv73489351xc823 119 | 120 | ##### --pushover_token, OVERSIGHT_PUSHOVER_TOKEN \[PUSHOVER_TOKEN\] 121 | The Pushover API token to use when sending Pushover notifications. 122 | 123 | E.g. --pushover_token cx952oiiuv24ccvx586sdklakjd7c6v346c8bc612lzxbe 124 | 125 | ##### --image_storage_directory, OVERSIGHT_IMAGE_STORAGE_DIRECTORY \[IMAGE_STORAGE_DIRECTORY\] 126 | Location to store captured images that trigger events. This will only capture an image when the event is first triggered, and won't capture subsequent frames while the trigger is active. Once the trigger is deactivated, a new trigger will capture another image. 127 | 128 | E.g. --image_storage_directory ~/.oversight/captured_images 129 | 130 | ##### --triggers, OVERSIGHT_TRIGGERS \[TRIGGER:LEVEL ...\] 131 | Triggers are a set of trigger events and a level. The trigger events should correspond to the events that the Oversight model has been trained on. The level for each trigger is normalised from 0.0 to 1.0, and represents the probability of that event being a subject of an image captured from a camera. If it is highly unlikely that this event is in the image, the probability will be closer to 0.0. If it is highly likely, the probability will be closer to 1.0. The configured trigger level is the minimum threshold to activate this trigger. As an example, a trigger of 'person:0.80' has been configured. If an image has been analysed and the probability of the 'person' event is less than 0.80, the trigger will not fire. If the probability is greater than or equal to 0.80, it will fire. 132 | 133 | E.g. --triggers "person:0.80 car:0.90 unicorn:0.20" 134 | 135 | ##### --log_level, OVERSIGHT_LOG_LEVEL \[LOG_LEVEL\] 136 | The level of logging to apply. Valid options (from most verbose to least verbose) are DEBUG, INFO, WARNING, ERROR. 137 | 138 | #### Running With Docker: 139 | To run Oversight with Docker, you can override relevant environment options to configure it at runtime. In the example below, a data volume is connected that contains the pre-trained model. 140 | 141 | sudo docker run \ 142 | --volumes-from oversight-data \ 143 | -e "OVERSIGHT_DOWNLOAD_URLS=side:http://user:password@192.168.0.1/Streaming/channels/1/picture front:http://user:password@192.168.0.2/Streaming/channels/1/picture" \ 144 | -e "OVERSIGHT_MODEL_DIRECTORY=/oversight_data" \ 145 | -e "OVERSIGHT_IMAGE_BUFFER_LENGTH=3" \ 146 | -e "OVERSIGHT_SMTP_RECIPIENTS=not_a_real_person@oversight.tech" \ 147 | -e "OVERSIGHT_SMTP_HOST=your.mailserver.com" \ 148 | -e "OVERSIGHT_TRIGGERS=car:0.85 person:0.90" \ 149 | --name oversight -d oversight 150 | 151 | #### Running Without Docker: 152 | Without docker, you can either override relevant environment variables, or supply the variables as command line options: 153 | 154 | python oversight_runner.py --download_urls "http://user:password@192.168.0.1/Streaming/channels/1/picture" --model_directory "~/.oversight" --smtp_recipients "not_a_real_person@oversight.tech" --smtp_server "your.mailserver.com" 155 | 156 | ## The Future 157 | - A more general model that doesn't require individual training. 158 | - More descriptive output than raw categories (using something like [im2txt](https://github.com/tensorflow/models/tree/master/im2txt)). 159 | - More notifiers, e.g. [Twilio](https://www.twilio.com). 160 | - Better secrets management (e.g. integration with [Vault](https://vaultproject.io)) 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2016 Ben Carson 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /oversight/retrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple transfer learning with an Inception v3 architecture model which 16 | displays summaries in TensorBoard. 17 | 18 | This example shows how to take a Inception v3 architecture model trained on 19 | ImageNet images, and train a new top layer that can recognize other classes of 20 | images. 21 | 22 | The top layer receives as input a 2048-dimensional vector for each image. We 23 | train a softmax layer on top of this representation. Assuming the softmax layer 24 | contains N labels, this corresponds to learning N + 2048*N model parameters 25 | corresponding to the learned biases and weights. 26 | 27 | Here's an example, which assumes you have a folder containing class-named 28 | subfolders, each full of images for each label. The example folder flower_photos 29 | should have a structure like this: 30 | 31 | ~/flower_photos/daisy/photo1.jpg 32 | ~/flower_photos/daisy/photo2.jpg 33 | ... 34 | ~/flower_photos/rose/anotherphoto77.jpg 35 | ... 36 | ~/flower_photos/sunflower/somepicture.jpg 37 | 38 | The subfolder names are important, since they define what label is applied to 39 | each image, but the filenames themselves don't matter. Once your images are 40 | prepared, you can run the training with a command like this: 41 | 42 | bazel build third_party/tensorflow/examples/image_retraining:retrain && \ 43 | bazel-bin/third_party/tensorflow/examples/image_retraining/retrain \ 44 | --image_dir ~/flower_photos 45 | 46 | You can replace the image_dir argument with any folder containing subfolders of 47 | images. The label for each image is taken from the name of the subfolder it's 48 | in. 49 | 50 | This produces a new model file that can be loaded and run by any TensorFlow 51 | program, for example the label_image sample code. 52 | 53 | 54 | To use with TensorBoard: 55 | 56 | By default, this script will log summaries to /tmp/retrain_logs directory 57 | 58 | Visualize the summaries with this command: 59 | 60 | tensorboard --logdir /tmp/retrain_logs 61 | 62 | """ 63 | from __future__ import absolute_import 64 | from __future__ import division 65 | from __future__ import print_function 66 | 67 | from datetime import datetime 68 | import glob 69 | import hashlib 70 | import os.path 71 | import random 72 | import re 73 | import sys 74 | import tarfile 75 | 76 | import numpy as np 77 | from six.moves import urllib 78 | import tensorflow as tf 79 | 80 | from tensorflow.python.client import graph_util 81 | from tensorflow.python.framework import tensor_shape 82 | from tensorflow.python.platform import gfile 83 | from tensorflow.python.util import compat 84 | 85 | 86 | import struct 87 | 88 | FLAGS = tf.app.flags.FLAGS 89 | 90 | # Input and output file flags. 91 | tf.app.flags.DEFINE_string('image_dir', '', 92 | """Path to folders of labeled images.""") 93 | tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb', 94 | """Where to save the trained graph.""") 95 | tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt', 96 | """Where to save the trained graph's labels.""") 97 | tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs', 98 | """Where to save summary logs for TensorBoard.""") 99 | 100 | # Details of the training configuration. 101 | tf.app.flags.DEFINE_integer('how_many_training_steps', 4000, 102 | """How many training steps to run before ending.""") 103 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 104 | """How large a learning rate to use when training.""") 105 | tf.app.flags.DEFINE_integer( 106 | 'testing_percentage', 10, 107 | """What percentage of images to use as a test set.""") 108 | tf.app.flags.DEFINE_integer( 109 | 'validation_percentage', 10, 110 | """What percentage of images to use as a validation set.""") 111 | tf.app.flags.DEFINE_integer('eval_step_interval', 10, 112 | """How often to evaluate the training results.""") 113 | tf.app.flags.DEFINE_integer('train_batch_size', 100, 114 | """How many images to train on at a time.""") 115 | tf.app.flags.DEFINE_integer('test_batch_size', 500, 116 | """How many images to test on at a time. This""" 117 | """ test set is only used infrequently to verify""" 118 | """ the overall accuracy of the model.""") 119 | tf.app.flags.DEFINE_integer( 120 | 'validation_batch_size', 100, 121 | """How many images to use in an evaluation batch. This validation set is""" 122 | """ used much more often than the test set, and is an early indicator of""" 123 | """ how accurate the model is during training.""") 124 | 125 | # File-system cache locations. 126 | tf.app.flags.DEFINE_string('model_dir', '/tmp/imagenet', 127 | """Path to classify_image_graph_def.pb, """ 128 | """imagenet_synset_to_human_label_map.txt, and """ 129 | """imagenet_2012_challenge_label_map_proto.pbtxt.""") 130 | tf.app.flags.DEFINE_string( 131 | 'bottleneck_dir', '/tmp/bottleneck', 132 | """Path to cache bottleneck layer values as files.""") 133 | tf.app.flags.DEFINE_string('final_tensor_name', 'final_result', 134 | """The name of the output classification layer in""" 135 | """ the retrained graph.""") 136 | 137 | # Controls the distortions used during training. 138 | tf.app.flags.DEFINE_boolean( 139 | 'flip_left_right', False, 140 | """Whether to randomly flip half of the training images horizontally.""") 141 | tf.app.flags.DEFINE_integer( 142 | 'random_crop', 0, 143 | """A percentage determining how much of a margin to randomly crop off the""" 144 | """ training images.""") 145 | tf.app.flags.DEFINE_integer( 146 | 'random_scale', 0, 147 | """A percentage determining how much to randomly scale up the size of the""" 148 | """ training images by.""") 149 | tf.app.flags.DEFINE_integer( 150 | 'random_brightness', 0, 151 | """A percentage determining how much to randomly multiply the training""" 152 | """ image input pixels up or down by.""") 153 | 154 | # These are all parameters that are tied to the particular model architecture 155 | # we're using for Inception v3. These include things like tensor names and their 156 | # sizes. If you want to adapt this script to work with another model, you will 157 | # need to update these to reflect the values in the network you're using. 158 | # pylint: disable=line-too-long 159 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 160 | # pylint: enable=line-too-long 161 | BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' 162 | BOTTLENECK_TENSOR_SIZE = 2048 163 | MODEL_INPUT_WIDTH = 299 164 | MODEL_INPUT_HEIGHT = 299 165 | MODEL_INPUT_DEPTH = 3 166 | JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' 167 | RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0' 168 | MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M 169 | 170 | 171 | def create_image_lists(image_dir, testing_percentage, validation_percentage): 172 | """Builds a list of training images from the file system. 173 | 174 | Analyzes the sub folders in the image directory, splits them into stable 175 | training, testing, and validation sets, and returns a data structure 176 | describing the lists of images for each label and their paths. 177 | 178 | Args: 179 | image_dir: String path to a folder containing subfolders of images. 180 | testing_percentage: Integer percentage of the images to reserve for tests. 181 | validation_percentage: Integer percentage of images reserved for validation. 182 | 183 | Returns: 184 | A dictionary containing an entry for each label subfolder, with images split 185 | into training, testing, and validation sets within each label. 186 | """ 187 | if not gfile.Exists(image_dir): 188 | print("Image directory '" + image_dir + "' not found.") 189 | return None 190 | result = {} 191 | sub_dirs = [x[0] for x in os.walk(image_dir)] 192 | # The root directory comes first, so skip it. 193 | is_root_dir = True 194 | for sub_dir in sub_dirs: 195 | if is_root_dir: 196 | is_root_dir = False 197 | continue 198 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 199 | file_list = [] 200 | dir_name = os.path.basename(sub_dir) 201 | if dir_name == image_dir: 202 | continue 203 | print("Looking for images in '" + dir_name + "'") 204 | for extension in extensions: 205 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension) 206 | file_list.extend(glob.glob(file_glob)) 207 | if not file_list: 208 | print('No files found') 209 | continue 210 | if len(file_list) < 20: 211 | print('WARNING: Folder has less than 20 images, which may cause issues.') 212 | elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS: 213 | print('WARNING: Folder {} has more than {} images. Some images will ' 214 | 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS)) 215 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) 216 | training_images = [] 217 | testing_images = [] 218 | validation_images = [] 219 | for file_name in file_list: 220 | base_name = os.path.basename(file_name) 221 | # We want to ignore anything after '_nohash_' in the file name when 222 | # deciding which set to put an image in, the data set creator has a way of 223 | # grouping photos that are close variations of each other. For example 224 | # this is used in the plant disease data set to group multiple pictures of 225 | # the same leaf. 226 | hash_name = re.sub(r'_nohash_.*$', '', file_name) 227 | # This looks a bit magical, but we need to decide whether this file should 228 | # go into the training, testing, or validation sets, and we want to keep 229 | # existing files in the same set even if more files are subsequently 230 | # added. 231 | # To do that, we need a stable way of deciding based on just the file name 232 | # itself, so we do a hash of that and then use that to generate a 233 | # probability value that we use to assign it. 234 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 235 | percentage_hash = ((int(hash_name_hashed, 16) % 236 | (MAX_NUM_IMAGES_PER_CLASS + 1)) * 237 | (100.0 / MAX_NUM_IMAGES_PER_CLASS)) 238 | if percentage_hash < validation_percentage: 239 | validation_images.append(base_name) 240 | elif percentage_hash < (testing_percentage + validation_percentage): 241 | testing_images.append(base_name) 242 | else: 243 | training_images.append(base_name) 244 | result[label_name] = { 245 | 'dir': dir_name, 246 | 'training': training_images, 247 | 'testing': testing_images, 248 | 'validation': validation_images, 249 | } 250 | return result 251 | 252 | 253 | def get_image_path(image_lists, label_name, index, image_dir, category): 254 | """"Returns a path to an image for a label at the given index. 255 | 256 | Args: 257 | image_lists: Dictionary of training images for each label. 258 | label_name: Label string we want to get an image for. 259 | index: Int offset of the image we want. This will be moduloed by the 260 | available number of images for the label, so it can be arbitrarily large. 261 | image_dir: Root folder string of the subfolders containing the training 262 | images. 263 | category: Name string of set to pull images from - training, testing, or 264 | validation. 265 | 266 | Returns: 267 | File system path string to an image that meets the requested parameters. 268 | 269 | """ 270 | if label_name not in image_lists: 271 | tf.logging.fatal('Label does not exist %s.', label_name) 272 | label_lists = image_lists[label_name] 273 | if category not in label_lists: 274 | tf.logging.fatal('Category does not exist %s.', category) 275 | category_list = label_lists[category] 276 | if not category_list: 277 | tf.logging.fatal('Label %s has no images in the category %s.', 278 | label_name, category) 279 | mod_index = index % len(category_list) 280 | base_name = category_list[mod_index] 281 | sub_dir = label_lists['dir'] 282 | full_path = os.path.join(image_dir, sub_dir, base_name) 283 | return full_path 284 | 285 | 286 | def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, 287 | category): 288 | """"Returns a path to a bottleneck file for a label at the given index. 289 | 290 | Args: 291 | image_lists: Dictionary of training images for each label. 292 | label_name: Label string we want to get an image for. 293 | index: Integer offset of the image we want. This will be moduloed by the 294 | available number of images for the label, so it can be arbitrarily large. 295 | bottleneck_dir: Folder string holding cached files of bottleneck values. 296 | category: Name string of set to pull images from - training, testing, or 297 | validation. 298 | 299 | Returns: 300 | File system path string to an image that meets the requested parameters. 301 | """ 302 | return get_image_path(image_lists, label_name, index, bottleneck_dir, 303 | category) + '.txt' 304 | 305 | 306 | def create_inception_graph(): 307 | """"Creates a graph from saved GraphDef file and returns a Graph object. 308 | 309 | Returns: 310 | Graph holding the trained Inception network, and various tensors we'll be 311 | manipulating. 312 | """ 313 | with tf.Session() as sess: 314 | model_filename = os.path.join( 315 | FLAGS.model_dir, 'classify_image_graph_def.pb') 316 | with gfile.FastGFile(model_filename, 'rb') as f: 317 | graph_def = tf.GraphDef() 318 | graph_def.ParseFromString(f.read()) 319 | bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = ( 320 | tf.import_graph_def(graph_def, name='', return_elements=[ 321 | BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME, 322 | RESIZED_INPUT_TENSOR_NAME])) 323 | return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 324 | 325 | 326 | def run_bottleneck_on_image(sess, image_data, image_data_tensor, 327 | bottleneck_tensor): 328 | """Runs inference on an image to extract the 'bottleneck' summary layer. 329 | 330 | Args: 331 | sess: Current active TensorFlow Session. 332 | image_data: String of raw JPEG data. 333 | image_data_tensor: Input data layer in the graph. 334 | bottleneck_tensor: Layer before the final softmax. 335 | 336 | Returns: 337 | Numpy array of bottleneck values. 338 | """ 339 | bottleneck_values = sess.run( 340 | bottleneck_tensor, 341 | {image_data_tensor: image_data}) 342 | bottleneck_values = np.squeeze(bottleneck_values) 343 | return bottleneck_values 344 | 345 | 346 | def maybe_download_and_extract(): 347 | """Download and extract model tar file. 348 | 349 | If the pretrained model we're using doesn't already exist, this function 350 | downloads it from the TensorFlow.org website and unpacks it into a directory. 351 | """ 352 | dest_directory = FLAGS.model_dir 353 | if not os.path.exists(dest_directory): 354 | os.makedirs(dest_directory) 355 | filename = DATA_URL.split('/')[-1] 356 | filepath = os.path.join(dest_directory, filename) 357 | if not os.path.exists(filepath): 358 | 359 | def _progress(count, block_size, total_size): 360 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 361 | (filename, 362 | float(count * block_size) / float(total_size) * 100.0)) 363 | sys.stdout.flush() 364 | 365 | filepath, _ = urllib.request.urlretrieve(DATA_URL, 366 | filepath, 367 | _progress) 368 | print() 369 | statinfo = os.stat(filepath) 370 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 371 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 372 | 373 | 374 | def ensure_dir_exists(dir_name): 375 | """Makes sure the folder exists on disk. 376 | 377 | Args: 378 | dir_name: Path string to the folder we want to create. 379 | """ 380 | if not os.path.exists(dir_name): 381 | os.makedirs(dir_name) 382 | 383 | 384 | def write_list_of_floats_to_file(list_of_floats , file_path): 385 | """Writes a given list of floats to a binary file. 386 | 387 | Args: 388 | list_of_floats: List of floats we want to write to a file. 389 | file_path: Path to a file where list of floats will be stored. 390 | 391 | """ 392 | 393 | s = struct.pack('d' * BOTTLENECK_TENSOR_SIZE, *list_of_floats) 394 | with open(file_path, 'wb') as f: 395 | f.write(s) 396 | 397 | 398 | def read_list_of_floats_from_file(file_path): 399 | """Reads list of floats from a given file. 400 | 401 | Args: 402 | file_path: Path to a file where list of floats was stored. 403 | Returns: 404 | Array of bottleneck values (list of floats). 405 | 406 | """ 407 | 408 | with open(file_path, 'rb') as f: 409 | s = struct.unpack('d' * BOTTLENECK_TENSOR_SIZE, f.read()) 410 | return list(s) 411 | 412 | 413 | bottleneck_path_2_bottleneck_values = {} 414 | 415 | 416 | def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, 417 | category, bottleneck_dir, jpeg_data_tensor, 418 | bottleneck_tensor): 419 | """Retrieves or calculates bottleneck values for an image. 420 | 421 | If a cached version of the bottleneck data exists on-disk, return that, 422 | otherwise calculate the data and save it to disk for future use. 423 | 424 | Args: 425 | sess: The current active TensorFlow Session. 426 | image_lists: Dictionary of training images for each label. 427 | label_name: Label string we want to get an image for. 428 | index: Integer offset of the image we want. This will be modulo-ed by the 429 | available number of images for the label, so it can be arbitrarily large. 430 | image_dir: Root folder string of the subfolders containing the training 431 | images. 432 | category: Name string of which set to pull images from - training, testing, 433 | or validation. 434 | bottleneck_dir: Folder string holding cached files of bottleneck values. 435 | jpeg_data_tensor: The tensor to feed loaded jpeg data into. 436 | bottleneck_tensor: The output tensor for the bottleneck values. 437 | 438 | Returns: 439 | Numpy array of values produced by the bottleneck layer for the image. 440 | """ 441 | label_lists = image_lists[label_name] 442 | sub_dir = label_lists['dir'] 443 | sub_dir_path = os.path.join(bottleneck_dir, sub_dir) 444 | ensure_dir_exists(sub_dir_path) 445 | bottleneck_path = get_bottleneck_path(image_lists, label_name, index, 446 | bottleneck_dir, category) 447 | if not os.path.exists(bottleneck_path): 448 | print('Creating bottleneck at ' + bottleneck_path) 449 | image_path = get_image_path(image_lists, label_name, index, image_dir, 450 | category) 451 | if not gfile.Exists(image_path): 452 | tf.logging.fatal('File does not exist %s', image_path) 453 | image_data = gfile.FastGFile(image_path, 'rb').read() 454 | bottleneck_values = run_bottleneck_on_image(sess, image_data, 455 | jpeg_data_tensor, 456 | bottleneck_tensor) 457 | bottleneck_string = ','.join(str(x) for x in bottleneck_values) 458 | with open(bottleneck_path, 'w') as bottleneck_file: 459 | bottleneck_file.write(bottleneck_string) 460 | 461 | with open(bottleneck_path, 'r') as bottleneck_file: 462 | bottleneck_string = bottleneck_file.read() 463 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 464 | return bottleneck_values 465 | 466 | 467 | def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, 468 | jpeg_data_tensor, bottleneck_tensor): 469 | """Ensures all the training, testing, and validation bottlenecks are cached. 470 | 471 | Because we're likely to read the same image multiple times (if there are no 472 | distortions applied during training) it can speed things up a lot if we 473 | calculate the bottleneck layer values once for each image during 474 | preprocessing, and then just read those cached values repeatedly during 475 | training. Here we go through all the images we've found, calculate those 476 | values, and save them off. 477 | 478 | Args: 479 | sess: The current active TensorFlow Session. 480 | image_lists: Dictionary of training images for each label. 481 | image_dir: Root folder string of the subfolders containing the training 482 | images. 483 | bottleneck_dir: Folder string holding cached files of bottleneck values. 484 | jpeg_data_tensor: Input tensor for jpeg data from file. 485 | bottleneck_tensor: The penultimate output layer of the graph. 486 | 487 | Returns: 488 | Nothing. 489 | """ 490 | how_many_bottlenecks = 0 491 | ensure_dir_exists(bottleneck_dir) 492 | for label_name, label_lists in image_lists.items(): 493 | for category in ['training', 'testing', 'validation']: 494 | category_list = label_lists[category] 495 | for index, unused_base_name in enumerate(category_list): 496 | get_or_create_bottleneck(sess, image_lists, label_name, index, 497 | image_dir, category, bottleneck_dir, 498 | jpeg_data_tensor, bottleneck_tensor) 499 | how_many_bottlenecks += 1 500 | if how_many_bottlenecks % 100 == 0: 501 | print(str(how_many_bottlenecks) + ' bottleneck files created.') 502 | 503 | 504 | def get_random_cached_bottlenecks(sess, image_lists, how_many, category, 505 | bottleneck_dir, image_dir, jpeg_data_tensor, 506 | bottleneck_tensor): 507 | """Retrieves bottleneck values for cached images. 508 | 509 | If no distortions are being applied, this function can retrieve the cached 510 | bottleneck values directly from disk for images. It picks a random set of 511 | images from the specified category. 512 | 513 | Args: 514 | sess: Current TensorFlow Session. 515 | image_lists: Dictionary of training images for each label. 516 | how_many: The number of bottleneck values to return. 517 | category: Name string of which set to pull from - training, testing, or 518 | validation. 519 | bottleneck_dir: Folder string holding cached files of bottleneck values. 520 | image_dir: Root folder string of the subfolders containing the training 521 | images. 522 | jpeg_data_tensor: The layer to feed jpeg image data into. 523 | bottleneck_tensor: The bottleneck output layer of the CNN graph. 524 | 525 | Returns: 526 | List of bottleneck arrays and their corresponding ground truths. 527 | """ 528 | class_count = len(image_lists.keys()) 529 | bottlenecks = [] 530 | ground_truths = [] 531 | for unused_i in range(how_many): 532 | label_index = random.randrange(class_count) 533 | label_name = list(image_lists.keys())[label_index] 534 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) 535 | bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, 536 | image_index, image_dir, category, 537 | bottleneck_dir, jpeg_data_tensor, 538 | bottleneck_tensor) 539 | ground_truth = np.zeros(class_count, dtype=np.float32) 540 | ground_truth[label_index] = 1.0 541 | bottlenecks.append(bottleneck) 542 | ground_truths.append(ground_truth) 543 | return bottlenecks, ground_truths 544 | 545 | 546 | def get_random_distorted_bottlenecks( 547 | sess, image_lists, how_many, category, image_dir, input_jpeg_tensor, 548 | distorted_image, resized_input_tensor, bottleneck_tensor): 549 | """Retrieves bottleneck values for training images, after distortions. 550 | 551 | If we're training with distortions like crops, scales, or flips, we have to 552 | recalculate the full model for every image, and so we can't use cached 553 | bottleneck values. Instead we find random images for the requested category, 554 | run them through the distortion graph, and then the full graph to get the 555 | bottleneck results for each. 556 | 557 | Args: 558 | sess: Current TensorFlow Session. 559 | image_lists: Dictionary of training images for each label. 560 | how_many: The integer number of bottleneck values to return. 561 | category: Name string of which set of images to fetch - training, testing, 562 | or validation. 563 | image_dir: Root folder string of the subfolders containing the training 564 | images. 565 | input_jpeg_tensor: The input layer we feed the image data to. 566 | distorted_image: The output node of the distortion graph. 567 | resized_input_tensor: The input node of the recognition graph. 568 | bottleneck_tensor: The bottleneck output layer of the CNN graph. 569 | 570 | Returns: 571 | List of bottleneck arrays and their corresponding ground truths. 572 | """ 573 | class_count = len(image_lists.keys()) 574 | bottlenecks = [] 575 | ground_truths = [] 576 | for unused_i in range(how_many): 577 | label_index = random.randrange(class_count) 578 | label_name = list(image_lists.keys())[label_index] 579 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) 580 | image_path = get_image_path(image_lists, label_name, image_index, image_dir, 581 | category) 582 | if not gfile.Exists(image_path): 583 | tf.logging.fatal('File does not exist %s', image_path) 584 | jpeg_data = gfile.FastGFile(image_path, 'rb').read() 585 | # Note that we materialize the distorted_image_data as a numpy array before 586 | # sending running inference on the image. This involves 2 memory copies and 587 | # might be optimized in other implementations. 588 | distorted_image_data = sess.run(distorted_image, 589 | {input_jpeg_tensor: jpeg_data}) 590 | bottleneck = run_bottleneck_on_image(sess, distorted_image_data, 591 | resized_input_tensor, 592 | bottleneck_tensor) 593 | ground_truth = np.zeros(class_count, dtype=np.float32) 594 | ground_truth[label_index] = 1.0 595 | bottlenecks.append(bottleneck) 596 | ground_truths.append(ground_truth) 597 | return bottlenecks, ground_truths 598 | 599 | 600 | def should_distort_images(flip_left_right, random_crop, random_scale, 601 | random_brightness): 602 | """Whether any distortions are enabled, from the input flags. 603 | 604 | Args: 605 | flip_left_right: Boolean whether to randomly mirror images horizontally. 606 | random_crop: Integer percentage setting the total margin used around the 607 | crop box. 608 | random_scale: Integer percentage of how much to vary the scale by. 609 | random_brightness: Integer range to randomly multiply the pixel values by. 610 | 611 | Returns: 612 | Boolean value indicating whether any distortions should be applied. 613 | """ 614 | return (flip_left_right or (random_crop != 0) or (random_scale != 0) or 615 | (random_brightness != 0)) 616 | 617 | 618 | def add_input_distortions(flip_left_right, random_crop, random_scale, 619 | random_brightness): 620 | """Creates the operations to apply the specified distortions. 621 | 622 | During training it can help to improve the results if we run the images 623 | through simple distortions like crops, scales, and flips. These reflect the 624 | kind of variations we expect in the real world, and so can help train the 625 | model to cope with natural data more effectively. Here we take the supplied 626 | parameters and construct a network of operations to apply them to an image. 627 | 628 | Cropping 629 | ~~~~~~~~ 630 | 631 | Cropping is done by placing a bounding box at a random position in the full 632 | image. The cropping parameter controls the size of that box relative to the 633 | input image. If it's zero, then the box is the same size as the input and no 634 | cropping is performed. If the value is 50%, then the crop box will be half the 635 | width and height of the input. In a diagram it looks like this: 636 | 637 | < width > 638 | +---------------------+ 639 | | | 640 | | width - crop% | 641 | | < > | 642 | | +------+ | 643 | | | | | 644 | | | | | 645 | | | | | 646 | | +------+ | 647 | | | 648 | | | 649 | +---------------------+ 650 | 651 | Scaling 652 | ~~~~~~~ 653 | 654 | Scaling is a lot like cropping, except that the bounding box is always 655 | centered and its size varies randomly within the given range. For example if 656 | the scale percentage is zero, then the bounding box is the same size as the 657 | input and no scaling is applied. If it's 50%, then the bounding box will be in 658 | a random range between half the width and height and full size. 659 | 660 | Args: 661 | flip_left_right: Boolean whether to randomly mirror images horizontally. 662 | random_crop: Integer percentage setting the total margin used around the 663 | crop box. 664 | random_scale: Integer percentage of how much to vary the scale by. 665 | random_brightness: Integer range to randomly multiply the pixel values by. 666 | graph. 667 | 668 | Returns: 669 | The jpeg input layer and the distorted result tensor. 670 | """ 671 | 672 | jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') 673 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH) 674 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32) 675 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) 676 | margin_scale = 1.0 + (random_crop / 100.0) 677 | resize_scale = 1.0 + (random_scale / 100.0) 678 | margin_scale_value = tf.constant(margin_scale) 679 | resize_scale_value = tf.random_uniform(tensor_shape.scalar(), 680 | minval=1.0, 681 | maxval=resize_scale) 682 | scale_value = tf.mul(margin_scale_value, resize_scale_value) 683 | precrop_width = tf.mul(scale_value, MODEL_INPUT_WIDTH) 684 | precrop_height = tf.mul(scale_value, MODEL_INPUT_HEIGHT) 685 | precrop_shape = tf.pack([precrop_height, precrop_width]) 686 | precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32) 687 | precropped_image = tf.image.resize_bilinear(decoded_image_4d, 688 | precrop_shape_as_int) 689 | precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0]) 690 | cropped_image = tf.random_crop(precropped_image_3d, 691 | [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, 692 | MODEL_INPUT_DEPTH]) 693 | if flip_left_right: 694 | flipped_image = tf.image.random_flip_left_right(cropped_image) 695 | else: 696 | flipped_image = cropped_image 697 | brightness_min = 1.0 - (random_brightness / 100.0) 698 | brightness_max = 1.0 + (random_brightness / 100.0) 699 | brightness_value = tf.random_uniform(tensor_shape.scalar(), 700 | minval=brightness_min, 701 | maxval=brightness_max) 702 | brightened_image = tf.mul(flipped_image, brightness_value) 703 | distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult') 704 | return jpeg_data, distort_result 705 | 706 | 707 | def variable_summaries(var, name): 708 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 709 | with tf.name_scope('summaries'): 710 | mean = tf.reduce_mean(var) 711 | tf.scalar_summary('mean/' + name, mean) 712 | with tf.name_scope('stddev'): 713 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 714 | tf.scalar_summary('stddev/' + name, stddev) 715 | tf.scalar_summary('max/' + name, tf.reduce_max(var)) 716 | tf.scalar_summary('min/' + name, tf.reduce_min(var)) 717 | tf.histogram_summary(name, var) 718 | 719 | 720 | def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor): 721 | """Adds a new softmax and fully-connected layer for training. 722 | 723 | We need to retrain the top layer to identify our new classes, so this function 724 | adds the right operations to the graph, along with some variables to hold the 725 | weights, and then sets up all the gradients for the backward pass. 726 | 727 | The set up for the softmax and fully-connected layers is based on: 728 | https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html 729 | 730 | Args: 731 | class_count: Integer of how many categories of things we're trying to 732 | recognize. 733 | final_tensor_name: Name string for the new final node that produces results. 734 | bottleneck_tensor: The output of the main CNN graph. 735 | 736 | Returns: 737 | The tensors for the training and cross entropy results, and tensors for the 738 | bottleneck input and ground truth input. 739 | """ 740 | with tf.name_scope('input'): 741 | bottleneck_input = tf.placeholder_with_default( 742 | bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE], 743 | name='BottleneckInputPlaceholder') 744 | 745 | ground_truth_input = tf.placeholder(tf.float32, 746 | [None, class_count], 747 | name='GroundTruthInput') 748 | 749 | # Organizing the following ops as `final_training_ops` so they're easier 750 | # to see in TensorBoard 751 | layer_name = 'final_training_ops' 752 | with tf.name_scope(layer_name): 753 | with tf.name_scope('weights'): 754 | layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights') 755 | variable_summaries(layer_weights, layer_name + '/weights') 756 | with tf.name_scope('biases'): 757 | layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') 758 | variable_summaries(layer_biases, layer_name + '/biases') 759 | with tf.name_scope('Wx_plus_b'): 760 | logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases 761 | tf.histogram_summary(layer_name + '/pre_activations', logits) 762 | 763 | final_tensor = tf.nn.softmax(logits, name=final_tensor_name) 764 | tf.histogram_summary(final_tensor_name + '/activations', final_tensor) 765 | 766 | with tf.name_scope('cross_entropy'): 767 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 768 | logits, ground_truth_input) 769 | with tf.name_scope('total'): 770 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 771 | tf.scalar_summary('cross entropy', cross_entropy_mean) 772 | 773 | with tf.name_scope('train'): 774 | train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize( 775 | cross_entropy_mean) 776 | 777 | return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, 778 | final_tensor) 779 | 780 | 781 | def add_evaluation_step(result_tensor, ground_truth_tensor): 782 | """Inserts the operations we need to evaluate the accuracy of our results. 783 | 784 | Args: 785 | result_tensor: The new final node that produces results. 786 | ground_truth_tensor: The node we feed ground truth data 787 | into. 788 | 789 | Returns: 790 | Nothing. 791 | """ 792 | with tf.name_scope('accuracy'): 793 | with tf.name_scope('correct_prediction'): 794 | correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \ 795 | tf.argmax(ground_truth_tensor, 1)) 796 | with tf.name_scope('accuracy'): 797 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 798 | tf.scalar_summary('accuracy', evaluation_step) 799 | return evaluation_step 800 | 801 | 802 | def main(_): 803 | # Setup the directory we'll write summaries to for TensorBoard 804 | if tf.gfile.Exists(FLAGS.summaries_dir): 805 | tf.gfile.DeleteRecursively(FLAGS.summaries_dir) 806 | tf.gfile.MakeDirs(FLAGS.summaries_dir) 807 | 808 | # Set up the pre-trained graph. 809 | maybe_download_and_extract() 810 | graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = ( 811 | create_inception_graph()) 812 | 813 | # Look at the folder structure, and create lists of all the images. 814 | image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, 815 | FLAGS.validation_percentage) 816 | class_count = len(image_lists.keys()) 817 | if class_count == 0: 818 | print('No valid folders of images found at ' + FLAGS.image_dir) 819 | return -1 820 | if class_count == 1: 821 | print('Only one valid folder of images found at ' + FLAGS.image_dir + 822 | ' - multiple classes are needed for classification.') 823 | return -1 824 | 825 | # See if the command-line flags mean we're applying any distortions. 826 | do_distort_images = should_distort_images( 827 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 828 | FLAGS.random_brightness) 829 | sess = tf.Session() 830 | 831 | if do_distort_images: 832 | # We will be applying distortions, so setup the operations we'll need. 833 | distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions( 834 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 835 | FLAGS.random_brightness) 836 | else: 837 | # We'll make sure we've calculated the 'bottleneck' image summaries and 838 | # cached them on disk. 839 | cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir, 840 | jpeg_data_tensor, bottleneck_tensor) 841 | 842 | # Add the new layer that we'll be training. 843 | (train_step, cross_entropy, bottleneck_input, ground_truth_input, 844 | final_tensor) = add_final_training_ops(len(image_lists.keys()), 845 | FLAGS.final_tensor_name, 846 | bottleneck_tensor) 847 | 848 | # Create the operations we need to evaluate the accuracy of our new layer. 849 | evaluation_step = add_evaluation_step(final_tensor, ground_truth_input) 850 | 851 | # Merge all the summaries and write them out to /tmp/retrain_logs (by default) 852 | merged = tf.merge_all_summaries() 853 | train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', 854 | sess.graph) 855 | validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation') 856 | 857 | # Set up all our weights to their initial default values. 858 | init = tf.initialize_all_variables() 859 | sess.run(init) 860 | 861 | # Run the training for as many cycles as requested on the command line. 862 | for i in range(FLAGS.how_many_training_steps): 863 | # Get a batch of input bottleneck values, either calculated fresh every time 864 | # with distortions applied, or from the cache stored on disk. 865 | if do_distort_images: 866 | train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks( 867 | sess, image_lists, FLAGS.train_batch_size, 'training', 868 | FLAGS.image_dir, distorted_jpeg_data_tensor, 869 | distorted_image_tensor, resized_image_tensor, bottleneck_tensor) 870 | else: 871 | train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks( 872 | sess, image_lists, FLAGS.train_batch_size, 'training', 873 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 874 | bottleneck_tensor) 875 | # Feed the bottlenecks and ground truth into the graph, and run a training 876 | # step. Capture training summaries for TensorBoard with the `merged` op. 877 | train_summary, _ = sess.run([merged, train_step], 878 | feed_dict={bottleneck_input: train_bottlenecks, 879 | ground_truth_input: train_ground_truth}) 880 | train_writer.add_summary(train_summary, i) 881 | 882 | # Every so often, print out how well the graph is training. 883 | is_last_step = (i + 1 == FLAGS.how_many_training_steps) 884 | if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 885 | train_accuracy, cross_entropy_value = sess.run( 886 | [evaluation_step, cross_entropy], 887 | feed_dict={bottleneck_input: train_bottlenecks, 888 | ground_truth_input: train_ground_truth}) 889 | print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i, 890 | train_accuracy * 100)) 891 | print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, 892 | cross_entropy_value)) 893 | validation_bottlenecks, validation_ground_truth = ( 894 | get_random_cached_bottlenecks( 895 | sess, image_lists, FLAGS.validation_batch_size, 'validation', 896 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 897 | bottleneck_tensor)) 898 | # Run a validation step and capture training summaries for TensorBoard 899 | # with the `merged` op. 900 | validation_summary, validation_accuracy = sess.run( 901 | [merged, evaluation_step], 902 | feed_dict={bottleneck_input: validation_bottlenecks, 903 | ground_truth_input: validation_ground_truth}) 904 | validation_writer.add_summary(validation_summary, i) 905 | print('%s: Step %d: Validation accuracy = %.1f%%' % 906 | (datetime.now(), i, validation_accuracy * 100)) 907 | 908 | # We've completed all our training, so run a final test evaluation on 909 | # some new images we haven't used before. 910 | test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks( 911 | sess, image_lists, FLAGS.test_batch_size, 'testing', 912 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 913 | bottleneck_tensor) 914 | test_accuracy = sess.run( 915 | evaluation_step, 916 | feed_dict={bottleneck_input: test_bottlenecks, 917 | ground_truth_input: test_ground_truth}) 918 | print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) 919 | 920 | # Write out the trained graph and labels with the weights stored as constants. 921 | output_graph_def = graph_util.convert_variables_to_constants( 922 | sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) 923 | with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: 924 | f.write(output_graph_def.SerializeToString()) 925 | with gfile.FastGFile(FLAGS.output_labels, 'w') as f: 926 | f.write('\n'.join(image_lists.keys()) + '\n') 927 | 928 | 929 | if __name__ == '__main__': 930 | tf.app.run() 931 | --------------------------------------------------------------------------------