├── delay ├── __init__.py ├── __init__.pyc ├── UniformDelay.pyc ├── ConstantDelay.py ├── PoissonDelay.py └── UniformDelay.py ├── classifiers ├── __init__.py ├── Images.pyc ├── __init__.pyc ├── UNSWClassifier.pyc ├── OMNIGLOTClassifier.pyc ├── Images.py ├── UNSWClassifier.py └── OMNIGLOTClassifier.py ├── rl_agents ├── __init__.py ├── DeROLAgent.pyc ├── __init__.pyc └── DeROLAgent.py ├── analyst_manager ├── __init__.py ├── __init__.pyc ├── AnalystManagement.pyc ├── UNSWAnalystManagement.pyc ├── UNSWAnalystManagement.py └── AnalystManagement.py ├── sample_handlers ├── __init__.py ├── __init__.pyc ├── ExperimentLogger.pyc ├── DelayClassification.pyc ├── ExperimentLogger.py └── DelayClassification.py ├── sample_generators ├── __init__.py ├── __init__.pyc ├── UNSWGenerator.pyc ├── OMNIGLOTGenerator.pyc ├── OMNIGLOTGenerator.py └── UNSWGenerator.py ├── trained_models └── backup │ └── models ├── datasets ├── OMNIGLOT.info └── UNSW.info ├── LICENSE ├── README.md ├── train_derol_OMNIGLOT.py └── train_derol_UNSW.py /delay/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rl_agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /analyst_manager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sample_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sample_generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /delay/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/delay/__init__.pyc -------------------------------------------------------------------------------- /classifiers/Images.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/classifiers/Images.pyc -------------------------------------------------------------------------------- /classifiers/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/classifiers/__init__.pyc -------------------------------------------------------------------------------- /delay/UniformDelay.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/delay/UniformDelay.pyc -------------------------------------------------------------------------------- /rl_agents/DeROLAgent.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/rl_agents/DeROLAgent.pyc -------------------------------------------------------------------------------- /rl_agents/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/rl_agents/__init__.pyc -------------------------------------------------------------------------------- /trained_models/backup/models: -------------------------------------------------------------------------------- 1 | Models are automatically saved for backup and stored in this folder -------------------------------------------------------------------------------- /analyst_manager/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/analyst_manager/__init__.pyc -------------------------------------------------------------------------------- /sample_handlers/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_handlers/__init__.pyc -------------------------------------------------------------------------------- /classifiers/UNSWClassifier.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/classifiers/UNSWClassifier.pyc -------------------------------------------------------------------------------- /sample_generators/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_generators/__init__.pyc -------------------------------------------------------------------------------- /classifiers/OMNIGLOTClassifier.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/classifiers/OMNIGLOTClassifier.pyc -------------------------------------------------------------------------------- /sample_generators/UNSWGenerator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_generators/UNSWGenerator.pyc -------------------------------------------------------------------------------- /sample_handlers/ExperimentLogger.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_handlers/ExperimentLogger.pyc -------------------------------------------------------------------------------- /analyst_manager/AnalystManagement.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/analyst_manager/AnalystManagement.pyc -------------------------------------------------------------------------------- /sample_generators/OMNIGLOTGenerator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_generators/OMNIGLOTGenerator.pyc -------------------------------------------------------------------------------- /sample_handlers/DelayClassification.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/sample_handlers/DelayClassification.pyc -------------------------------------------------------------------------------- /analyst_manager/UNSWAnalystManagement.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonpuz/DeROL/HEAD/analyst_manager/UNSWAnalystManagement.pyc -------------------------------------------------------------------------------- /delay/ConstantDelay.py: -------------------------------------------------------------------------------- 1 | class ConstantDelay(): 2 | def __init__(self, delay=1.0): 3 | self.delay = delay 4 | 5 | def get_rand_delay(self): 6 | return int(1 + self.delay) -------------------------------------------------------------------------------- /delay/PoissonDelay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PoissonDelay(): 5 | def __init__(self, lam=1.0): 6 | self.lam = lam 7 | 8 | def get_rand_delay(self): 9 | return int(1 + np.random.poisson(self.lam)) -------------------------------------------------------------------------------- /delay/UniformDelay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class UniformDelay(): 5 | def __init__(self, par=1.0): 6 | self.par = par 7 | 8 | def get_rand_delay(self): 9 | return int(np.random.uniform(self.par + 1.1)) #1.1 is used to make inclusive selection -------------------------------------------------------------------------------- /datasets/OMNIGLOT.info: -------------------------------------------------------------------------------- 1 | # This document holds instructions for downloading and creating OMNIGLOT sample dataset 2 | This dataset **must** be used in accordance to its own license and terms 3 | information taken from: https://github.com/brendenlake/omniglot 4 | original paper: http://science.sciencemag.org/content/350/6266/1332 5 | 6 | ## General Information 7 | This dataset contains 1623 different handwritten characters from 50 different alphabets. 8 | Each of the 1623 characters was drawn online via Amazon's Mechanical Turk by 20 different people 9 | 10 | ## Download 11 | 1. Go to: https://github.com/brendenlake/omniglot/tree/master/python 12 | 2. Download any one of the zip files 13 | 14 | ## Folder Structure Preparation 15 | 1. Extract the zip files and pass the parent directory when running the algorithm -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 Anton Puzanov and Kobi Cohen 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /datasets/UNSW.info: -------------------------------------------------------------------------------- 1 | # This document holds instructions for downloading and creating UNBW-NB15 sample dataset 2 | This dataset **must** be used in accordance to its own license and terms 3 | information taken from: https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/ 4 | 5 | The details of the UNSW-NB15 data set are published in following the papers: 6 | Moustafa, Nour, and Jill Slay. "UNSW-NB15: a comprehensive data set for network intrusion detection systems (UNSW-NB15 network data set)."Military Communications and Information Systems Conference (MilCIS), 2015. IEEE, 2015. 7 | Moustafa, Nour, and Jill Slay. "The evaluation of Network Anomaly Detection Systems: Statistical analysis of the UNSW-NB15 data set and the comparison with the KDD99 data set." Information Security Journal: A Global Perspective (2016): 1-14. 8 | 9 | 10 | ## General Information 11 | This dataset contains network session summary records 12 | The raw network packets of the UNSW-NB 15 data set was created by the IXIA PerfectStorm tool in the Cyber Range Lab of the Australian Centre for Cyber Security (ACCS) for generating a hybrid of real modern normal activities and synthetic contemporary attack behaviours. 13 | 14 | ## Download 15 | 1. Go to: https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/ 16 | 2. Download any .csv files you wish to work with, UNSW-NB15_2.csv, UNSW-NB15_3.csv were used in this study 17 | 18 | ## Folder Structure Preparation 19 | 1. Column 45 (for excel - AV) has the attack name or blank for normal behavior - will be reffered as "Record Category" 20 | 2. Partition the file by record category, with the following rules 21 | a. file names start with capital letter, normal records is called "Normal" 22 | b. files for train should be called _UNSW.csv - example Generic_UNSW.csv 23 | c. files for test should be called test__UNSW.csv - example Shellcode_test_UNSW.csv 24 | d. all categorical and time columns (1,2,3,4,5,6,14,21,22,45,46) should be deleted 25 | e. files should not contain any headers 26 | 3. Pass the parent directory with all the files when running the algorithm 27 | -------------------------------------------------------------------------------- /classifiers/Images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import math 5 | 6 | from scipy.ndimage import rotate,shift 7 | from scipy.misc import imread,imresize 8 | 9 | 10 | def get_shuffled_images(paths, labels, nb_samples=None, shuffle=True, last_class_offset=0, late_instances=1): 11 | if nb_samples is not None: 12 | before_third_samples = int(math.ceil(last_class_offset/len(labels))) 13 | all_but_last_samples = int(math.ceil(before_third_samples * float(len(labels)) / float(len(labels) - late_instances))) 14 | sampler_before = lambda x: random.sample(x, all_but_last_samples) 15 | sampler_after = lambda x: random.sample(x, nb_samples-before_third_samples) 16 | else: 17 | sampler_before = lambda x:x 18 | sampler_after = lambda x:x 19 | 20 | # print "Sampling from: " + str(paths) 21 | images_before = [(i, os.path.join(path, image)) for i,path in zip(labels[late_instances:],paths[late_instances:]) for image in sampler_before(os.listdir(path)) ] 22 | images_after = [(i, os.path.join(path, image)) for i,path in zip(labels,paths) for image in sampler_after(os.listdir(path)) ] 23 | 24 | if(shuffle): 25 | random.shuffle(images_before) 26 | random.shuffle(images_after) 27 | 28 | images = images_before + images_after 29 | images = images[0: len(labels) * nb_samples] 30 | return images 31 | 32 | def time_offset_label(labels_and_images): 33 | labels, images = zip(*labels_and_images) 34 | time_offset_labels = (None,) + labels[:-1] 35 | return zip(images, time_offset_labels) 36 | 37 | def load_transform(image_path, angle=0., s=(0,0), size=(20,20)): 38 | #Load the image 39 | original = imread(image_path, flatten=True) 40 | #Rotate the image 41 | rotated = np.maximum(np.minimum(rotate(original, angle=angle, cval=1.), 1.), 0.) 42 | #Shift the image 43 | shifted = shift(rotated, shift=s) 44 | #Resize the image 45 | resized = np.asarray(imresize(rotated, size=size), dtype=np.float32) / 255 #Note here we coded manually as np.float32, it should be tf.float32 46 | #Invert the image 47 | inverted = 1. - resized 48 | max_value = np.max(inverted) 49 | if max_value > 0: 50 | inverted /= max_value 51 | return inverted -------------------------------------------------------------------------------- /analyst_manager/UNSWAnalystManagement.py: -------------------------------------------------------------------------------- 1 | 2 | class UNSWAnalystManagement: 3 | def __init__(self, number_of_analysts, delay_creator, image_classifier, maximal_size=10000000): 4 | self.delay_creator = delay_creator 5 | self.current_delay_position = 0 6 | self.maximal_size = maximal_size 7 | self.delayed_classifications = {} 8 | self.number_of_analysts = number_of_analysts 9 | self.busy_analysts = 0 10 | self.job_queue = [] 11 | self.image_classifier = image_classifier 12 | 13 | def __add_item(self, item_to_add, delay=0): 14 | # Add element to be written 15 | position_after_delay = (self.current_delay_position + delay) % self.maximal_size 16 | to_be_written_after_delay = self.delayed_classifications.get(position_after_delay, []) 17 | to_be_written_after_delay.append(item_to_add) 18 | self.delayed_classifications[position_after_delay] = to_be_written_after_delay 19 | 20 | def __start_processing_of_job(self, buffer): #private 21 | wait_for = self.delay_creator.get_rand_delay() 22 | self.busy_analysts += 1 23 | self.__add_item(buffer, delay=wait_for) 24 | 25 | def analysts_load(self): 26 | return float(self.busy_analysts + len(self.job_queue))/self.number_of_analysts 27 | 28 | def add_classification_job(self, buffer): 29 | if self.busy_analysts 0 43 | 44 | def get_buffer_size(self): 45 | return self.maximal_size 46 | 47 | def get_sample(self): 48 | if (len(self.delayed_inputs.get(self.current_position, [])) == 0): 49 | print("in DelayClassification, asked for samples with empty buffer, returning None") 50 | return None 51 | else: 52 | to_be_written_after_delay = self.delayed_inputs.get(self.current_position, []) 53 | to_return = to_be_written_after_delay[0] 54 | to_be_written_after_delay = to_be_written_after_delay[1:] 55 | self.delayed_inputs[self.current_position] = to_be_written_after_delay 56 | self.delayed_samples -= 1 57 | return to_return 58 | 59 | def get_up_to_samples(self, number_of_samples): 60 | if (len(self.delayed_inputs.get(self.current_position, [])) == 0): 61 | print("in DelayClassification, asked for samples with empty buffer, returning None") 62 | return None 63 | else: 64 | to_be_written_after_delay = self.delayed_inputs.get(self.current_position, []) 65 | to_return = to_be_written_after_delay[0:number_of_samples] 66 | to_be_written_after_delay = to_be_written_after_delay[number_of_samples:] 67 | self.delayed_inputs[self.current_position] = to_be_written_after_delay 68 | return to_return 69 | 70 | def get_all_samples(self): 71 | if (len(self.delayed_inputs.get(self.current_position, [])) == 0): 72 | return [] 73 | else: 74 | to_return = self.delayed_inputs.get(self.current_position, []) 75 | self.delayed_inputs[self.current_position] = [] 76 | return to_return 77 | 78 | def get_number_of_delayed(self): 79 | return self.delayed_samples 80 | 81 | def get_load(self): 82 | return self.delayed_samples / float(self.max_load) -------------------------------------------------------------------------------- /analyst_manager/AnalystManagement.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool as ThreadPool 2 | import numpy as np 3 | 4 | from classifiers.OMNIGLOTClassifier import break_param_names_and_run_distance 5 | 6 | 7 | class AnalystManagement: 8 | def __init__(self, number_of_analysts, delay_creator, image_classifier, number_of_threads=20, maximal_delay_size=10000000): 9 | self.delay_creator = delay_creator 10 | self.maximal_size = maximal_delay_size 11 | if maximal_delay_size<=0: 12 | raise ValueError('maximal value for iterations should be positive, obtained {}'.format(maximal_delay_size)) 13 | self.current_delay_position = 0 14 | self.delayed_classifications = {} 15 | self.number_of_analysts = number_of_analysts 16 | self.busy_analysts = 0 17 | self.job_queue = [] 18 | # self.job_queue = Queue.Queue() 19 | self.image_classifier = image_classifier 20 | self.executionPool = ThreadPool(number_of_threads) 21 | 22 | def __add_item(self, item_to_add, delay=0): 23 | # Add element to be written 24 | position_after_delay = (self.current_delay_position + delay) % self.maximal_size 25 | to_be_written_after_delay = self.delayed_classifications.get(position_after_delay, []) 26 | to_be_written_after_delay.append(item_to_add) 27 | self.delayed_classifications[position_after_delay] = to_be_written_after_delay 28 | 29 | def __start_processing_of_job(self, buffer): #private 30 | wait_for = self.delay_creator.get_rand_delay() 31 | self.busy_analysts += 1 32 | self.__add_item(buffer, delay=wait_for) 33 | 34 | def analysts_load(self): 35 | return float(self.busy_analysts + len(self.job_queue))/self.number_of_analysts 36 | 37 | def add_classification_job(self, buffer): 38 | if self.busy_analysts 0): 77 | res = 1.0 - min(clean_distances) 78 | return res 79 | else: 80 | return 0.0 81 | 82 | -------------------------------------------------------------------------------- /sample_generators/OMNIGLOTGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | 5 | from classifiers.Images import load_transform, get_shuffled_images 6 | 7 | 8 | class OMNIGLOTGenerator(object): 9 | """Docstring for OmniglotGenerator""" 10 | def __init__(self, data_folder, letter_swap=1, batch_size=1, classes=5, samples_per_class=10, max_rotation=0., max_shift=0., img_size=(20,20), number_of_classes=30, max_iter=None, only_labels_and_images=False): 11 | super(OMNIGLOTGenerator, self).__init__() 12 | self.data_folder = data_folder 13 | self.letter_swap = letter_swap 14 | self.batch_size = batch_size 15 | self.number_of_classes = number_of_classes 16 | self.classes = classes 17 | self.samples_per_class = samples_per_class 18 | self.max_rotation = max_rotation 19 | self.max_shift = max_shift 20 | self.img_size = img_size 21 | self.max_iter = max_iter 22 | self.num_iter = 0 23 | self.only_labels_and_images = only_labels_and_images 24 | self.character_folders = [os.path.join(self.data_folder, family, character) \ 25 | for family in os.listdir(self.data_folder) \ 26 | if os.path.isdir(os.path.join(self.data_folder, family)) \ 27 | for character in os.listdir(os.path.join(self.data_folder, family))] 28 | self.working_characters = random.sample(self.character_folders, self.classes) 29 | self.working_labels = np.random.choice(self.number_of_classes, self.classes, replace=False).tolist() 30 | self.cacheDict = {} 31 | self.newest_swapped_letter = self.working_labels[0] 32 | 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | return self.next() 39 | 40 | def next(self): 41 | if (self.max_iter is None) or (self.num_iter < self.max_iter): 42 | self.num_iter += 1 43 | return (self.num_iter - 1), self.sample(self.classes) 44 | else: 45 | raise StopIteration 46 | 47 | def _get_class_for_folders(self, working_chars): 48 | char_labels = np.random.choice(range(self.number_of_classes), len(working_chars), replace=False) 49 | return char_labels.tolist() 50 | 51 | 52 | def get_image_by_name_and_chars(self, filename, angle, shift, img_size): 53 | e_key = (filename, angle, shift, img_size) 54 | if e_key in self.cacheDict: 55 | return self.cacheDict[e_key] 56 | else: 57 | obtained_image = load_transform(filename, angle=angle, s=shift, size=img_size).flatten() 58 | self.cacheDict[e_key] = obtained_image 59 | return obtained_image 60 | 61 | def sample(self, nb_samples): 62 | for i in range(self.letter_swap): 63 | index_to_swap = random.randint(0, nb_samples-1) 64 | self.newest_swapped_letter = self.working_labels[index_to_swap] 65 | self.working_characters[index_to_swap] = random.sample(self.character_folders, 1)[0] 66 | self.working_labels[index_to_swap] = np.random.randint(0,self.number_of_classes,1)[0] 67 | 68 | example_inputs = np.zeros((self.batch_size, nb_samples * self.samples_per_class, np.prod(self.img_size)), dtype=np.float32) 69 | example_outputs = np.zeros((self.batch_size, nb_samples * self.samples_per_class), dtype=np.float32) #notice hardcoded np.float32 here and above, change it to something else in tf 70 | folder_and_labels_only_list = [] 71 | 72 | for i in range(self.batch_size): 73 | labels_and_images = get_shuffled_images(self.working_characters, self.working_labels, nb_samples=self.samples_per_class) 74 | if(self.only_labels_and_images): 75 | if (folder_and_labels_only_list == []): 76 | folder_and_labels_only_list = [labels_and_images] 77 | else: 78 | folder_and_labels_only_list = np.vstack((folder_and_labels_only_list, [labels_and_images])) 79 | continue 80 | 81 | sequence_length = len(labels_and_images) 82 | labels, image_files = zip(*labels_and_images) 83 | 84 | angles = np.random.uniform(-self.max_rotation, self.max_rotation, size=sequence_length) 85 | shifts = np.random.uniform(-self.max_shift, self.max_shift, size=sequence_length) 86 | 87 | example_inputs[i] = np.asarray([self.get_image_by_name_and_chars(filename, angle=angle, shift=shift, img_size=self.img_size).flatten() \ 88 | for (filename, angle, shift) in zip(image_files, angles, shifts)], dtype=np.float32) 89 | example_outputs[i] = np.asarray(labels, dtype=np.int32) 90 | 91 | if (self.only_labels_and_images): 92 | return folder_and_labels_only_list 93 | else: 94 | return example_inputs, example_outputs 95 | 96 | def get_last_swapped_letter(self): 97 | return self.newest_swapped_letter 98 | 99 | 100 | -------------------------------------------------------------------------------- /classifiers/UNSWClassifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.naive_bayes import GaussianNB 3 | 4 | from sklearn.metrics.pairwise import paired_cosine_distances 5 | from sklearn.metrics.pairwise import paired_manhattan_distances 6 | from sklearn.metrics.pairwise import paired_euclidean_distances 7 | 8 | 9 | 10 | class UNSWClassifier(object): 11 | def __init__(self, number_of_classes, normal_memory_size=1000, mal_memory_size=50, max_distance=2.0, similarity='l2'): 12 | super(UNSWClassifier, self).__init__() 13 | self.number_of_classes = number_of_classes 14 | self.normal_memory_size = normal_memory_size 15 | self.mal_memory_size = mal_memory_size 16 | self.normal_labels = [] 17 | self.mal_labels = [] 18 | self.normal_data_storage = None 19 | self.mal_data_storage = None 20 | self.gnb = GaussianNB() 21 | self.max_distance = max_distance 22 | self.similarity = similarity 23 | 24 | def classify(self, data_v): 25 | if((len(self.normal_labels) + len(self.mal_labels)) < 2): 26 | print("Calling classification without enough image samples - returning zero vector") 27 | return np.zeros(self.number_of_classes) 28 | 29 | if(self.normal_data_storage == None): 30 | all_data = self.mal_data_storage 31 | elif(self.mal_data_storage == None): 32 | all_data = self.normal_data_storage 33 | else: 34 | all_data = self.normal_data_storage + self.mal_data_storage 35 | all_labels = self.normal_labels + self.mal_labels 36 | data_length = len(all_labels) 37 | 38 | min_distance = np.ones(self.number_of_classes) * self.max_distance 39 | 40 | if(self.similarity == 'cos'): 41 | distances = paired_cosine_distances(all_data, np.tile(data_v, (data_length, 1))) 42 | elif(self.similarity == 'l1'): 43 | distances = paired_manhattan_distances(all_data, np.tile(data_v, (data_length, 1))) 44 | else: 45 | distances = paired_euclidean_distances(all_data, np.tile(data_v, (data_length, 1))) 46 | 47 | 48 | for index_i in range(data_length): 49 | if (abs(distances[index_i]) < min_distance[all_labels[index_i]]): 50 | min_distance[all_labels[index_i]] = abs(distances[index_i]) 51 | 52 | min_distance = self.max_distance - min_distance 53 | return min_distance 54 | 55 | def add_image_sample(self, data_v, label): 56 | if(label >= self.number_of_classes): 57 | print("image label must not be greater than number of classes - 1") 58 | return 59 | 60 | if(label == 0): #Normal 61 | if (self.normal_data_storage == None): 62 | self.normal_data_storage = [data_v] 63 | self.normal_labels = [label] 64 | else: 65 | self.normal_data_storage.append(data_v) 66 | self.normal_labels.append(label) 67 | 68 | if (len(self.normal_labels) > self.normal_memory_size): 69 | self.normal_data_storage = self.normal_data_storage[1:] 70 | self.normal_labels = self.normal_labels[1:] 71 | else: 72 | if (self.mal_data_storage == None): 73 | self.mal_data_storage = [data_v] 74 | self.mal_labels = [label] 75 | else: 76 | self.mal_data_storage.append(data_v) 77 | self.mal_labels.append(label) 78 | 79 | if (len(self.mal_labels) > self.mal_memory_size): 80 | self.mal_data_storage = self.mal_data_storage[1:] 81 | self.mal_labels = self.mal_labels[1:] 82 | 83 | if(self.normal_data_storage == None): 84 | all_data = self.mal_data_storage 85 | elif(self.mal_data_storage == None): 86 | all_data = self.normal_data_storage 87 | else: 88 | all_data = self.normal_data_storage + self.mal_data_storage 89 | all_labels = self.normal_labels + self.mal_labels 90 | 91 | self.gnb = GaussianNB() 92 | self.gnb = self.gnb.fit(np.array(all_data), np.array(all_labels)) 93 | 94 | def distance_from_sample(self, all_data, single_samples): 95 | if(all_data == None or len(all_data) == 0): 96 | return None 97 | 98 | data_length = len(all_data) 99 | 100 | if(self.similarity == 'cos'): 101 | distances = paired_cosine_distances(all_data, np.tile(single_samples, (data_length, 1))) 102 | elif(self.similarity == 'l1'): 103 | distances = paired_manhattan_distances(all_data, np.tile(single_samples, (data_length, 1))) 104 | else: 105 | distances = paired_euclidean_distances(all_data, np.tile(single_samples, (data_length, 1))) 106 | 107 | min_distance = 5.0 108 | for index_i in range(data_length): 109 | if (abs(distances[index_i]) < min_distance): 110 | min_distance = abs(distances[index_i]) 111 | 112 | return min_distance 113 | 114 | 115 | def get_database_size(self): 116 | return len(self.normal_labels + self.mal_labels) 117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement One-Shot Learning (DeROL) Classification Framework 2 | 3 | This repository contains the DeROL framework code, as presented in "Deep Reinforcement One-Shot Learning for Artificially Intelligent Classification Systems". 4 | The paper has been uploaded to arXiv (identifier 1808.01527) and is accessible from: http://arxiv.org/abs/1808.01527 5 | 6 | If this code was used in your research please cite our paper: [BibTeX](https://github.com/antonpuz/DeROL#please-cite-our-paper) 7 | 8 | 9 | 10 | ## Special Notes 11 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | * Any code/data piece supplementary to this repository **must** be used in accordance to its own license and terms 13 | * [`datasets`](datasets) has the instructions for obtaining the datasets 14 | 15 | ## Requirements 16 | * We ran the code on Ubuntu 16.04.4 LTS but it should be easily used in MacOS, using Windows would require fixing all the paths in "train_derol.py" 17 | * We used Python 2.7.12 with the following packages: tensorflow, numpy, scipy, sklearn 18 | 19 | ## Intoduction 20 | In recent years there has been a sharp rise in applications, in which significant events need to be classified but only a few training instances are available. These are known as cases of one-shot learning. Examples from different disciplines include surveillance and security, environmental monitoring, and patient monitoring. To handle this challenging task, organizations often use human analysts to classify events under high uncertainty. Existing algorithms use a threshold-based mechanism to decide whether to classify an object automatically or send it to an analyst for deeper inspection. However, this approach leads to a significant waste of resources since it does not take the practical temporal constraints of system resources into account. Our contribution is threefold. First, we develop a novel Deep Reinforcement One-shot Learning (DeROL) framework to address this challenge. The basic idea of the DeROL algorithm is to train a deep-Q network to obtain a policy which is oblivious to the unseen classes in the testing data. Then, in real-time, DeROL maps the current state of the one-shot learning process to operational actions based on the trained deep-Q network, to maximize the objective function. Second, we develop the first open-source software for practical artificially intelligent oneshot classification systems with limited resources for the benefit of researchers and developers in related disciplines. Third, we present an extensive experimental study using the OMNIGLOT dataset for computer vision tasks and the UNSW-NB15 dataset for intrusion detection tasks that demonstrates the versatility and efficiency of the DeROL framework. 21 | 22 | ## Modules 23 | * **rl_agents**: Has implementation of the policies, the policy is the conductor and the heart of our framework, we used Deep Reinforcement Learning policy, 24 | with an Long Short Term Memory component and 1 hidden layer. 25 | * **classifiers**: The classifier outputs similarity measure to all the known classes, almost every dataset would require a either new or adapted classifier. 26 | The classifier must support one-shot behavior, we used samples bank with a distance metric to predict the similarity to all the classes. We implemented classifiers for two different datasets: 27 | For OMNIGLOT dataset: we used [“A Modified Hausdorff Distance for Object Matching”](http://www.cse.msu.edu/prip/Files/DubuissonJain.pdf) as distance metric 28 | For UNSW0NB15: we used [“Euclidean Distance”](https://en.wikipedia.org/wiki/Euclidean_distance) as distance metric 29 | * **analyst_manager**: is used to manage the analyst's work cycle 30 | * **delay**: holds random delay creators to be used in the framework 31 | * **sample_generators**: Has implementation of generators, which are responsible for reading data saved on disk, and create ordered batches from it. 32 | We implemented two types of generator, one for every dataset. 33 | OMNIGLOTGenerator: supports reading image samples from OMNIGLOT dataset, returns either raw pixel values or image names 34 | UNSWGenerator: supports reading "csv" files, and normalize them according to pre-calculated mean and standard deviation values. 35 | * **sample_handlers**: has the classes which support extra sample holding classes, 36 | DelayClassification: holds samples marked to be delayed by the policy 37 | ExperimentLogger: holds {sample, action, reward, future_reward} tuples, used later for training. 38 | 39 | ## Please Cite Our Paper 40 | @article{puzanov2020deep, 41 | title={Deep reinforcement one-shot learning for artificially intelligent classification in expert aided systems}, 42 | author={Puzanov, Anton and Zhang, Senyang and Cohen, Kobi}, 43 | journal={Engineering Applications of Artificial Intelligence}, 44 | volume={91}, 45 | pages={103589}, 46 | year={2020}, 47 | publisher={Elsevier} 48 | } 49 | -------------------------------------------------------------------------------- /classifiers/OMNIGLOTClassifier.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage import imread 2 | from scipy.misc import imresize 3 | from scipy.spatial.distance import cdist 4 | import numpy as np 5 | from multiprocessing import Pool as ThreadPool 6 | 7 | from classifiers import Images 8 | 9 | 10 | def _ModHausdorffDistance(itemA, itemB): 11 | D = cdist(itemA, itemB) 12 | mindist_A = D.min(axis=1) 13 | mindist_B = D.min(axis=0) 14 | mean_A = np.mean(mindist_A) 15 | mean_B = np.mean(mindist_B) 16 | return max(mean_A, mean_B) 17 | 18 | def break_param_names_and_run_distance(param): 19 | return (param[0], _ModHausdorffDistance(param[1], param[2])) 20 | 21 | class OMNIGLOTClassifier(object): 22 | def __init__(self, number_of_classes, memory_size=100, number_of_executors=100, image_size=(20,20)): 23 | super(OMNIGLOTClassifier, self).__init__() 24 | self.number_of_classes = number_of_classes 25 | self.image_size = image_size 26 | self.memory_size = memory_size 27 | self.image_names = np.array([]) 28 | self.image_loads = {} 29 | self.image_labels = {} 30 | self.different_labels = {} 31 | self.representing_image_cache = {} 32 | self.image_points_cache = {} 33 | self.executionPool = ThreadPool(min(number_of_executors, memory_size)) 34 | 35 | def LoadImgAsPoints(self, fn): 36 | if(fn in self.image_points_cache): 37 | return self.image_points_cache[fn] 38 | 39 | I = imread(fn, flatten=True) 40 | I = np.asarray(imresize(I, size=self.image_size), dtype=np.float32) 41 | I[I<255] = 0 42 | I = np.array(I, dtype=bool) 43 | I = np.logical_not(I) 44 | (row, col) = I.nonzero() 45 | D = np.array([row, col]) 46 | D = np.transpose(D) 47 | D = D.astype(float) 48 | n = D.shape[0] 49 | mean = np.mean(D, axis=0) 50 | for i in range(n): 51 | D[i, :] = D[i, :] - mean 52 | 53 | self.image_points_cache[fn] = D 54 | return D 55 | 56 | def load_representing_image(self, fn): 57 | if(fn in self.representing_image_cache): 58 | return self.representing_image_cache[fn] 59 | 60 | I = Images.load_transform(fn, size=self.image_size) 61 | 62 | self.representing_image_cache[fn] = I 63 | 64 | return I 65 | 66 | def classify(self, image_file): 67 | if(len(self.image_names) == 0): 68 | print("Calling classification without enough image samples - returning zero vector") 69 | return np.zeros(self.number_of_classes) 70 | 71 | points_image = self.LoadImgAsPoints(image_file) 72 | min_distance = np.ones(self.number_of_classes) * 2.0 73 | unbraked_values = [(self.image_labels[image], self.LoadImgAsPoints(image), points_image) for image in self.image_names] 74 | distances = self.executionPool.map(break_param_names_and_run_distance, unbraked_values) 75 | 76 | for distance in distances: 77 | if (distance[1] < min_distance[distance[0]]): 78 | min_distance[distance[0]] = distance[1] 79 | 80 | min_distance = 2.0 - min_distance 81 | return min_distance 82 | 83 | def add_image_sample(self, image_file, label): 84 | if(label >= self.number_of_classes): 85 | print("image label must not be greater than number of classes - 1") 86 | return 87 | 88 | if(image_file in self.image_loads): 89 | old_label = self.image_labels[image_file] 90 | if (self.different_labels[old_label] == 1): 91 | del self.different_labels[old_label] 92 | else: 93 | self.different_labels[old_label] -= 1 94 | 95 | if(not image_file in self.image_loads): 96 | self.image_names = np.append(self.image_names, image_file) 97 | if (label in self.different_labels): 98 | self.different_labels[label] += 1 99 | else: 100 | self.different_labels[label] = 1 101 | self.image_loads[image_file] = self.LoadImgAsPoints(image_file) 102 | self.image_labels[image_file] = label 103 | 104 | to_be_deleted = None 105 | if(len(self.image_names) > self.memory_size): 106 | to_be_deleted = self.image_names[0] 107 | self.image_names = self.image_names[1:] 108 | 109 | if(to_be_deleted != None): 110 | del self.image_loads[to_be_deleted] 111 | removed_label = self.image_labels[to_be_deleted] 112 | if(not removed_label in self.different_labels): 113 | print("error found") 114 | if (self.different_labels[removed_label] == 1): 115 | del self.different_labels[removed_label] 116 | else: 117 | self.different_labels[removed_label] -= 1 118 | del self.image_labels[to_be_deleted] 119 | 120 | if(sum(self.different_labels.values()) > self.memory_size): 121 | print("classifier reached an unexpected") 122 | 123 | if (len(self.image_loads) != len(self.image_labels) or len(self.image_labels) != len(self.image_names)): 124 | print("classifier reached an unexpected") 125 | 126 | def database_size(self): 127 | return len(self.different_labels) 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /sample_generators/UNSWGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import csv 5 | 6 | class UNSWGenerator(object): 7 | # STD 8 | # [3.11718494e+00, 5.05267629e+04, 1.59565885e+05, 6.01614215e+01, 3.68997951e+01, 2.04402226e+01, 5.60362463e+01, 9.02498798e+07, 4.38022457e+06, 7.34661981e+01, 1.20085728e+02, 1.21618092e+02, 1.21674944e+02, 1.43249037e+09, 1.43025723e+09, 1.53408714e+02, 3.38400001e+02, 3.48792432e-01, 5.02959031e+04, 1.62749588e+04, 2.81668871e+03, 3.03954671e+03, 1.58047419e+03, 3.18336467e-02, 1.71418292e-02, 1.64202905e-02, 4.47561323e-02, 5.28876325e-01, 3.88885807e-01, 1.35820667e-01, 1.37266282e-01, 7.48326410e+00, 7.43839560e+00, 5.30909327e+00, 5.48918817e+00, 5.40727891e+00, 3.86212078e+00, 7.39321762e+00] 9 | # MEAN 10 | # [5.83990359e-01, 4.43920514e+03, 3.82112115e+04, 4.98032817e+01, 3.16409877e+01, 5.55065064e+00, 1.73402538e+01, 2.01082635e+07, 2.74214517e+06, 3.56218705e+01, 4.57140918e+01, 1.65776746e+02, 1.65596682e+02, 1.39622413e+09, 1.39395230e+09, 1.25076663e+02, 3.05978873e+02, 9.58984344e-02, 4.96244366e+03, 1.58372001e+03, 7.78032492e+02, 2.25335137e+02, 9.18970630e+01, 4.39537934e-03, 2.37087431e-03, 2.02450502e-03, 2.00713999e-03, 1.54769779e-01, 1.05521278e-01, 1.82014026e-02, 1.83442595e-02, 6.72776896e+00, 6.54925922e+00, 4.51331784e+00, 5.03591995e+00, 2.60927913e+00, 1.92820439e+00, 4.06247991e+00] 11 | 12 | def __init__(self, data_folder, letter_swap=1, batch_size=1, classes=5, samples_per_class=10, max_iter=None, train=True): 13 | super(UNSWGenerator, self).__init__() 14 | self.data_folder = data_folder 15 | self.trainLabelDict = {"Normal":0, "Exploits":1, "Reconnaissance":1, "DoS":1, \ 16 | "Generic": 1, "Shellcode":1, "Analysis":1} 17 | self.trainAttackDict = {"Exploits": 1, "Reconnaissance": 1, "DoS": 1, \ 18 | "Generic": 1, "Shellcode": 1, "Analysis": 1} 19 | 20 | self.testLabelDict = {"Normal_test":0, "Backdoors_test":1, "Fuzzers_test": 1, "Worms_test":1, 21 | "Generic_test": 1, "Shellcode_test":1, "Analysis_test":1} 22 | self.testAttackDict = {"Backdoors_test": 1, "Fuzzers_test": 1, "Worms_test": 1, 23 | "Generic_test": 1, "Shellcode_test": 1, "Analysis_test": 1} 24 | 25 | if(train): 26 | self.possible_labels = ["Normal"] + list(self.trainAttackDict) 27 | self.active_dir = self.trainLabelDict 28 | else: 29 | self.possible_labels = ["Normal_test"] + list(self.testAttackDict) 30 | self.active_dir = self.testLabelDict 31 | self.all_data_in_memory = {} 32 | 33 | for label in self.possible_labels: 34 | file = os.path.join(self.data_folder, label + "_UNSW.csv") 35 | print("reading from {}".format(file)) 36 | with open(file, 'rb') as csvfile: 37 | spamreader = csv.reader(csvfile, quotechar='|') 38 | all_relevant_data = [row for row in spamreader] 39 | self.all_data_in_memory[label] = all_relevant_data 40 | 41 | 42 | self.letter_swap = letter_swap 43 | self.batch_size = batch_size 44 | self.classes = classes 45 | self.samples_per_class = samples_per_class 46 | self.max_iter = max_iter 47 | self.num_iter = 0 48 | self.features = 38 49 | self.working_labels = [] 50 | for i in range(self.classes - 2): 51 | new_index = random.randint(0, len(self.possible_labels) - 1) 52 | self.working_labels.append(self.possible_labels[new_index]) 53 | for i in range(2): 54 | if(train): 55 | self.working_labels.append("Normal") 56 | else: 57 | self.working_labels.append("Normal_test") 58 | self.cacheDict = {} 59 | self.newest_swapped_letter = self.working_labels[0] 60 | self.sample_mean = np.array([5.83990359e-01, 4.43920514e+03, 3.82112115e+04, 4.98032817e+01, 3.16409877e+01, 5.55065064e+00, 1.73402538e+01, 2.01082635e+07, 2.74214517e+06, 3.56218705e+01, 4.57140918e+01, 1.65776746e+02, 1.65596682e+02, 1.39622413e+09, 1.39395230e+09, 1.25076663e+02, 3.05978873e+02, 9.58984344e-02, 4.96244366e+03, 1.58372001e+03, 7.78032492e+02, 2.25335137e+02, 9.18970630e+01, 4.39537934e-03, 2.37087431e-03, 2.02450502e-03, 2.00713999e-03, 1.54769779e-01, 1.05521278e-01, 1.82014026e-02, 1.83442595e-02, 6.72776896e+00, 6.54925922e+00, 4.51331784e+00, 5.03591995e+00, 2.60927913e+00, 1.92820439e+00, 4.06247991e+00]) 61 | self.sample_std = np.array([[3.11718494e+00, 5.05267629e+04, 1.59565885e+05, 6.01614215e+01, 3.68997951e+01, 2.04402226e+01, 5.60362463e+01, 9.02498798e+07, 4.38022457e+06, 7.34661981e+01, 1.20085728e+02, 1.21618092e+02, 1.21674944e+02, 1.43249037e+09, 1.43025723e+09, 1.53408714e+02, 3.38400001e+02, 3.48792432e-01, 5.02959031e+04, 1.62749588e+04, 2.81668871e+03, 3.03954671e+03, 1.58047419e+03, 3.18336467e-02, 1.71418292e-02, 1.64202905e-02, 4.47561323e-02, 5.28876325e-01, 3.88885807e-01, 1.35820667e-01, 1.37266282e-01, 7.48326410e+00, 7.43839560e+00, 5.30909327e+00, 5.48918817e+00, 5.40727891e+00, 3.86212078e+00, 7.39321762e+00]]) 62 | 63 | 64 | def __iter__(self): 65 | return self 66 | 67 | def __next__(self): 68 | return self.next() 69 | 70 | def next(self): 71 | if (self.max_iter is None) or (self.num_iter < self.max_iter): 72 | self.num_iter += 1 73 | return (self.num_iter - 1), self.sample(self.classes) 74 | else: 75 | raise StopIteration 76 | 77 | def sample(self, nb_samples): 78 | for i in range(self.letter_swap): 79 | index_to_swap = random.randint(0, self.classes - 3) 80 | new_index = random.randint(0, len(self.possible_labels)-1) 81 | self.newest_swapped_letter = self.working_labels[0] 82 | self.working_labels[index_to_swap] = self.possible_labels[new_index] 83 | 84 | example_inputs = np.zeros((self.batch_size, nb_samples * self.samples_per_class, self.features), dtype=np.float32) 85 | example_outputs = np.zeros((self.batch_size, nb_samples * self.samples_per_class), dtype=np.float32) #notice hardcoded np.float32 here and above, change it to something else in tf 86 | 87 | for i in range(self.batch_size): 88 | labels_and_samples = [(self.active_dir[active_label], x) for active_label in self.working_labels for x in random.sample(self.all_data_in_memory[active_label], self.samples_per_class)] 89 | random.shuffle(labels_and_samples) 90 | 91 | sequence_length = len(labels_and_samples) 92 | labels, samples = zip(*labels_and_samples) 93 | samples = np.asarray(samples, dtype=np.float32) 94 | #normalize 95 | samples = (samples - self.sample_mean) / self.sample_std 96 | 97 | example_inputs[i] = samples 98 | example_outputs[i] = np.asarray(labels, dtype=np.int32) 99 | 100 | return example_inputs, example_outputs 101 | 102 | def get_last_swapped_letter(self): 103 | return self.newest_swapped_letter 104 | 105 | 106 | -------------------------------------------------------------------------------- /train_derol_OMNIGLOT.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from analyst_manager.AnalystManagement import AnalystManagement 5 | from classifiers.OMNIGLOTClassifier import OMNIGLOTClassifier 6 | from delay.UniformDelay import UniformDelay 7 | from rl_agents.DeROLAgent import DeROLAgent 8 | from sample_generators.OMNIGLOTGenerator import OMNIGLOTGenerator 9 | from sample_handlers.DelayClassification import DelayClassification 10 | from sample_handlers.ExperimentLogger import ExperimentLogger 11 | 12 | np.set_printoptions(threshold='nan') 13 | import time 14 | import sys 15 | import os 16 | 17 | ##Configurations 18 | experiment_name = "one-shot-OMNIGLOT" 19 | batch_size = 1 20 | loaded_characters = 3 21 | total_classes = 30 22 | actions = total_classes + 2 23 | samples_per_class = 10 24 | samples_per_batch = loaded_characters * samples_per_class 25 | total_input_size = 34 26 | lstm_units = 200 27 | learning_rate = 0.001 28 | gamma = 0.5 29 | eps = 0.05 30 | number_of_analysts = 3 31 | delay_classification_penalty_base = -0.05 32 | delay_classification_penalty_multiplier = 2 33 | classification_delay_timesteps = 10 34 | wrong_classification_penalty = -1 35 | analyst_load_penalty_multiplier = -0.1 36 | full_analyst_load = 5.0 37 | analyst_delay_param = 9.0 38 | training_batches = 600000 39 | statistics_siplay_step = 50 40 | max_number_of_delayed = 100 41 | is_training_phase = True 42 | ##End Configurations 43 | 44 | 45 | tf.reset_default_graph() 46 | 47 | rlAgent = DeROLAgent(batch_size=batch_size, samples_per_batch=samples_per_batch, total_input_size=total_input_size, 48 | actions=actions, learning_rate=learning_rate) 49 | image_classifier = OMNIGLOTClassifier(total_classes, memory_size=25) 50 | analyst_expr_manager = ExperimentLogger(samples_per_batch, graph_creator=None) 51 | samples_delay_manager = DelayClassification(max_load=max_number_of_delayed) 52 | delay_creator = UniformDelay(analyst_delay_param) 53 | job_manager = AnalystManagement(number_of_analysts, delay_creator, image_classifier) 54 | 55 | init = tf.global_variables_initializer() 56 | 57 | def benchmark(argv): 58 | print "Starting Training from data: " + str(argv[0]) 59 | 60 | with tf.Session() as sess: 61 | sess.run(init) 62 | 63 | # try to load a saved model 64 | model_loader = tf.train.Saver() 65 | current_model_folder = "./trained_models/backup/model-" + experiment_name 66 | if (os.path.exists(current_model_folder)): 67 | print("Loading pre calculaated model") 68 | model_loader.restore(sess, current_model_folder + "/model.ckpt") 69 | print("v1 : %s" % rlAgent.policy_biases['out'].eval()[0]) 70 | else: 71 | print("Creating the folder for the model to be stored in") 72 | os.makedirs(current_model_folder) 73 | 74 | generator = OMNIGLOTGenerator(data_folder=argv[0], batch_size=batch_size, samples_per_class=samples_per_class, 75 | classes=loaded_characters, only_labels_and_images=True) 76 | 77 | t0 = time.time() 78 | 79 | print("Starting real training") 80 | 81 | global eps 82 | global is_training_phase 83 | 84 | policy_online_state = (np.zeros([batch_size, lstm_units]), 85 | np.zeros([batch_size, lstm_units])) # Reset the policy layer's hidden state 86 | 87 | policy_training_state = (np.zeros([batch_size, lstm_units]), 88 | np.zeros([batch_size, lstm_units])) # Reset the policy layer's hidden state 89 | 90 | accuracy_agg = [] 91 | confusion_matrix = np.zeros(6) 92 | sample_per_batch_agg = [] 93 | reward_agg = [] 94 | loss_agg = [] 95 | 96 | r_m1 = None 97 | a_m1 = None 98 | x_m1 = None 99 | 100 | for i, image_and_labels in generator: 101 | 102 | for batch in range(batch_size): 103 | sample_counter = 0 104 | labels, image_files = zip(*image_and_labels[batch]) 105 | samples_length = len(image_files) 106 | if ((i % statistics_siplay_step == 0) and (batch == 0)): 107 | print("Starting policy training on step " + str(i) + ", over data with length: " + str(samples_length)) 108 | if(is_training_phase): 109 | print("Saving most recent model") 110 | save_path = model_loader.save(sess, current_model_folder + "/model.ckpt") 111 | print("Model saved in path: %s" % save_path) 112 | print("v1 : %s" % rlAgent.policy_biases['out'].eval()[0]) 113 | 114 | 115 | timeseries = np.array(image_files) 116 | batch_y = np.array([int(float(x)) for x in labels]) 117 | 118 | sample = 0 119 | reward_sampling = [] 120 | 121 | while sample < samples_length: 122 | 123 | x_sample_buffer = None 124 | classification_logits_sample = None 125 | sample_file = None 126 | is_delayed_sample = False 127 | delay_penalty = delay_classification_penalty_base 128 | y = -1 129 | 130 | if(samples_delay_manager.is_waiting_sample()): 131 | x_sample_buffer_tuple = samples_delay_manager.get_sample() 132 | delay_penalty = x_sample_buffer_tuple[0] 133 | x_sample_buffer = x_sample_buffer_tuple[1] 134 | sample_file = x_sample_buffer_tuple[3] 135 | classification_logits_sample = image_classifier.classify(sample_file) 136 | y = x_sample_buffer_tuple[2] 137 | is_delayed_sample = True 138 | else: 139 | classification_logits_sample = image_classifier.classify(timeseries[sample]) 140 | y = int(batch_y[sample]) 141 | x_sample_buffer = image_classifier.load_representing_image(timeseries[sample]) 142 | sample_file = timeseries[sample] 143 | 144 | accuracy_agg.append(np.argmax(classification_logits_sample) == y) 145 | sample_counter += 1 146 | 147 | sample_for_policy = np.concatenate( 148 | ([samples_delay_manager.get_load()], [delay_penalty], [job_manager.analysts_load()], [job_manager.distance_to_samples_in_work(sample_file)], np.reshape(classification_logits_sample, -1))) 149 | qp1, a, new_online_state = sess.run([rlAgent.policy_logits, rlAgent.predicted_action, rlAgent.policy_rnn_state], 150 | feed_dict={rlAgent.X: np.reshape(sample_for_policy, [batch_size, 1, -1]), 151 | rlAgent.policy_state_in: policy_online_state, rlAgent.timeseries_length: 1}) 152 | 153 | policy_online_state = new_online_state 154 | a = a[0] 155 | qp1 = qp1[0] 156 | 157 | rand_action_sample = np.random.rand(1) 158 | if rand_action_sample < eps: 159 | rand_action_sample = np.random.rand(1) 160 | if rand_action_sample < 0.25: 161 | a = total_classes 162 | elif rand_action_sample < 0.5: 163 | a = total_classes + 1 164 | elif rand_action_sample < 0.75: 165 | a = y 166 | else: 167 | a = np.random.randint(0, total_classes) 168 | 169 | reward = 1 170 | if (a == total_classes + 1): # Delay in classification 171 | if(samples_delay_manager.get_number_of_delayed() >= max_number_of_delayed ): 172 | reward = wrong_classification_penalty 173 | confusion_matrix[4] += 1 174 | else: 175 | reward = delay_penalty 176 | delay_penalty *= delay_classification_penalty_multiplier 177 | samples_delay_manager.add_item((delay_penalty, x_sample_buffer, y, sample_file), delay=classification_delay_timesteps) 178 | confusion_matrix[2] += 1 179 | elif (a == total_classes): # Asked for classification 180 | if(job_manager.analysts_load() < full_analyst_load): 181 | job_manager.add_classification_job((sample_file, y)) 182 | reward = analyst_load_penalty_multiplier * job_manager.analysts_load() 183 | confusion_matrix[3] += 1 184 | else: 185 | reward = wrong_classification_penalty 186 | confusion_matrix[5] += 1 187 | else: 188 | if (a != y): 189 | reward = wrong_classification_penalty 190 | confusion_matrix[1] += 1 191 | else: 192 | confusion_matrix[0] += 1 193 | 194 | reward_sampling.append(reward) 195 | max_q = max(qp1) 196 | if(a_m1 != None and is_training_phase): 197 | analyst_expr_manager.add_item( 198 | np.array([x_m1, a_m1, r_m1, max_q]), delay=0) 199 | x_m1 = sample_for_policy 200 | a_m1 = a 201 | r_m1 = reward 202 | 203 | # Train the policy 204 | if( (analyst_expr_manager.number_of_batches() >= batch_size) and is_training_phase): 205 | expr_matrix = np.reshape(analyst_expr_manager.create_batch(batch_size), [-1, 4]) 206 | x_matrix = np.reshape(np.vstack(expr_matrix[:, 0]), 207 | [batch_size, samples_per_batch, total_input_size]) 208 | action_matrix = np.reshape(expr_matrix[:, 1], [-1, samples_per_batch]) 209 | q_calc_expr = expr_matrix[:, 2] + gamma * expr_matrix[:, 3] 210 | q_calc_expr = np.reshape(q_calc_expr, [-1, samples_per_batch]) 211 | 212 | state_t, _, loss = sess.run([rlAgent.policy_rnn_state, rlAgent.updateModel, rlAgent.loss], 213 | feed_dict={rlAgent.X: x_matrix, rlAgent.actions: action_matrix, 214 | rlAgent.Q_calculation: q_calc_expr, rlAgent.policy_state_in: policy_training_state}) 215 | loss_agg.append(loss) 216 | policy_training_state = policy_online_state 217 | 218 | 219 | if(a != total_classes + 1): #advance timestamp count if sample is not delayed 220 | job_manager.advance_time() 221 | if(not samples_delay_manager.is_waiting_sample()): 222 | samples_delay_manager.advance_state() 223 | if(not is_delayed_sample): 224 | sample += 1 225 | 226 | reward_agg.append(np.sum(reward_sampling)) 227 | reward_sampling = [] 228 | sample_per_batch_agg.append(sample_counter) 229 | 230 | 231 | #Statistics 232 | if(i % statistics_siplay_step == 0 and i!=0): 233 | print "Batch " + str(i) + " finished after " + str(time.time() - t0) + " seconds" 234 | print "Average cycle reward is: " + str(np.sum(reward_agg) / statistics_siplay_step) 235 | print("averaged accuracy since last print is " + str(np.average(accuracy_agg))) 236 | if(is_training_phase): 237 | print("averaged loss since last print " + str(np.average(loss_agg))) 238 | print("confusion matrix: " + str(confusion_matrix / (batch_size * statistics_siplay_step))) 239 | print("Average number of samples: new: {}, total: {}".format(samples_per_batch, np.average(sample_per_batch_agg))) 240 | 241 | #Clear the counters 242 | confusion_matrix = np.zeros(6) 243 | accuracy_agg = [] 244 | sample_per_batch_agg = [] 245 | loss_agg = [] 246 | reward_agg = [] 247 | 248 | 249 | if i == training_batches: 250 | eps = 0.0 251 | is_training_phase = False 252 | 253 | if i == training_batches*2: 254 | print("Finished full system training, execution time " + str(time.time() - t0) + " seconds") 255 | exit(1) 256 | 257 | print("Finished full system training, it took " + str(time.time() - t0) + " seconds") 258 | 259 | 260 | if __name__ == '__main__': 261 | if(len(sys.argv) != 2): 262 | print("train_derol should be called with image root folder as parameter") 263 | exit(1) 264 | benchmark(sys.argv[1:]) -------------------------------------------------------------------------------- /train_derol_UNSW.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from sample_generators.UNSWGenerator import UNSWGenerator 5 | 6 | np.set_printoptions(threshold='nan') 7 | import time 8 | import sys 9 | import math 10 | import os 11 | 12 | from analyst_manager.UNSWAnalystManagement import UNSWAnalystManagement 13 | from classifiers.UNSWClassifier import UNSWClassifier 14 | from delay.UniformDelay import UniformDelay 15 | from rl_agents.DeROLAgent import DeROLAgent 16 | from sample_handlers.DelayClassification import DelayClassification 17 | from sample_handlers.ExperimentLogger import ExperimentLogger 18 | 19 | 20 | ##Configurations 21 | experiment_name = "one-shot-UNSW-NB15" 22 | batch_size = 1 23 | loaded_characters = 3 24 | total_classes = 2 25 | actions = total_classes + 2 26 | samples_per_class = 20 27 | samples_per_batch = loaded_characters * samples_per_class 28 | total_input_size = 6 29 | lstm_units = 200 30 | learning_rate = 0.001 31 | gamma = 0.5 32 | eps = 0.05 33 | number_of_analysts = 2 34 | delay_classification_penalty_base = -1.0 35 | delay_classification_penalty_multiplier = 2 36 | classification_delay_timesteps = 5 37 | wrong_classification_penalty = -2 38 | analyst_load_penalty_multiplier = -0.5 39 | full_analyst_load = 5.0 40 | analyst_delay_param = 9.0 41 | training_batches = 500000 42 | statistics_siplay_step = 50 43 | max_number_of_delayed = 100 44 | is_training_phase = True 45 | ##End Configurations 46 | 47 | 48 | tf.reset_default_graph() 49 | 50 | rlAgent = DeROLAgent(batch_size=batch_size, samples_per_batch=samples_per_batch, total_input_size=total_input_size, 51 | actions=actions, learning_rate=learning_rate) 52 | image_classifier = UNSWClassifier(total_classes, normal_memory_size=10, mal_memory_size=10, max_distance=5.0, similarity='l2') 53 | analyst_expr_manager = ExperimentLogger(samples_per_batch, graph_creator=None) 54 | samples_delay_manager = DelayClassification(max_load=max_number_of_delayed) 55 | delay_creator = UniformDelay(analyst_delay_param) 56 | job_manager = UNSWAnalystManagement(number_of_analysts, delay_creator, image_classifier) 57 | 58 | 59 | init = tf.global_variables_initializer() 60 | 61 | def benchmark(argv): 62 | print "Starting Training from data: " + str(argv[0]) 63 | 64 | with tf.Session() as sess: 65 | sess.run(init) 66 | 67 | # try to load a saved model 68 | model_loader = tf.train.Saver() 69 | current_model_folder = "./trained_models/backup/model-" + experiment_name 70 | if (os.path.exists(current_model_folder)): 71 | print("Loading pre calculaated model") 72 | model_loader.restore(sess, current_model_folder + "/model.ckpt") 73 | print("v1 : %s" % rlAgent.policy_biases['out'].eval()[0]) 74 | else: 75 | print("Creating the folder for the model to be stored in") 76 | os.makedirs(current_model_folder) 77 | 78 | generator = UNSWGenerator(data_folder=argv[0], batch_size=batch_size, classes=loaded_characters, 79 | samples_per_class=samples_per_class) 80 | 81 | t0 = time.time() 82 | 83 | global eps 84 | global is_training_phase 85 | 86 | policy_online_state = (np.zeros([batch_size, lstm_units]), 87 | np.zeros([batch_size, lstm_units])) # Reset the policy layer's hidden state for iterative actions 88 | 89 | policy_training_state = (np.zeros([batch_size, lstm_units]), 90 | np.zeros( 91 | [batch_size, lstm_units])) # Reset the policy layer's hidden state for training 92 | 93 | accuracy_agg = [] 94 | confusion_matrix = np.zeros(6) 95 | sample_per_batch_agg = [] 96 | reward_agg = [] 97 | loss_agg = [] 98 | 99 | r_m1 = None 100 | a_m1 = None 101 | x_m1 = None 102 | 103 | for i, (image_files_t, labels_t) in generator: 104 | 105 | for batch in range(batch_size): 106 | image_files = image_files_t[batch] 107 | labels = labels_t[batch] 108 | sample_counter = 0 109 | samples_length = len(image_files) 110 | if ((i % statistics_siplay_step == 0) and (batch == 0)): 111 | print("Starting policy training on step " + str(i) + ", over data with length: " + str(samples_length)) 112 | if(is_training_phase): 113 | print("Saving most recent model") 114 | save_path = model_loader.save(sess, current_model_folder + "/model.ckpt") 115 | print("Model saved in path: %s" % save_path) 116 | print("v1 : %s" % rlAgent.policy_biases['out'].eval()[0]) 117 | 118 | 119 | timeseries = np.array(image_files) 120 | batch_y = np.array([int(x) for x in labels.flatten()]) 121 | 122 | sample = 0 123 | reward_sampling = [] 124 | 125 | while sample < samples_length: 126 | 127 | x_sample_buffer = None 128 | classification_logits_sample = None 129 | sample_file = None 130 | is_delayed_sample = False 131 | delay_penalty = delay_classification_penalty_base 132 | y = -1 133 | 134 | if(samples_delay_manager.is_waiting_sample()): 135 | x_sample_buffer_tuple = samples_delay_manager.get_sample() 136 | delay_penalty = x_sample_buffer_tuple[0] 137 | x_sample_buffer = x_sample_buffer_tuple[1] 138 | sample_file = x_sample_buffer_tuple[3] 139 | y = x_sample_buffer_tuple[2] 140 | classification_logits_sample = image_classifier.classify(sample_file) 141 | is_delayed_sample = True 142 | else: 143 | y = int(batch_y[sample]) 144 | classification_logits_sample = image_classifier.classify(timeseries[sample]) 145 | x_sample_buffer = timeseries[sample] 146 | sample_file = timeseries[sample] 147 | 148 | accuracy_agg.append(np.argmax(classification_logits_sample) == y) 149 | sample_counter += 1 150 | 151 | sample_for_policy = np.concatenate( 152 | ([samples_delay_manager.get_load()], [delay_penalty], [job_manager.analysts_load()], [job_manager.distance_to_samples_in_work(x_sample_buffer)], np.reshape(classification_logits_sample, -1))) 153 | qp1, a, new_online_state = sess.run([rlAgent.policy_logits, rlAgent.predicted_action, rlAgent.policy_rnn_state], 154 | feed_dict={rlAgent.X: np.reshape(sample_for_policy, [batch_size, 1, -1]), 155 | rlAgent.policy_state_in: policy_online_state, rlAgent.timeseries_length: 1}) 156 | 157 | policy_online_state = new_online_state 158 | a = a[0] 159 | qp1 = qp1[0] 160 | 161 | rand_action_sample = np.random.rand(1) 162 | if rand_action_sample < eps: 163 | rand_action_sample = np.random.rand(1) 164 | if rand_action_sample < 0.25: 165 | a = total_classes 166 | elif rand_action_sample < 0.5: 167 | a = total_classes + 1 168 | elif rand_action_sample < 0.75: 169 | a = y 170 | else: 171 | a = np.random.randint(0, total_classes) 172 | 173 | reward = 0 174 | if (a == total_classes + 1): # Delay in classification 175 | if(samples_delay_manager.get_number_of_delayed() >= max_number_of_delayed ): 176 | reward = wrong_classification_penalty 177 | confusion_matrix[4] += 1 178 | else: 179 | reward = delay_penalty 180 | delay_penalty *= delay_classification_penalty_multiplier 181 | samples_delay_manager.add_item((delay_penalty, x_sample_buffer, y, sample_file), delay=classification_delay_timesteps) 182 | confusion_matrix[2] += 1 183 | 184 | elif (a == total_classes): # Asked for classification 185 | if(job_manager.analysts_load() >= full_analyst_load): 186 | reward = wrong_classification_penalty 187 | confusion_matrix[5] += 1 188 | else: 189 | job_manager.add_classification_job((sample_file, y)) 190 | reward = math.floor(analyst_load_penalty_multiplier * job_manager.analysts_load() + 0.1) 191 | confusion_matrix[3] += 1 192 | else: # Automatic prediction 193 | if (a != y): 194 | if(y == 0): 195 | reward = wrong_classification_penalty 196 | else: 197 | reward = wrong_classification_penalty * 15 198 | confusion_matrix[1] += 1 199 | else: 200 | confusion_matrix[0] += 1 201 | if(y != 0): 202 | reward = 1 203 | 204 | reward_sampling.append(reward) 205 | max_q = max(qp1) 206 | if(a_m1 != None and is_training_phase): 207 | analyst_expr_manager.add_item( 208 | np.array([x_m1, a_m1, r_m1, max_q]), delay=0) 209 | x_m1 = sample_for_policy 210 | a_m1 = a 211 | r_m1 = reward 212 | 213 | # Train the policy 214 | if( (analyst_expr_manager.number_of_batches() >= batch_size) and is_training_phase): 215 | expr_matrix = np.reshape(analyst_expr_manager.create_batch(batch_size), [-1, 4]) 216 | x_matrix = np.reshape(np.vstack(expr_matrix[:, 0]), 217 | [batch_size, samples_per_batch, total_input_size]) 218 | action_matrix = np.reshape(expr_matrix[:, 1], [-1, samples_per_batch]) 219 | q_calc_expr = expr_matrix[:, 2] + gamma * expr_matrix[:, 3] 220 | q_calc_expr = np.reshape(q_calc_expr, [-1, samples_per_batch]) 221 | 222 | state_t, _, loss = sess.run([rlAgent.policy_rnn_state, rlAgent.updateModel, rlAgent.loss], 223 | feed_dict={rlAgent.X: x_matrix, rlAgent.actions: action_matrix, 224 | rlAgent.Q_calculation: q_calc_expr, rlAgent.policy_state_in: policy_training_state}) 225 | policy_training_state = policy_online_state 226 | loss_agg.append(loss) 227 | 228 | 229 | if(a != total_classes + 1): #advance timestamp count if sample is not delayed 230 | job_manager.advance_time() 231 | if(not samples_delay_manager.is_waiting_sample()): 232 | samples_delay_manager.advance_state() 233 | if(not is_delayed_sample): 234 | sample += 1 235 | 236 | sample_per_batch_agg.append(sample_counter) 237 | reward_agg.append(np.sum(reward_sampling)) 238 | reward_sampling = [] 239 | 240 | 241 | #Statistics 242 | if(i % statistics_siplay_step == 0 and i!=0): 243 | print "Batch " + str(i) + " finished after " + str(time.time() - t0) + " seconds" 244 | print "Average cycle reward is: " + str(np.sum(reward_agg)/statistics_siplay_step) 245 | print("averaged accuracy since last print is " + str(np.average(accuracy_agg))) 246 | if(is_training_phase): 247 | print("averaged loss since last print " + str(np.average(loss_agg))) 248 | print("confusion matrix: " + str(confusion_matrix / (batch_size * statistics_siplay_step))) 249 | print("Average number of samples: new: {}, total: {}".format(samples_per_batch, np.average(sample_per_batch_agg))) 250 | 251 | #Clear the counters 252 | confusion_matrix = np.zeros(6) 253 | accuracy_agg = [] 254 | sample_per_batch_agg = [] 255 | loss_agg = [] 256 | reward_agg = [] 257 | 258 | if i == training_batches: 259 | eps = 0.0 260 | is_training_phase = False 261 | 262 | if i == training_batches*2: 263 | print("Finished full system training, execution time " + str(time.time() - t0) + " seconds") 264 | exit(1) 265 | 266 | 267 | 268 | 269 | if __name__ == '__main__': 270 | if(len(sys.argv) != 2): 271 | print("train_derol should be called with image root folder as parameter") 272 | exit(1) 273 | benchmark(sys.argv[1:]) --------------------------------------------------------------------------------