├── .dvc ├── .gitignore └── config ├── .gitignore ├── LICENSE ├── README.md ├── board_view.py ├── constants.py ├── create_datasets.py ├── data.dvc ├── incorrect_roi_images.p ├── models ├── .gitignore ├── base_model.py ├── cnn_2d.py ├── cnn_3d.py ├── mvcnn_cnn.py ├── mvcnn_xception.py ├── saved_models.dvc ├── vcnn1.py ├── vcnn2.py ├── vgg19.py ├── vgg19_3d.py ├── xception.py ├── xception_3d.py └── xception_kapp.py ├── non_defective_xml_files.dvc ├── result_images.dvc ├── run_models.py ├── run_models_kapp_integrated.py ├── solder_joint.py ├── solder_joint_container.py ├── solder_joint_container_obj.p ├── utils_basic.py └── utils_datagen.py /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /updater 3 | /state-journal 4 | /state-wal 5 | /state 6 | /lock 7 | /tmp 8 | /updater.lock 9 | /cache 10 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | ['remote "myremote"'] 2 | url = ..\dvc-storage 3 | [core] 4 | remote = myremote 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't track content of these folders 2 | images_roi_marked 3 | incorrect_roi_images 4 | original_dataset 5 | .idea 6 | __pycache__ 7 | 8 | 9 | # Compiled source # 10 | ################### 11 | *.com 12 | *.class 13 | *.dll 14 | *.exe 15 | *.o 16 | *.so 17 | 18 | # File types # 19 | ############## 20 | *.jpg 21 | *.pickle 22 | *.pkl 23 | *.xml 24 | *.csv 25 | *.npy 26 | *.numpy 27 | 28 | # Packages # 29 | ############ 30 | # it's better to unpack these files and commit the raw source 31 | # git has its own built in compression methods 32 | *.7z 33 | *.dmg 34 | *.gz 35 | *.iso 36 | *.jar 37 | *.rar 38 | *.tar 39 | *.zip 40 | /data 41 | /non_defective_xml_files 42 | /result_images 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 chinthysl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AXI_PCB_defect_detection 2 | 3 | This repo contains data pre-processing, classification and defect detection methodologies for images from **Advance XRay Inspection** from multi-layer PCB boards. 4 | Please go through our research paper if you need more details and be kind enough to Cite it if you are going to use this implementation. 5 | - https://ieeexplore.ieee.org/abstract/document/9442142/?casa_token=FNPEvLNODrgAAAAA:Lwm3mFCDg-BgJiSLl1uhefLUv_ApdkNBMbwECzTi1KEGnGX1PohgRLILGQKf3l7Dugr-vuQ7gDZdt4U 6 | - @inproceedings{zhang2020deep, 7 | title={Deep Learning Based Defect Detection for Solder Joints on Industrial X-Ray Circuit Board Images}, 8 | author={Zhang, Qianru and Zhang, Meng and Gamanayake, Chinthaka and Yuen, Chau and Geng, Zehao and Jayasekaraand, Hirunima and Zhang, Xuewen and Woo, Chia-wei and Low, Jenny and Liu, Xiang}, 9 | booktitle={2020 IEEE 18th International Conference on Industrial Informatics (INDIN)}, 10 | volume={1}, 11 | pages={74--79}, 12 | year={2020}, 13 | organization={IEEE} 14 | } 15 | 16 | - [AXI_PCB_defect_detection](#AXI_PCB_defect_detection) 17 | * [Details of the dataset](#general-guidelines) 18 | * [General guidelines for contributors](#general-guidelines) 19 | * [Project file structure](#project-file-structure) 20 | * [Code explanation](#code-explanation) 21 | + [solder_joint.py](#solder_joint) 22 | + [board_view.py](#board_view) 23 | + [solder_joint_container.py](#solder_joint_container) 24 | + [main.py](#main) 25 | * [DVC integration](#dvc-integration) 26 | 27 | ## Details of the proprietory dataset 28 | - Distinct .jpg images including all slices – 32377 29 | - Distinct ROIs with labels – 92208 30 | - PCB types – 17 31 | - Distinct XRay board views – 8063 32 | - Views containing correct ROIs – 6112, Views containing incorrect ROIs – 1951 33 | - Slices per XRay board view (slices per physical solder joint) – 3, 4, 5, 6 34 | - Number of solder joints – 22872 35 | - Correct joints – 15613, Incorrect joints – 7259 36 | - Correct-square solder joints – 14672 37 | - missing - 1496, short - 5605, insufficient - 4691 normal - 2880 38 | - 3 sliced joints - 5471, 4 sliced joints - 8590, 5 sliced joints - 97, 6 sliced joints - 277 39 | 40 | #### Special Notes: 41 | - A single XRay board view has multiple ROIs marked. 42 | - A single defective ROI(solder joint) can have multiple defect labels. 43 | - If a XRay view contains one incorrect ROI, all solder joints in that view are considered as incorrect. 44 | 45 | 46 | ## General Guidelines for contributors 47 | - Please add your data folders into the ignore file until DVC is setup 48 | - Don't do `git add --all`. Please be aware of what you commit. 49 | 50 | ## Project file structure 51 | 52 | ```bash 53 | ├── non_defective_xml_files/ # put xml labels for non defective rois here 54 | ├── original_dataset/ # put two image folders and PTH2_reviewed.csv inside this 55 | ├── board_view.py # python class for a unique PCB BoardView 56 | ├── constants.py # constant values for the project 57 | ├── incorrect_roi_images.p # this file contains file names of all incorrect rois 58 | ├── main.py # script for creating objects and run ROI concatenation 59 | ├── solder_joint.py # python class for a unique PCB SolderJoint 60 | ├── solder_joint_container.py # python class containing all BoardView and SolderJoint objs 61 | ├── solder_joint_container_obj.p # saved pickle obj for SolderJointContainer class 62 | └── utils_basics.py # helper functions 63 | ``` 64 | 65 | ## Code explanation 66 | 67 | ### solder_joint.py 68 | 69 | ``` 70 | +-----------------------------------------------------------------------+ 71 | | SolderJoint | 72 | +-----------------------------------------------------------------------+ 73 | |+ self.component_name | 74 | |+ self.defect_id | 75 | |+ self.defect_name | 76 | |+ self.roi | 77 | |+ self.is_square | 78 | |+ self.slice_dict | 79 | +-----------------------------------------------------------------------+ 80 | |+__init__(self, component_name, defect_id, defect_name, roi) | 81 | |+add_slice(self, slice_id, file_location) | 82 | |+concat_first_four_slices_and_resize(self, width, height) | 83 | +-----------------------------------------------------------------------+ 84 | ``` 85 | 86 | - `add_slice(self, slice_id, file_location)`: 87 | 88 | This method will add the slice id and corresponding image location of that slice to the `self.slice_dict`. 89 | 90 | - `concat_first_four_slices_and_resize(self, width, height)`: 91 | 92 | In this method only 1st 4 slices from ids 0,1,2,3 are concatenated in a 2d square shape. Concatenated 2d image and the label of the joint is returned. You have to catch them in `SolderJointContainer` object and save to disk accordingly. 93 | 94 | If you want to write a new method of channel concatenation please modify this method. (For an example, if you want to generate gray scale image per slice and concatenate them as a pickle or numpy file) 95 | 96 | ### board_view.py 97 | 98 | ``` 99 | +-----------------------------------------------------------------------+ 100 | | BoardView | 101 | +-----------------------------------------------------------------------+ 102 | |+ self.view_identifier | 103 | |+ self.is_incorrect_view | 104 | |+ self.solder_joint_dict | 105 | |+ self.slice_dict | 106 | +-----------------------------------------------------------------------+ 107 | |+__init__(self, view_identifier) | 108 | |+add_solder_joint(self, component, defect_id, defect_name, roi) | 109 | |+add_slice(self, file_location) | 110 | |+add_slices_to_solder_joints(self) | 111 | +-----------------------------------------------------------------------+ 112 | ``` 113 | 114 | - This class contains all the SolderJoint objects and all the slices details regarding that board view of the PCB. 115 | 116 | ### solder_joint_container.py 117 | 118 | ``` 119 | +-----------------------------------------------------------------------+ 120 | | SolderJointContainer | 121 | +-----------------------------------------------------------------------+ 122 | |+ self.board_view_dict | 123 | |+ self.new_image_name_mapping_dict | 124 | |+ self.csv_details_dict | 125 | |+ self.incorrect_board_view_ids | 126 | +-----------------------------------------------------------------------+ 127 | |+__init__(self) | 128 | |+mark_all_images_with_rois(self) | 129 | |+load_non_defective_rois_to_container(self) | 130 | |+find_flag_incorrect_roi_board_view_objs(self) | 131 | |+save_concat_images_first_four_slices(self, width=128, height=128) | 132 | |+print_container_details(self) | 133 | |+write_csv_defect_roi_images(self) | 134 | +-----------------------------------------------------------------------+ 135 | ``` 136 | 137 | - `save_concat_images_first_four_slices(self, width=128, height=128)`: 138 | 139 | In this method only 1st 4 slices from ids 0,1,2,3 are concatenated in a 2d square shape. Concatenated 2d image and the label of the joints returned from the `SolderJoint` object is saved to the disk accordingly. 140 | 141 | If you want to write a new method of channel concatenation please modify this method. 142 | 143 | ### main.py 144 | 145 | `SolderJointContainer ` object is created inside this script. Then using that object we can call the required method to generate concatenated images. 146 | 147 | ## DVC integration 148 | 149 | This section will explains about setting up the data version controlling and storage for the project. Data version control is integrated in this project. You can simple train any model by changing parameters of the neural network models. After training you'll see changes in data folders. DVC commit those changes followed by a git commit. Trained models and data will be tracked by a local DVC storage. 150 | -------------------------------------------------------------------------------- /board_view.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from solder_joint import SolderJoint 4 | 5 | 6 | class BoardView: 7 | def __init__(self, view_identifier): 8 | self.view_identifier = view_identifier 9 | self.solder_joint_dict = {} 10 | self.slice_dict = {} 11 | self.is_incorrect_view = False 12 | 13 | logging.info('BoardView obj created for view id: %s', self.view_identifier) 14 | 15 | def add_solder_joint(self, component, defect_id, defect_name, roi): 16 | my_tupple = tuple([roi[0], roi[1], roi[2], roi[3], defect_id]) 17 | if my_tupple in self.solder_joint_dict.keys(): 18 | logging.info('ROI+Defect found inside the solder_joint_dict, won\'t add a new joint') 19 | else: 20 | logging.info('Adding new SolderJoint obj for the new ROI+Defect') 21 | self.solder_joint_dict[my_tupple] = SolderJoint(component, defect_id, defect_name, roi) 22 | 23 | def add_slice(self, file_location): 24 | slice_id = int(file_location[-5]) 25 | self.slice_dict[slice_id] = file_location 26 | for solder_joint_obj in self.solder_joint_dict.values(): 27 | solder_joint_obj.add_slice(slice_id, file_location) 28 | 29 | def add_slices_to_solder_joints(self): 30 | for slice_id in self.slice_dict.keys(): 31 | file_location = self.slice_dict[slice_id] 32 | for solder_joint_obj in self.solder_joint_dict.values(): 33 | solder_joint_obj.add_slice(slice_id, file_location) 34 | 35 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | DEFECT_NAMES_DICT = {89: "missing", 164: "insufficient", 145: "short", 0: "normal"} 2 | -------------------------------------------------------------------------------- /create_datasets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | from solder_joint_container import SolderJointContainer 5 | 6 | logging.basicConfig(level=logging.DEBUG) 7 | 8 | # if you don't have the pickle file for the container obj make CREATE_OBJ = True 9 | CREATE_OBJ = False 10 | if CREATE_OBJ: 11 | # creating solder joint object will read the data set and create all other objects 12 | solder_joint_container_obj = SolderJointContainer() 13 | # from a single slice of incorrect roi images whole board view objects will be marked 14 | solder_joint_container_obj.find_flag_incorrect_roi_board_view_objs() 15 | # load non defective solder joint info to the container 16 | solder_joint_container_obj.load_non_defective_rois_to_container() 17 | 18 | with open('solder_joint_container_obj.p', 'wb') as output_handle: 19 | pickle.dump(solder_joint_container_obj, output_handle, pickle.HIGHEST_PROTOCOL) 20 | 21 | # ********************************************************************************************************************** 22 | with open('solder_joint_container_obj.p', 'rb') as input_handle: 23 | solder_joint_container_obj = pickle.load(input_handle) 24 | 25 | # # # this method will create a directory full of images with marked rois 26 | # solder_joint_container_obj.mark_all_images_with_rois() 27 | 28 | # # methods to create extracted roi joint image or pickle data sets 29 | # # solder_joint_container_obj.save_concat_images_first_four_slices_2d() 30 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_pickle() 31 | # solder_joint_container_obj.save_concat_images_first_four_slices_3d_pickle() 32 | # solder_joint_container_obj.save_concat_images_all_slices_2d_pickle() 33 | # solder_joint_container_obj.save_concat_images_all_slices_3d_pickle() 34 | # solder_joint_container_obj.save_concat_images_all_slices_inverse_3d_pickle() 35 | # solder_joint_container_obj.save_concat_images_first_four_slices_list_rgb_pickle() 36 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_rotated_pickle() 37 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_more_normal_pickle() 38 | solder_joint_container_obj.save_concat_images_first_four_slices_list_more_normal_pickle() 39 | 40 | # solder_joint_container_obj.save_4slices_individual_pickle() 41 | 42 | # # print details of BoardView objects and SolderJoint objects 43 | # solder_joint_container_obj.print_solder_joint_resolution_details() 44 | # solder_joint_container_obj.print_container_details() 45 | 46 | -------------------------------------------------------------------------------- /data.dvc: -------------------------------------------------------------------------------- 1 | md5: 2e10abaf47729002df92f7752d3fe3f2 2 | outs: 3 | - md5: f52d9804b9765aae3fee0e645ffbccb7.dir 4 | path: data 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /incorrect_roi_images.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chinthysl/AXI_PCB_defect_detection/fe81c1d8e144ce5434aee78548cdc026d0d53d1c/incorrect_roi_images.p -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | /saved_models 2 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import seaborn as sn 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | import keras 7 | from sklearn.metrics import classification_report, confusion_matrix 8 | from abc import ABC, abstractmethod 9 | from utils_basic import save_logs 10 | 11 | 12 | class BaseModel(ABC): 13 | @abstractmethod 14 | def build_model(self, input_shape, n_classes): 15 | pass 16 | 17 | def fit(self, training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs): 18 | 19 | start_time = time.time() 20 | hist = self.model.fit_generator(generator=training_generator, 21 | validation_data=testing_generator, 22 | samples_per_epoch=samples_per_train_epoch, 23 | validation_steps=samples_per_test_epoch, 24 | epochs=epochs, 25 | callbacks=self.callbacks, 26 | use_multiprocessing=False, 27 | workers=1, 28 | verbose=self.verbose) 29 | duration = time.time() - start_time 30 | 31 | model = keras.models.load_model(self.output_directory + '/best_model.hdf5') 32 | test_loss, test_acc = model.evaluate_generator(testing_generator, steps=samples_per_test_epoch) 33 | print('test_loss:', test_loss, 'test_acc:', test_acc) 34 | 35 | y_pred = model.predict_generator(testing_generator, steps=samples_per_test_epoch) 36 | y_true = testing_generator.gen_class_labels() 37 | save_logs(self.output_directory, hist, y_pred, y_true, duration, lr=False) 38 | 39 | keras.backend.clear_session() 40 | 41 | def predict(self, testing_generator, samples_per_test_epoch): 42 | model = keras.models.load_model(self.output_directory + '/best_model.hdf5') 43 | test_loss, test_acc = model.evaluate_generator(testing_generator, steps=samples_per_test_epoch) 44 | print('test_loss:', test_loss, 'test_acc:', test_acc) 45 | y_pred = model.predict_generator(testing_generator, steps=samples_per_test_epoch) 46 | 47 | y_pred = np.argmax(y_pred, axis=1) 48 | y_true = testing_generator.gen_class_labels() 49 | print('Confusion Matrix') 50 | 51 | target_names = testing_generator.integer_mapping_dict.keys() 52 | mat = confusion_matrix(y_true, y_pred) 53 | print(mat) 54 | df_cm = pd.DataFrame(mat, index=target_names, columns=target_names) 55 | plt.figure() 56 | sn.heatmap(df_cm, annot=True, cmap='Blues', fmt='g') 57 | plt.show() 58 | 59 | print('Classification Report') 60 | print(classification_report(y_true, y_pred, target_names=target_names)) 61 | 62 | keras.backend.clear_session() 63 | 64 | 65 | -------------------------------------------------------------------------------- /models/cnn_2d.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam, Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class CNN2D(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/cnn_2d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | ## input layer 26 | input_layer = Input(input_shape) 27 | 28 | ## convolutional layers 29 | conv_layer1 = Conv2D(filters=8, kernel_size=(3, 3), activation='relu')(input_layer) 30 | conv_layer2 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(conv_layer1) 31 | 32 | ## add max pooling to obtain the most imformatic features 33 | pooling_layer1 = MaxPool2D(pool_size=(2, 2))(conv_layer2) 34 | 35 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1) 36 | conv_layer4 = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(conv_layer3) 37 | pooling_layer2 = MaxPool2D(pool_size=(2, 2))(conv_layer4) 38 | 39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture 40 | pooling_layer2 = BatchNormalization()(pooling_layer2) 41 | flatten_layer = Flatten()(pooling_layer2) 42 | 43 | ## create an MLP architecture with dense layers : 4096 -> 512 -> 10 44 | ## add dropouts to avoid overfitting / perform regularization 45 | dense_layer1 = Dense(units=2048, activation='relu')(flatten_layer) 46 | dense_layer1 = Dropout(0.4)(dense_layer1) 47 | dense_layer2 = Dense(units=512, activation='relu')(dense_layer1) 48 | dense_layer2 = Dropout(0.4)(dense_layer2) 49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2) 50 | 51 | ## define the model with input layer and output layer 52 | model = Model(inputs=input_layer, outputs=output_layer) 53 | model.summary() 54 | 55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 56 | 57 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(), metrics=['acc']) 58 | 59 | # model save 60 | file_path = self.output_directory + '/best_model.hdf5' 61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 62 | 63 | # Tensorboard log 64 | log_dir = self.output_directory + '/tf_logs' 65 | chk_n_mkdir(log_dir) 66 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 67 | 68 | self.callbacks = [model_checkpoint, tb_cb] 69 | return model 70 | 71 | 72 | -------------------------------------------------------------------------------- /models/cnn_3d.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv3D, MaxPool3D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam, Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class CNN3D(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/cnn_3d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | ## input layer 26 | input_layer = Input(input_shape) 27 | 28 | ## convolutional layers 29 | conv_layer1 = Conv3D(filters=8, kernel_size=(3, 3, 2), activation='relu')(input_layer) 30 | conv_layer2 = Conv3D(filters=16, kernel_size=(3, 3, 2), activation='relu')(conv_layer1) 31 | 32 | ## add max pooling to obtain the most imformatic features 33 | pooling_layer1 = MaxPool3D(pool_size=(2, 2, 2))(conv_layer2) 34 | 35 | conv_layer3 = Conv3D(filters=32, kernel_size=(3, 3, 1), activation='relu')(pooling_layer1) 36 | conv_layer4 = Conv3D(filters=64, kernel_size=(3, 3, 1), activation='relu')(conv_layer3) 37 | pooling_layer2 = MaxPool3D(pool_size=(2, 2, 1))(conv_layer4) 38 | 39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture 40 | pooling_layer2 = BatchNormalization()(pooling_layer2) 41 | flatten_layer = Flatten()(pooling_layer2) 42 | 43 | ## create an MLP architecture with dense layers : 4096 -> 512 -> 10 44 | ## add dropouts to avoid overfitting / perform regularization 45 | dense_layer1 = Dense(units=2048, activation='relu')(flatten_layer) 46 | dense_layer1 = Dropout(0.4)(dense_layer1) 47 | dense_layer2 = Dense(units=512, activation='relu')(dense_layer1) 48 | dense_layer2 = Dropout(0.4)(dense_layer2) 49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2) 50 | 51 | ## define the model with input layer and output layer 52 | model = Model(inputs=input_layer, outputs=output_layer) 53 | model.summary() 54 | 55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 56 | 57 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(), metrics=['acc']) 58 | 59 | # model save 60 | file_path = self.output_directory + '/best_model.hdf5' 61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 62 | 63 | # Tensorboard log 64 | log_dir = self.output_directory + '/tf_logs' 65 | chk_n_mkdir(log_dir) 66 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 67 | 68 | self.callbacks = [model_checkpoint, tb_cb] 69 | return model 70 | 71 | -------------------------------------------------------------------------------- /models/mvcnn_cnn.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, concatenate, Maximum 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam, Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class MVCNN_CNN(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/mvcnn_cnn' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def cnn(self, input_shape): 25 | # ## input layer 26 | input_layer = Input(input_shape) 27 | 28 | ## convolutional layers 29 | conv_layer1 = Conv2D(filters=8, kernel_size=(3, 3), activation='relu')(input_layer) 30 | conv_layer2 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(conv_layer1) 31 | 32 | ## add max pooling to obtain the most imformatic features 33 | pooling_layer1 = MaxPooling2D(pool_size=(2, 2))(conv_layer2) 34 | 35 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1) 36 | conv_layer4 = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(conv_layer3) 37 | pooling_layer2 = MaxPooling2D(pool_size=(2, 2))(conv_layer4) 38 | 39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture 40 | pooling_layer2 = BatchNormalization()(pooling_layer2) 41 | flatten_layer = Flatten()(pooling_layer2) 42 | 43 | # input layer 44 | # input_layer = Input(input_shape) 45 | # conv_layer1 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(input_layer) 46 | # conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer1) 47 | # pooling_layer1 = MaxPooling2D(pool_size=(4, 4))(conv_layer2) 48 | # # dropout_layer1 = Dropout(0.25)(pooling_layer1) 49 | # dropout_layer1 =BatchNormalization()(pooling_layer1) 50 | # flatten_layer = Flatten()(dropout_layer1) 51 | 52 | # Create model. 53 | model = Model(input_layer, flatten_layer) 54 | return model 55 | 56 | def build_model(self, input_shape, n_classes): 57 | cnn_1 = self.cnn(input_shape) 58 | input_1 = cnn_1.input 59 | output_1 = cnn_1.output 60 | cnn_2 = self.cnn(input_shape) 61 | input_2 = cnn_2.input 62 | output_2 = cnn_2.output 63 | cnn_3 = self.cnn(input_shape) 64 | input_3 = cnn_3.input 65 | output_3 = cnn_3.output 66 | cnn_4 = self.cnn(input_shape) 67 | input_4 = cnn_4.input 68 | output_4 = cnn_4.output 69 | 70 | concat_layer = Maximum()([output_1, output_2, output_3, output_4]) 71 | concat_layer = Dropout(0.5)(concat_layer) 72 | dense_layer1 = Dense(units=1024, activation='relu')(concat_layer) 73 | dense_layer1 = Dropout(0.5)(dense_layer1) 74 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(dense_layer1) 75 | 76 | model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_layer]) 77 | model.summary() 78 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 79 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.01), metrics=['accuracy']) 80 | 81 | # model save 82 | file_path = self.output_directory + '/best_model.hdf5' 83 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 84 | 85 | # Tensorboard log 86 | log_dir = self.output_directory + '/tf_logs' 87 | chk_n_mkdir(log_dir) 88 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 89 | 90 | self.callbacks = [model_checkpoint, tb_cb] 91 | return model 92 | 93 | -------------------------------------------------------------------------------- /models/mvcnn_xception.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import os 3 | 4 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalMaxPooling2D, Flatten, Dense, concatenate, Maximum 5 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D 6 | from keras.losses import categorical_crossentropy 7 | from keras.optimizers import Adam 8 | from keras.models import Model 9 | from keras.utils import plot_model 10 | import keras.backend as backend 11 | import keras.utils as keras_utils 12 | 13 | from utils_datagen import TrainValTensorBoard 14 | from utils_basic import chk_n_mkdir 15 | from models.base_model import BaseModel 16 | 17 | 18 | class MVCNN_XCEPTION(BaseModel): 19 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 20 | self.output_directory = output_directory + '/mvcnn_xception' 21 | chk_n_mkdir(self.output_directory) 22 | self.model = self.build_model(input_shape, n_classes) 23 | if verbose: 24 | self.model.summary() 25 | self.verbose = verbose 26 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 27 | 28 | def xception(self, include_top=True, weights='imagenet', input_tensor=None, pooling=None, classes=1000): 29 | 30 | TF_WEIGHTS_PATH = ( 31 | 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/' 32 | 'xception_weights_tf_dim_ordering_tf_kernels.h5') 33 | TF_WEIGHTS_PATH_NO_TOP = ( 34 | 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/' 35 | 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5') 36 | 37 | if not (weights in {'imagenet', None} or os.path.exists(weights)): 38 | raise ValueError('The `weights` argument should be either ' 39 | '`None` (random initialization), `imagenet` ' 40 | '(pre-training on ImageNet), ' 41 | 'or the path to the weights file to be loaded.') 42 | 43 | if weights == 'imagenet' and include_top and classes != 1000: 44 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 45 | ' as true, `classes` should be 1000') 46 | 47 | # Determine proper input shape 48 | input_shape = (299, 299, 3) 49 | 50 | if input_tensor is None: 51 | img_input = Input(shape=input_shape) 52 | else: 53 | if not backend.is_keras_tensor(input_tensor): 54 | img_input = Input(tensor=input_tensor, shape=input_shape) 55 | else: 56 | img_input = input_tensor 57 | 58 | channel_axis = -1 59 | 60 | x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input) 61 | x = BatchNormalization(axis=channel_axis)(x) 62 | x = Activation('relu')(x) 63 | x = Conv2D(64, (3, 3), use_bias=False)(x) 64 | x = BatchNormalization(axis=channel_axis)(x) 65 | x = Activation('relu')(x) 66 | 67 | residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 68 | residual = BatchNormalization(axis=channel_axis)(residual) 69 | 70 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x) 71 | x = BatchNormalization(axis=channel_axis)(x) 72 | x = Activation('relu')(x) 73 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x) 74 | x = BatchNormalization(axis=channel_axis)(x) 75 | 76 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 77 | x = add([x, residual]) 78 | 79 | residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 80 | residual = BatchNormalization(axis=channel_axis)(residual) 81 | 82 | x = Activation('relu')(x) 83 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x) 84 | x = BatchNormalization(axis=channel_axis)(x) 85 | x = Activation('relu')(x) 86 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x) 87 | x = BatchNormalization(axis=channel_axis)(x) 88 | 89 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 90 | x = add([x, residual]) 91 | 92 | residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 93 | residual = BatchNormalization(axis=channel_axis)(residual) 94 | 95 | x = Activation('relu')(x) 96 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 97 | x = BatchNormalization(axis=channel_axis)(x) 98 | x = Activation('relu')(x) 99 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 100 | x = BatchNormalization(axis=channel_axis)(x) 101 | 102 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 103 | x = add([x, residual]) 104 | 105 | for i in range(8): 106 | residual = x 107 | 108 | x = Activation('relu')(x) 109 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 110 | x = BatchNormalization(axis=channel_axis)(x) 111 | x = Activation('relu')(x) 112 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 113 | x = BatchNormalization(axis=channel_axis)(x) 114 | x = Activation('relu')(x) 115 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 116 | x = BatchNormalization(axis=channel_axis)(x) 117 | 118 | x = add([x, residual]) 119 | 120 | residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 121 | residual = BatchNormalization(axis=channel_axis)(residual) 122 | 123 | x = Activation('relu')(x) 124 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x) 125 | x = BatchNormalization(axis=channel_axis)(x) 126 | x = Activation('relu')(x) 127 | x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x) 128 | x = BatchNormalization(axis=channel_axis)(x) 129 | 130 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 131 | x = add([x, residual]) 132 | 133 | x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x) 134 | x = BatchNormalization(axis=channel_axis)(x) 135 | x = Activation('relu')(x) 136 | 137 | x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x) 138 | x = BatchNormalization(axis=channel_axis)(x) 139 | x = Activation('relu')(x) 140 | 141 | if include_top: 142 | x = GlobalAveragePooling2D(name='avg_pool')(x) 143 | x = Dense(classes, activation='softmax', name='predictions')(x) 144 | else: 145 | if pooling == 'avg': 146 | x = GlobalAveragePooling2D()(x) 147 | elif pooling == 'max': 148 | x = GlobalMaxPooling2D()(x) 149 | 150 | # Ensure that the model takes into account 151 | # any potential predecessors of `input_tensor`. 152 | if input_tensor is not None: 153 | inputs = keras_utils.get_source_inputs(input_tensor) 154 | else: 155 | inputs = img_input 156 | # Create model. 157 | model = Model(inputs, x) 158 | 159 | # Load weights. 160 | if weights == 'imagenet': 161 | if include_top: 162 | weights_path = keras_utils.get_file( 163 | 'xception_weights_tf_dim_ordering_tf_kernels.h5', 164 | TF_WEIGHTS_PATH, 165 | cache_subdir='models', 166 | file_hash='0a58e3b7378bc2990ea3b43d5981f1f6') 167 | else: 168 | weights_path = keras_utils.get_file( 169 | 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5', 170 | TF_WEIGHTS_PATH_NO_TOP, 171 | cache_subdir='models', 172 | file_hash='b0042744bf5b25fce3cb969f33bebb97') 173 | model.load_weights(weights_path) 174 | if backend.backend() == 'theano': 175 | keras_utils.convert_all_kernels_in_model(model) 176 | elif weights is not None: 177 | model.load_weights(weights) 178 | 179 | return model 180 | 181 | def build_model(self, input_shape, n_classes): 182 | xception_1 = self.xception(include_top=False, pooling='max') 183 | for layer in xception_1.layers: 184 | layer.trainable = False 185 | input_1 = xception_1.input 186 | output_1 = xception_1.output 187 | xception_2 = self.xception(include_top=False, pooling='max') 188 | for layer in xception_2.layers: 189 | layer.trainable = False 190 | input_2 = xception_2.input 191 | output_2 = xception_2.output 192 | xception_3 = self.xception(include_top=False, pooling='max') 193 | for layer in xception_3.layers: 194 | layer.trainable = False 195 | input_3 = xception_3.input 196 | output_3 = xception_3.output 197 | xception_4 = self.xception(include_top=False, pooling='max') 198 | for layer in xception_4.layers: 199 | layer.trainable = False 200 | input_4 = xception_4.input 201 | output_4 = xception_4.output 202 | 203 | concat_layer = Maximum()([output_1, output_2, output_3, output_4]) 204 | concat_layer.trainable = False 205 | # concat_layer = Dropout(0.25)(concat_layer) 206 | # dense_layer1 = Dense(units=1024, activation='relu')(concat_layer) 207 | dense_layer1 = Dropout(0.5)(concat_layer) 208 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(dense_layer1) 209 | 210 | model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_layer]) 211 | model.summary() 212 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 213 | model.compile(loss=categorical_crossentropy, optimizer=Adam(lr=0.01), metrics=['acc']) 214 | 215 | # model save 216 | file_path = self.output_directory + '/best_model.hdf5' 217 | model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 218 | 219 | # Tensorboard log 220 | log_dir = self.output_directory + '/tf_logs' 221 | chk_n_mkdir(log_dir) 222 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 223 | 224 | self.callbacks = [model_checkpoint, tb_cb] 225 | return model 226 | 227 | 228 | -------------------------------------------------------------------------------- /models/saved_models.dvc: -------------------------------------------------------------------------------- 1 | md5: f2f2f88b356078e7b5147950ccf87b92 2 | outs: 3 | - md5: 1574ce19441fca2c5303c2c59e2e5bc8.dir 4 | path: saved_models 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /models/vcnn1.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class VCNN1(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/vcnn1_3d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | ## input layer 26 | input_layer = Input(input_shape) 27 | 28 | ## convolutional layers 29 | conv_layer1 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(input_layer) 30 | pooling_layer1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer1) 31 | 32 | conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1) 33 | 34 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer2) 35 | pooling_layer2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer3) 36 | dropout_layer =Dropout(0.5)(pooling_layer2) 37 | 38 | dense_layer = Dense(units=2048, activation='relu')(dropout_layer) 39 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer) 40 | 41 | ## define the model with input layer and output layer 42 | model = Model(inputs=input_layer, outputs=output_layer) 43 | model.summary() 44 | 45 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 46 | 47 | model.compile(loss=categorical_crossentropy, optimizer=Adam(), metrics=['acc']) 48 | 49 | # model save 50 | file_path = self.output_directory + '/best_model.hdf5' 51 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 52 | 53 | # Tensorboard log 54 | log_dir = self.output_directory + '/tf_logs' 55 | chk_n_mkdir(log_dir) 56 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 57 | 58 | self.callbacks = [model_checkpoint, tb_cb] 59 | return model 60 | 61 | -------------------------------------------------------------------------------- /models/vcnn2.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense, Concatenate, Activation 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class VCNN1(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/vcnn1_3d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | ## input layer 26 | input_layer = Input(input_shape) 27 | 28 | ## convolutional layers 29 | conv_layer1_1 = Conv2D(filters=16, kernel_size=(1, 1), activation=None, padding='same')(input_layer) 30 | conv_layer1_2 = Conv2D(filters=16, kernel_size=(3, 3), activation=None, padding='same')(input_layer) 31 | conv_layer1_3 = Conv2D(filters=16, kernel_size=(5, 5), activation=None, padding='same')(input_layer) 32 | concat_layer1 = Concatenate([conv_layer1_1, conv_layer1_2, conv_layer1_3]) 33 | activation_layer1 = Activation('relu')(concat_layer1) 34 | dropout_layer1 =Dropout(0.2)(activation_layer1) 35 | 36 | conv_layer2_1 = Conv2D(filters=16, kernel_size=(1, 1), activation=None, padding='same')(dropout_layer1) 37 | conv_layer2_2 = Conv2D(filters=16, kernel_size=(3, 3), activation=None, padding='same')(dropout_layer1) 38 | concat_layer2 = Concatenate([conv_layer2_1, conv_layer2_2]) 39 | activation_layer2 = Activation('relu')(concat_layer2) 40 | dropout_layer2 =Dropout(0.2)(activation_layer2) 41 | 42 | conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1) 43 | 44 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer2) 45 | pooling_layer2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer3) 46 | dropout_layer =Dropout(0.5)(pooling_layer2) 47 | 48 | dense_layer = Dense(units=2048, activation='relu')(dropout_layer) 49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer) 50 | 51 | ## define the model with input layer and output layer 52 | model = Model(inputs=input_layer, outputs=output_layer) 53 | model.summary() 54 | 55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 56 | 57 | model.compile(loss=categorical_crossentropy, optimizer=Adam(), metrics=['acc']) 58 | 59 | # model save 60 | file_path = self.output_directory + '/best_model.hdf5' 61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 62 | 63 | # Tensorboard log 64 | log_dir = self.output_directory + '/tf_logs' 65 | chk_n_mkdir(log_dir) 66 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 67 | 68 | self.callbacks = [model_checkpoint, tb_cb] 69 | return model 70 | 71 | -------------------------------------------------------------------------------- /models/vgg19.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class VGG19(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/vgg19' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + 'model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | # input layer 26 | input_layer = Input(input_shape) 27 | # Block 1 28 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(input_layer) 29 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 30 | x = MaxPool2D((2, 2), strides=(2, 2), name='block1_pool')(x) 31 | 32 | # Block 2 33 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 34 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 35 | x = MaxPool2D((2, 2), strides=(2, 2), name='block2_pool')(x) 36 | 37 | # Block 3 38 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 39 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 40 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 41 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x) 42 | x = MaxPool2D((2, 2), strides=(2, 2), name='block3_pool')(x) 43 | 44 | # Block 4 45 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 46 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 47 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 48 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x) 49 | x = MaxPool2D((2, 2), strides=(2, 2), name='block4_pool')(x) 50 | 51 | # Block 5 52 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 53 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 54 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 55 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x) 56 | x = MaxPool2D((2, 2), strides=(2, 2), name='block5_pool')(x) 57 | 58 | # Classification block 59 | x = Flatten(name='flatten')(x) 60 | x = Dense(4096, activation='relu', name='fc1')(x) 61 | x = Dense(4096, activation='relu', name='fc2')(x) 62 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x) 63 | 64 | ## define the model with input layer and output layer 65 | model = Model(inputs=input_layer, outputs=output_layer) 66 | model.summary() 67 | 68 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 69 | 70 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc']) 71 | 72 | # model save 73 | file_path = self.output_directory + '/best_model.hdf5' 74 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 75 | 76 | # Tensorboard log 77 | log_dir = self.output_directory + '/tf_logs' 78 | chk_n_mkdir(log_dir) 79 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 80 | 81 | self.callbacks = [model_checkpoint, tb_cb] 82 | return model 83 | -------------------------------------------------------------------------------- /models/vgg19_3d.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv3D, MaxPool3D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class VGG193D(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/vgg19_3d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + 'model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | # input layer 26 | input_layer = Input(input_shape) 27 | # Block 1 28 | x = Conv3D(64, (3, 3, 3), activation='relu', padding='same', name='block1_conv1')(input_layer) 29 | x = Conv3D(64, (3, 3, 3), activation='relu', padding='same', name='block1_conv2')(x) 30 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block1_pool')(x) 31 | 32 | # Block 2 33 | x = Conv3D(128, (3, 3, 3), activation='relu', padding='same', name='block2_conv1')(x) 34 | x = Conv3D(128, (3, 3, 3), activation='relu', padding='same', name='block2_conv2')(x) 35 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block2_pool')(x) 36 | 37 | # Block 3 38 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv1')(x) 39 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv2')(x) 40 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv3')(x) 41 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv4')(x) 42 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block3_pool')(x) 43 | 44 | # Block 4 45 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv1')(x) 46 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv2')(x) 47 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv3')(x) 48 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv4')(x) 49 | x = MaxPool3D((2, 2, 1), strides=(2, 2, 1), name='block4_pool')(x) 50 | 51 | # Block 5 52 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv1')(x) 53 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv2')(x) 54 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv3')(x) 55 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv4')(x) 56 | x = MaxPool3D((2, 2, 1), strides=(2, 2, 1), name='block5_pool')(x) 57 | 58 | # Classification block 59 | x = Flatten(name='flatten')(x) 60 | x = Dense(4096, activation='relu', name='fc1')(x) 61 | x = Dropout(0.4)(x) 62 | x = Dense(4096, activation='relu', name='fc2')(x) 63 | x = Dropout(0.4)(x) 64 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x) 65 | 66 | ## define the model with input layer and output layer 67 | model = Model(inputs=input_layer, outputs=output_layer) 68 | model.summary() 69 | 70 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 71 | 72 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc']) 73 | 74 | # model save 75 | file_path = self.output_directory + '/best_model.hdf5' 76 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 77 | 78 | # Tensorboard log 79 | log_dir = self.output_directory + '/tf_logs' 80 | chk_n_mkdir(log_dir) 81 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 82 | 83 | self.callbacks = [model_checkpoint, tb_cb] 84 | return model 85 | 86 | -------------------------------------------------------------------------------- /models/xception.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class XCEPTION(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/xception' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | # input layer 26 | input_layer = Input(input_shape) 27 | channel_axis = -1 # channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 28 | # Block 1 29 | x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(input_layer) 30 | x = BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x) 31 | x = Activation('relu', name='block1_conv1_act')(x) 32 | x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) 33 | x = BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x) 34 | x = Activation('relu', name='block1_conv2_act')(x) 35 | 36 | residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 37 | residual = BatchNormalization(axis=channel_axis)(residual) 38 | 39 | # Block 2 40 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x) 41 | x = BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x) 42 | x = Activation('relu', name='block2_sepconv2_act')(x) 43 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x) 44 | x = BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x) 45 | 46 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x) 47 | x = add([x, residual]) 48 | 49 | residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 50 | residual = BatchNormalization(axis=channel_axis)(residual) 51 | 52 | # Block 3 53 | x = Activation('relu', name='block3_sepconv1_act')(x) 54 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x) 55 | x = BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x) 56 | x = Activation('relu', name='block3_sepconv2_act')(x) 57 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x) 58 | x = BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x) 59 | 60 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x) 61 | x = add([x, residual]) 62 | 63 | residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 64 | residual = BatchNormalization(axis=channel_axis)(residual) 65 | 66 | # Block 4 67 | x = Activation('relu', name='block4_sepconv1_act')(x) 68 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x) 69 | x = BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x) 70 | x = Activation('relu', name='block4_sepconv2_act')(x) 71 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x) 72 | x = BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x) 73 | 74 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x) 75 | x = add([x, residual]) 76 | 77 | # Block 5-12 78 | for i in range(8): 79 | residual = x 80 | prefix = 'block' + str(i + 5) 81 | 82 | x = Activation('relu', name=prefix + '_sepconv1_act')(x) 83 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x) 84 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv1_bn')(x) 85 | x = Activation('relu', name=prefix + '_sepconv2_act')(x) 86 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x) 87 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv2_bn')(x) 88 | x = Activation('relu', name=prefix + '_sepconv3_act')(x) 89 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x) 90 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv3_bn')(x) 91 | x = add([x, residual]) 92 | 93 | residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 94 | residual = BatchNormalization(axis=channel_axis)(residual) 95 | 96 | # Block 13 97 | x = Activation('relu', name='block13_sepconv1_act')(x) 98 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x) 99 | x = BatchNormalization(axis=channel_axis, name='block13_sepconv1_bn')(x) 100 | x = Activation('relu', name='block13_sepconv2_act')(x) 101 | x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x) 102 | x = BatchNormalization(axis=channel_axis, name='block13_sepconv2_bn')(x) 103 | 104 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x) 105 | x = add([x, residual]) 106 | 107 | # Block 14 108 | x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x) 109 | x = BatchNormalization(axis=channel_axis, name='block14_sepconv1_bn')(x) 110 | x = Activation('relu', name='block14_sepconv1_act')(x) 111 | x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False,name='block14_sepconv2')(x) 112 | x = BatchNormalization(axis=channel_axis, name='block14_sepconv2_bn')(x) 113 | x = Activation('relu', name='block14_sepconv2_act')(x) 114 | 115 | # Classification block 116 | x = GlobalAveragePooling2D(name='avg_pool')(x) 117 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x) 118 | 119 | # define the model with input layer and output layer 120 | model = Model(inputs=input_layer, outputs=output_layer) 121 | model.summary() 122 | 123 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 124 | 125 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc']) 126 | 127 | # model save 128 | file_path = self.output_directory + '/best_model.hdf5' 129 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 130 | 131 | # Tensorboard log 132 | log_dir = self.output_directory + '/tf_logs' 133 | chk_n_mkdir(log_dir) 134 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 135 | 136 | self.callbacks = [model_checkpoint, tb_cb] 137 | return model 138 | 139 | -------------------------------------------------------------------------------- /models/xception_3d.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv3D, MaxPooling3D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling3D 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adadelta 5 | from keras.models import Model 6 | from keras.utils import plot_model 7 | from keras import callbacks 8 | 9 | from utils_datagen import TrainValTensorBoard 10 | from utils_basic import chk_n_mkdir 11 | from models.base_model import BaseModel 12 | 13 | 14 | class XCEPTION3D(BaseModel): 15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 16 | self.output_directory = output_directory + '/xception_3d' 17 | chk_n_mkdir(self.output_directory) 18 | self.model = self.build_model(input_shape, n_classes) 19 | if verbose: 20 | self.model.summary() 21 | self.verbose = verbose 22 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 23 | 24 | def build_model(self, input_shape, n_classes): 25 | # input layer 26 | input_layer = Input(input_shape) 27 | channel_axis = -1 # channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 28 | 29 | # Block 1 30 | x = Conv3D(8, (3, 3, 3), use_bias=False, name='block1_conv1')(input_layer) 31 | x = BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x) 32 | x = Activation('relu', name='block1_conv1_act')(x) 33 | x = Conv3D(8, (3, 3, 2), use_bias=False, name='block1_conv2')(x) 34 | x = BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x) 35 | x = Activation('relu', name='block1_conv2_act')(x) 36 | 37 | residual = Conv3D(16, (1, 1, 1), strides=(2, 2, 1), padding='same', use_bias=False)(x) 38 | residual = BatchNormalization(axis=channel_axis)(residual) 39 | 40 | # Block 2 41 | x = Conv3D(16, (3, 3, 1), padding='same', use_bias=False, name='block2_conv1')(x) 42 | x = BatchNormalization(axis=channel_axis, name='block2_conv1_bn')(x) 43 | 44 | x = MaxPooling3D((3, 3, 1), strides=(2, 2, 1), padding='same', name='block2_pool')(x) 45 | x = add([x, residual]) 46 | 47 | residual = Conv3D(32, (1, 1, 1), strides=(2, 2, 1), padding='same', use_bias=False)(x) 48 | residual = BatchNormalization(axis=channel_axis)(residual) 49 | 50 | # Block 3 51 | x = Activation('relu', name='block3_conv1_act')(x) 52 | x = Conv3D(32, (3, 3, 1), padding='same', use_bias=False, name='block3_conv1')(x) 53 | x = BatchNormalization(axis=channel_axis, name='block3_conv1_bn')(x) 54 | 55 | x = MaxPooling3D((3, 3, 1), strides=(2, 2, 1), padding='same', name='block3_pool')(x) 56 | x = add([x, residual]) 57 | 58 | # Block 4 59 | x = Conv3D(64, (3, 3, 1), padding='same', use_bias=False, name='block4_conv1')(x) 60 | x = BatchNormalization(axis=channel_axis, name='block4_conv1_bn')(x) 61 | x = Activation('relu', name='block4_conv1_act')(x) 62 | 63 | # Classification block 64 | x = GlobalAveragePooling3D(name='avg_pool')(x) 65 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x) 66 | 67 | # ## create an MLP architecture with dense layers : 4096 -> 512 -> 10 68 | # ## add dropouts to avoid overfitting / perform regularization 69 | # dense_layer1 = Dense(units=2048, activation='relu')(x) 70 | # dense_layer1 = Dropout(0.4)(dense_layer1) 71 | # dense_layer2 = Dense(units=512, activation='relu')(dense_layer1) 72 | # dense_layer2 = Dropout(0.4)(dense_layer2) 73 | # output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2) 74 | 75 | # define the model with input layer and output layer 76 | model = Model(inputs=input_layer, outputs=output_layer) 77 | model.summary() 78 | 79 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 80 | 81 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc']) 82 | 83 | # model save 84 | file_path = self.output_directory + '/best_model.hdf5' 85 | model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 86 | 87 | # Tensorboard log 88 | log_dir = self.output_directory + '/tf_logs' 89 | chk_n_mkdir(log_dir) 90 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 91 | 92 | self.callbacks = [model_checkpoint, tb_cb] 93 | return model 94 | -------------------------------------------------------------------------------- /models/xception_kapp.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense 2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D 3 | from keras.losses import categorical_crossentropy 4 | from keras.optimizers import Adam 5 | from keras.utils import plot_model 6 | from keras import callbacks 7 | from keras import models 8 | from keras.applications import Xception 9 | 10 | from utils_datagen import TrainValTensorBoard 11 | from utils_basic import chk_n_mkdir 12 | from models.base_model import BaseModel 13 | 14 | 15 | class XCEPTION_APP(BaseModel): 16 | def __init__(self, output_directory, input_shape, n_classes, verbose=False): 17 | self.output_directory = output_directory + '/xception_kapp' 18 | chk_n_mkdir(self.output_directory) 19 | self.model = self.build_model(input_shape, n_classes) 20 | if verbose: 21 | self.model.summary() 22 | self.verbose = verbose 23 | self.model.save_weights(self.output_directory + '/model_init.hdf5') 24 | 25 | def build_model(self, input_shape, n_classes): 26 | # Load the VGG model 27 | xception_conv = Xception(weights='imagenet', include_top=False, input_shape=input_shape) 28 | 29 | # Freeze the layers except the last 4 layers 30 | for layer in xception_conv.layers: 31 | layer.trainable = False 32 | 33 | # Create the model 34 | model = models.Sequential() 35 | 36 | # Add the vgg convolutional base model 37 | model.add(xception_conv) 38 | # Add new layers 39 | model.add(Flatten()) 40 | model.add(Dense(1024, activation='relu')) 41 | model.add(Dropout(0.5)) 42 | model.add(Dense(n_classes, activation='softmax', name='predictions')) 43 | 44 | # define the model with input layer and output layer 45 | model.summary() 46 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True) 47 | model.compile(loss=categorical_crossentropy, optimizer=Adam(lr=0.01), metrics=['acc']) 48 | 49 | # model save 50 | file_path = self.output_directory + '/best_model.hdf5' 51 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True) 52 | 53 | # Tensorboard log 54 | log_dir = self.output_directory + '/tf_logs' 55 | chk_n_mkdir(log_dir) 56 | tb_cb = TrainValTensorBoard(log_dir=log_dir) 57 | 58 | self.callbacks = [model_checkpoint, tb_cb] 59 | return model 60 | 61 | -------------------------------------------------------------------------------- /non_defective_xml_files.dvc: -------------------------------------------------------------------------------- 1 | md5: 95ba84bc7f7a3f57df1fe8fb88041c81 2 | outs: 3 | - md5: 09e023a9b75c98dbc0445c4852db8bf1.dir 4 | path: non_defective_xml_files 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /result_images.dvc: -------------------------------------------------------------------------------- 1 | md5: 78a149405f6b8b4b47ede83c26222084 2 | outs: 3 | - md5: 819032fe705cf34a03306d8f3bb989ab.dir 4 | path: result_images 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /run_models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from utils_datagen import DataGenerator, create_partition_label_dict_multi_class, create_partition_label_dict_binary 4 | from models.cnn_2d import CNN2D 5 | from models.cnn_3d import CNN3D 6 | from models.vgg19 import VGG19 7 | from models.vgg19_3d import VGG193D 8 | from models.xception import XCEPTION 9 | from models.xception_3d import XCEPTION3D 10 | from models.mvcnn_xception import MVCNN_XCEPTION 11 | from models.mvcnn_cnn import MVCNN_CNN 12 | 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | CLASSIFY_BINARY = True 16 | TRAIN = False 17 | MODEL3D = False 18 | 19 | # model selection 20 | model_types = ['CNN2D', 'CNN3D', 'VGG19', 'VGG193D', 'XCEPTION', 'XCEPTION3D', 'MVCNN_XCEPTION', 'MVCNN_CNN'] 21 | model_type = model_types[7] 22 | 23 | # Dataset selection 24 | pickle_files = ['./data/rois_all_slices_2d.p', './data/rois_all_slices_3d.p', './data/rois_all_slices_inverse_3d.p', 25 | './data/rois_first_four_slices_2d.p', './data/rois_first_four_slices_2d_rotated.p', 26 | './data/rois_first_four_slices_2d_more_normal.p', './data/rois_first_four_slices_list_more_normal.p', 27 | './data/rois_first_four_slices_3d.p'] 28 | pickle_file = pickle_files[6] 29 | n_slices = 4 30 | 31 | if CLASSIFY_BINARY: 32 | n_classes = 2 33 | partition, labels, integer_mapping_label_dict, image_shape = create_partition_label_dict_binary(pickle_file) 34 | print(integer_mapping_label_dict) 35 | else: 36 | n_classes = 4 37 | partition, labels, integer_mapping_label_dict, image_shape = create_partition_label_dict_multi_class(pickle_file) 38 | 39 | img_width, img_height = image_shape[0], image_shape[1] 40 | if MODEL3D: 41 | input_shape = (img_width, img_height, n_slices, 1) 42 | else: 43 | input_shape = (img_width, img_height, 1) 44 | 45 | # input_shape = (128,128,1) 46 | 47 | if model_type == 'CNN2D': 48 | model = CNN2D('./models/saved_models/', input_shape, n_classes, True) 49 | if model_type == 'CNN3D': 50 | model = CNN3D('./models/saved_models/', input_shape, n_classes, True) 51 | if model_type == 'VGG19': 52 | model = VGG19('./models/saved_models/', input_shape, n_classes, True) 53 | if model_type == 'VGG193D': 54 | model = VGG193D('./models/saved_models/', input_shape, n_classes, True) 55 | if model_type == 'XCEPTION': 56 | model = XCEPTION('./models/saved_models/', input_shape, n_classes, True) 57 | if model_type == 'XCEPTION3D': 58 | model = XCEPTION3D('./models/saved_models/', input_shape, n_classes, True) 59 | if model_type == 'MVCNN_XCEPTION': 60 | model = MVCNN_XCEPTION('./models/saved_models/', input_shape, n_classes, True) 61 | if model_type == 'MVCNN_CNN': 62 | model = MVCNN_CNN('./models/saved_models/', input_shape, n_classes, True) 63 | 64 | # Parameters 65 | n_train_samples = len(partition['train']) 66 | n_test_samples = len(partition['test']) 67 | 68 | batch_size = 32 69 | epochs = 200 70 | samples_per_train_epoch = n_train_samples // batch_size 71 | samples_per_test_epoch = n_test_samples // batch_size 72 | 73 | params = {'dim': input_shape, 74 | 'batch_size': batch_size, 75 | 'n_classes': n_classes, 76 | 'shuffle': False} 77 | 78 | # Generators 79 | testing_generator = DataGenerator(pickle_file, partition['test'], labels, integer_mapping_label_dict, **params) 80 | training_generator = DataGenerator(pickle_file, partition['train'], labels, integer_mapping_label_dict, **params) 81 | 82 | if TRAIN: 83 | model.fit(training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs) 84 | 85 | else: 86 | model.predict(testing_generator, samples_per_test_epoch) 87 | -------------------------------------------------------------------------------- /run_models_kapp_integrated.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | from utils_datagen import DataGeneratorIndividual 5 | from models.mvcnn_xception import MVCNN_XCEPTION 6 | 7 | logging.basicConfig(level=logging.DEBUG) 8 | 9 | CLASSIFY_BINARY = True 10 | TRAIN = False 11 | 12 | # model selection 13 | model_types = ['MVCNN_XCEPTION'] 14 | model_type = model_types[0] 15 | n_slices = 4 16 | 17 | if CLASSIFY_BINARY: 18 | n_classes = 2 19 | pickle_file_name = './data/joints_4slices/details_list.p' 20 | with open(pickle_file_name, 'rb') as handle: 21 | details_list = pickle.load(handle) 22 | partition, temp_labels, temp_integer_mapping_label_dict, image_shape = details_list[0], details_list[1], details_list[2], details_list[3] 23 | labels = {} 24 | for image_name, value in temp_labels.items(): 25 | if value == 3: 26 | new_value = 1 27 | else: 28 | new_value = 0 29 | labels[image_name] = new_value 30 | integer_mapping_label_dict = {'defect': 0, 'normal': 1} 31 | else: 32 | n_classes = 4 33 | pickle_file_name = './data/joints_4slices/details_list.p' 34 | with open(pickle_file_name, 'rb') as handle: 35 | details_list = pickle.load(handle) 36 | partition, labels, integer_mapping_label_dict, image_shape = details_list[0], details_list[1], details_list[2], details_list[3] 37 | 38 | img_width, img_height = image_shape[0], image_shape[1] 39 | input_shape = image_shape 40 | 41 | if model_type == 'MVCNN_XCEPTION': 42 | model = MVCNN_XCEPTION('./models/saved_models/', input_shape, n_classes, True) 43 | 44 | # Parameters 45 | n_train_samples = len(partition['train']) 46 | n_test_samples = len(partition['test']) 47 | 48 | batch_size = 32 49 | epochs = 1000 50 | samples_per_train_epoch = n_train_samples // batch_size 51 | samples_per_test_epoch = n_test_samples // batch_size 52 | 53 | params = {'dim': input_shape, 54 | 'batch_size': batch_size, 55 | 'n_classes': n_classes, 56 | 'shuffle': False} 57 | 58 | 59 | # Generators 60 | testing_generator = DataGeneratorIndividual('./data/joints_4slices/', partition['test'], labels, integer_mapping_label_dict, **params) 61 | training_generator = DataGeneratorIndividual('./data/joints_4slices/', partition['train'], labels, integer_mapping_label_dict,**params) 62 | 63 | 64 | if TRAIN: 65 | model.fit(training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs) 66 | 67 | else: 68 | model.predict(testing_generator, samples_per_test_epoch) 69 | -------------------------------------------------------------------------------- /solder_joint.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import logging 4 | 5 | 6 | class SolderJoint: 7 | def __init__(self, component_name, defect_id, defect_name, roi): 8 | self.component_name = component_name 9 | self.defect_id = defect_id 10 | self.defect_name = defect_name 11 | self.roi = roi 12 | self.x_min, self.y_min, self.x_max, self.y_max = roi[0], roi[1], roi[2], roi[3] 13 | self.slice_dict = {} 14 | 15 | logging.info('Created solder join, component_name:%s defect_id:%d defect_name:%s roi:%d,%d,%d,%d', 16 | self.component_name, self.defect_id, self.defect_name, self.roi[0], self.roi[1], self.roi[2], 17 | self.roi[3]) 18 | 19 | def is_square(self): 20 | if (self.x_max - self.x_min) == (self.y_max - self.y_min): 21 | return True 22 | else: 23 | return False 24 | 25 | def add_slice(self, slice_id, file_location): 26 | self.slice_dict[slice_id] = file_location 27 | logging.info('slice id: %d added to the joint, image name: %s', slice_id, file_location) 28 | 29 | # this method concat 0,1,2,3 slices in 2d plain if those slices exist and roi is square 30 | # if you want to concat slices in different method write a new function like this 31 | def concat_first_four_slices_2d(self): 32 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 33 | 34 | if not self.is_square(): 35 | logging.debug('joint roi is rectangular, canceling concatenation') 36 | return None, None 37 | 38 | if len(self.slice_dict.keys()) < 4: 39 | logging.error('Number of slice < 4, canceling concatenation') 40 | return None, None 41 | 42 | if len(self.slice_dict.keys()) > 4: 43 | logging.error('Number of slice > 4, canceling concatenation') 44 | return None, None 45 | 46 | slices_list = [None, None, None, None] 47 | for slice_id in range(4): 48 | # check whether 1st 4 slices are available 49 | if slice_id not in self.slice_dict.keys(): 50 | logging.error('First 4 slices not available, canceling concatenation') 51 | return None, None 52 | 53 | img = cv2.imread(self.slice_dict[slice_id]) 54 | # there's a bug here. image slicing doesn't give a perfect square sometimes 55 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 56 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 57 | if img_roi_gray is None: 58 | logging.error('Slice read is None, canceling concatenation') 59 | return None, None 60 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA) 61 | 62 | if resized_image is None: 63 | logging.error('Error occured in opencv ROI extraction') 64 | return None, None 65 | slices_list[slice_id] = resized_image 66 | 67 | im_h1 = cv2.hconcat(slices_list[0:2]) 68 | im_h2 = cv2.hconcat(slices_list[2:4]) 69 | im_concat = cv2.vconcat([im_h1, im_h2]) 70 | 71 | if im_concat is None: 72 | logging.error('Error occured in opencv ROI concat, is none, skipping concatenation') 73 | return None, None 74 | 75 | logging.debug('First 4 slices available, concatenation done') 76 | return im_concat, self.defect_name 77 | 78 | def concat_first_four_slices_2d_4rotated(self): 79 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 80 | 81 | if not self.is_square(): 82 | logging.debug('joint roi is rectangular, canceling concatenation') 83 | return None, None 84 | 85 | if len(self.slice_dict.keys()) < 4: 86 | logging.error('Number of slice < 4, canceling concatenation') 87 | return None, None 88 | 89 | if len(self.slice_dict.keys()) > 4: 90 | logging.error('Number of slice > 4, canceling concatenation') 91 | return None, None 92 | 93 | slices_list = [None, None, None, None] 94 | try: 95 | for slice_id in range(4): 96 | if slice_id not in self.slice_dict.keys(): 97 | logging.error('First 4 slices not available, canceling concatenation') 98 | return None, None 99 | 100 | img = cv2.imread(self.slice_dict[slice_id]) 101 | # there's a bug here. image slicing doesn't give a perfect square sometimes 102 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 103 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 104 | resized_image = cv2.resize(img_roi_gray, (64, 64), interpolation=cv2.INTER_LINEAR) 105 | slices_list[slice_id] = resized_image 106 | 107 | im_h1 = cv2.hconcat(slices_list[0:2]) 108 | im_h2 = cv2.hconcat(slices_list[2:4]) 109 | im_concat = cv2.vconcat([im_h1, im_h2]) 110 | im_concat = im_concat.astype(np.float32) / 255.0 111 | im_concat = np.expand_dims(im_concat, axis=2) 112 | 113 | (h, w) = im_concat.shape[:2] 114 | center = (w / 2, h / 2) 115 | rot_mat = cv2.getRotationMatrix2D(center, 90, 1.0) 116 | rotated90 = cv2.warpAffine(im_concat, rot_mat, (h, w)) 117 | rotated90 = np.expand_dims(rotated90, axis=2) 118 | rot_mat = cv2.getRotationMatrix2D(center, 180, 1.0) 119 | rotated180 = cv2.warpAffine(im_concat, rot_mat, (w, h)) 120 | rotated180 = np.expand_dims(rotated180, axis=2) 121 | rot_mat = cv2.getRotationMatrix2D(center, 270, 1.0) 122 | rotated270 = cv2.warpAffine(im_concat, rot_mat, (h, w)) 123 | rotated270 = np.expand_dims(rotated270, axis=2) 124 | 125 | except cv2.error as e: 126 | logging.error('OpenCV Error: %s, canceling concatenation', e) 127 | return None, None 128 | 129 | logging.debug('First 4 slices available, concatenation done') 130 | return [im_concat, rotated90, rotated180, rotated270], self.defect_name 131 | 132 | def concat_first_four_slices_list_rgb(self): 133 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 134 | 135 | if not self.is_square(): 136 | logging.debug('joint roi is rectangular, canceling concatenation') 137 | return None, None 138 | 139 | if len(self.slice_dict.keys()) < 4: 140 | logging.error('Number of slice < 4, canceling concatenation') 141 | return None, None 142 | 143 | if len(self.slice_dict.keys()) > 4: 144 | logging.error('Number of slice > 4, canceling concatenation') 145 | return None, None 146 | 147 | slices_list = [None, None, None, None] 148 | for slice_id in range(4): 149 | # check whether 1st 4 slices are available 150 | if slice_id not in self.slice_dict.keys(): 151 | logging.error('First 4 slices not available, canceling concatenation') 152 | return None, None 153 | 154 | try: 155 | img = cv2.imread(self.slice_dict[slice_id]) 156 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 157 | resized_image = cv2.resize(img_roi, (299, 299), interpolation=cv2.INTER_LINEAR) 158 | # resized_image = resized_image.astype(np.float32) // 255.0 159 | slices_list[slice_id] = resized_image 160 | except cv2.error as e: 161 | logging.error('OpenCV Error: %s, canceling concatenation', e) 162 | return None, None 163 | 164 | logging.debug('First 4 slices available, concatenation done') 165 | return slices_list, self.defect_name 166 | 167 | def concat_first_four_slices_list(self): 168 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 169 | 170 | if not self.is_square(): 171 | logging.debug('joint roi is rectangular, canceling concatenation') 172 | return None, None 173 | 174 | if len(self.slice_dict.keys()) < 4: 175 | logging.error('Number of slice < 4, canceling concatenation') 176 | return None, None 177 | 178 | if len(self.slice_dict.keys()) > 4: 179 | logging.error('Number of slice > 4, canceling concatenation') 180 | return None, None 181 | 182 | slices_list = [None, None, None, None] 183 | for slice_id in range(4): 184 | # check whether 1st 4 slices are available 185 | if slice_id not in self.slice_dict.keys(): 186 | logging.error('First 4 slices not available, canceling concatenation') 187 | return None, None 188 | 189 | try: 190 | img = cv2.imread(self.slice_dict[slice_id]) 191 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 192 | resized_image = cv2.resize(img_roi, (128, 128), interpolation=cv2.INTER_LINEAR) 193 | resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY) 194 | resized_image = resized_image.astype(np.float32) / 255.0 195 | slices_list[slice_id] = resized_image 196 | except cv2.error as e: 197 | logging.error('OpenCV Error: %s, canceling concatenation', e) 198 | return None, None 199 | 200 | logging.debug('First 4 slices available, concatenation done') 201 | return slices_list, self.defect_name 202 | 203 | def concat_first_four_slices_list_4rotated(self): 204 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 205 | 206 | if not self.is_square(): 207 | logging.debug('joint roi is rectangular, canceling concatenation') 208 | return None, None 209 | 210 | if len(self.slice_dict.keys()) < 4: 211 | logging.error('Number of slice < 4, canceling concatenation') 212 | return None, None 213 | 214 | if len(self.slice_dict.keys()) > 4: 215 | logging.error('Number of slice > 4, canceling concatenation') 216 | return None, None 217 | 218 | slices_list = [None, None, None, None] 219 | for slice_id in range(4): 220 | # check whether 1st 4 slices are available 221 | if slice_id not in self.slice_dict.keys(): 222 | logging.error('First 4 slices not available, canceling concatenation') 223 | return None, None 224 | 225 | try: 226 | img = cv2.imread(self.slice_dict[slice_id]) 227 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 228 | resized_image = cv2.resize(img_roi, (128, 128), interpolation=cv2.INTER_LINEAR) 229 | resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY) 230 | resized_image = np.expand_dims(resized_image, axis=2) 231 | resized_image = resized_image.astype(np.float32) / 255.0 232 | slices_list[slice_id] = resized_image 233 | except cv2.error as e: 234 | logging.error('OpenCV Error: %s, canceling concatenation', e) 235 | return None, None 236 | 237 | img_list_of_lists = [slices_list, None, None, None] 238 | rotated_list = [None, None, None, None] 239 | 240 | for j, angle in enumerate([90, 180, 270]): 241 | (h, w) = slices_list[0].shape[:2] 242 | center = (w / 2, h / 2) 243 | for i in range(4): 244 | rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0) 245 | rotated = cv2.warpAffine(slices_list[i], rot_mat, (h, w)) 246 | rotated = np.expand_dims(rotated, axis=2) 247 | rotated_list[i] = rotated 248 | img_list_of_lists[j+1] = rotated_list 249 | 250 | logging.debug('First 4 slices available, concatenation done') 251 | return img_list_of_lists, self.defect_name 252 | 253 | def concat_first_four_slices_3d(self): 254 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 255 | 256 | if not self.is_square(): 257 | logging.debug('joint roi is rectangular, canceling concatenation') 258 | return None, None 259 | 260 | if len(self.slice_dict.keys()) < 4: 261 | logging.error('Number of slice < 4, canceling concatenation') 262 | return None, None 263 | 264 | if len(self.slice_dict.keys()) > 4: 265 | logging.error('Number of slice > 4, canceling concatenation') 266 | return None, None 267 | 268 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 269 | 270 | slices_list = [None, None, None, None] 271 | for slice_id in range(4): 272 | # check whether 1st 4 slices are available 273 | if slice_id not in self.slice_dict.keys(): 274 | logging.error('First 4 slices not available, canceling concatenation') 275 | return None, None 276 | 277 | img = cv2.imread(self.slice_dict[slice_id]) 278 | # there's a bug here. image slicing doesn't give a perfect square sometimes 279 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 280 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 281 | if img_roi_gray is None: 282 | logging.error('Slice read is None, canceling concatenation') 283 | return None, None 284 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA) 285 | resized_image = resized_image.astype(np.float32) / 255 286 | 287 | if resized_image is None: 288 | logging.error('Error occured in opencv ROI extraction') 289 | return None, None 290 | 291 | slices_list[slice_id] = resized_image 292 | 293 | # logging.debug(slices_list[0].shape) 294 | stacked_np_array = np.stack(slices_list, axis=2) 295 | # logging.debug(stacked_np_array.shape) 296 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4) 297 | logging.debug('3d image shape: %s', stacked_np_array.shape) 298 | 299 | logging.debug('First 4 slices available, concatenation done') 300 | return stacked_np_array, self.defect_name 301 | 302 | def concat_pad_all_slices_2d(self): 303 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 304 | 305 | if not self.is_square(): 306 | logging.error('joint roi is rectangular, canceling concatenation') 307 | return None, None 308 | 309 | blank_image = np.zeros(shape=[128, 128], dtype=np.uint8) 310 | slices_list = [None, None, None, None, None, None] 311 | for slice_id in range(6): 312 | if slice_id in self.slice_dict.keys(): 313 | img = cv2.imread(self.slice_dict[slice_id]) 314 | # there's a bug here. image slicing doesn't give a perfect square sometimes 315 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 316 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 317 | if img_roi_gray is None: 318 | logging.error('Slice read is None, canceling concatenation') 319 | return None, None 320 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA) 321 | 322 | if resized_image is None: 323 | logging.error('Error occured in opencv ROI extraction') 324 | return None, None 325 | slices_list[slice_id] = resized_image 326 | 327 | else: 328 | slices_list[slice_id] = blank_image 329 | logging.debug('blank slice added to slice: %d', slice_id) 330 | 331 | im_h1 = cv2.hconcat(slices_list[0:3]) 332 | im_h2 = cv2.hconcat(slices_list[3:6]) 333 | im_concat = cv2.vconcat([im_h1, im_h2]) 334 | 335 | if im_concat is None: 336 | logging.error('im_concat is none, skipping concatenation') 337 | return None, None 338 | 339 | logging.debug('concatenation done') 340 | return im_concat, self.defect_name 341 | 342 | def concat_pad_all_slices_3d(self): 343 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 344 | 345 | if not self.is_square(): 346 | logging.error('joint roi is rectangular, canceling concatenation') 347 | return None, None 348 | 349 | blank_image = np.zeros(shape=[128, 128], dtype=np.uint8) 350 | slices_list = [None, None, None, None, None, None] 351 | for slice_id in range(6): 352 | if slice_id in self.slice_dict.keys(): 353 | img = cv2.imread(self.slice_dict[slice_id]) 354 | # there's a bug here. image slicing doesn't give a perfect square sometimes 355 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 356 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 357 | if img_roi_gray is None: 358 | logging.error('Slice read is None, canceling concatenation') 359 | return None, None 360 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA) 361 | resized_image = resized_image.astype(np.float32) / 255 362 | 363 | if resized_image is None: 364 | logging.error('Error occured in opencv ROI extraction') 365 | return None, None 366 | 367 | slices_list[slice_id] = resized_image 368 | 369 | else: 370 | slices_list[slice_id] = blank_image 371 | logging.debug('blank slice added to slice: %d', slice_id) 372 | 373 | # logging.debug('xmax, xmin, ymax, ymin: %d, %d, %d, %d', self.x_max, self.x_min, self.y_max, self.y_min) 374 | # logging.debug('Slices shapes: %s, %s, %s, %s, %s, %s', slices_list[0].shape, slices_list[1].shape, 375 | # slices_list[2].shape, slices_list[3].shape, slices_list[4].shape, slices_list[5].shape) 376 | 377 | # logging.debug(slices_list[0].shape) 378 | stacked_np_array = np.stack(slices_list, axis=2) 379 | # logging.debug(stacked_np_array.shape) 380 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4) 381 | logging.debug('3d image shape: %s', stacked_np_array.shape) 382 | 383 | logging.debug('Padded and concatenated 6 slices in 3d') 384 | return stacked_np_array, self.defect_name 385 | 386 | def concat_pad_all_slices_inverse_3d(self): 387 | logging.debug('start concatenating image, joint type: %s', self.defect_name) 388 | 389 | if not self.is_square(): 390 | logging.error('joint roi is rectangular, canceling concatenation') 391 | return None, None 392 | 393 | blank_image = np.ones(shape=[128, 128], dtype=np.uint8) 394 | slices_list = [None, None, None, None, None, None] 395 | for slice_id in range(6): 396 | if slice_id in self.slice_dict.keys(): 397 | img = cv2.imread(self.slice_dict[slice_id]) 398 | # there's a bug here. image slicing doesn't give a perfect square sometimes 399 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max] 400 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY) 401 | img_roi_gray = cv2.bitwise_not(img_roi_gray) 402 | if img_roi_gray is None: 403 | logging.error('Slice read is None, canceling concatenation') 404 | return None, None 405 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA) 406 | resized_image = resized_image.astype(np.float32) / 255 407 | 408 | if resized_image is None: 409 | logging.error('Error occured in opencv ROI extraction') 410 | return None, None 411 | 412 | slices_list[slice_id] = resized_image 413 | 414 | else: 415 | slices_list[slice_id] = blank_image 416 | logging.debug('blank slice added to slice: %d', slice_id) 417 | 418 | # logging.debug(slices_list[0].shape) 419 | stacked_np_array = np.stack(slices_list, axis=2) 420 | # logging.debug(stacked_np_array.shape) 421 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4) 422 | logging.debug('3d image shape: %s', stacked_np_array.shape) 423 | 424 | logging.debug('Padded and concatenated 6 slices in 3d') 425 | return stacked_np_array, self.defect_name 426 | -------------------------------------------------------------------------------- /solder_joint_container.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import os 4 | import csv 5 | import re 6 | import cv2 7 | import sys 8 | import random 9 | import matplotlib.pyplot as plt 10 | 11 | from constants import DEFECT_NAMES_DICT 12 | from utils_basic import chk_n_mkdir 13 | 14 | from board_view import BoardView 15 | 16 | 17 | class SolderJointContainer: 18 | def __init__(self): 19 | self.board_view_dict = {} 20 | self.new_image_name_mapping_dict = {} 21 | self.csv_details_dict = {} 22 | self.incorrect_board_view_ids = [] 23 | 24 | with open('original_dataset/PTH2_reviewed.csv', mode='r') as csv_file: 25 | csv_reader = csv.DictReader(csv_file) 26 | 27 | for row in csv_reader: 28 | component_name = row["component"] 29 | defect_id = int(row["defect_type_id"]) 30 | defect_name = DEFECT_NAMES_DICT[defect_id] 31 | roi = [int(float(str)) for str in row["roi_original"].split()] 32 | file_location = 'original_dataset\\' + row["image_filename"].strip('C:\Projects\pia-test\\') 33 | view_identifier = file_location[:-5] 34 | 35 | if view_identifier in self.board_view_dict.keys(): 36 | logging.debug('fount view_identifier inside the board_view_dict') 37 | board_view_obj = self.board_view_dict[view_identifier] 38 | board_view_obj.add_slice(file_location) 39 | else: 40 | logging.debug('adding new BoardView obj to the board_view_dict') 41 | board_view_obj = BoardView(view_identifier) 42 | self.board_view_dict[view_identifier] = board_view_obj 43 | 44 | board_view_obj.add_solder_joint(component_name, defect_id, defect_name, roi) 45 | 46 | # csv_details_dict is only made for file tracking purpose 47 | if file_location in self.csv_details_dict.keys(): 48 | self.csv_details_dict[file_location].append([component_name, defect_id, defect_name, roi]) 49 | else: 50 | self.csv_details_dict[file_location] = [] 51 | self.csv_details_dict[file_location].append([component_name, defect_id, defect_name, roi]) 52 | 53 | logging.debug('csv row details, component_name:%s, defect_name:%s, roi:%d,%d,%d,%d', component_name, 54 | defect_name, roi[0], roi[1], roi[2], roi[3]) 55 | 56 | for idx, file_loc in enumerate(self.csv_details_dict.keys()): 57 | raw_image_name = file_loc[-12:] 58 | image_name_with_idx = str(idx) + "_" + raw_image_name 59 | self.new_image_name_mapping_dict[image_name_with_idx] = file_loc 60 | 61 | logging.info('SolderJointContainer obj created') 62 | 63 | # this method create images in a seperate directory marked with rois 64 | def mark_all_images_with_rois(self): 65 | logging.info("Num of images to be marked:%d", len(self.csv_details_dict.keys())) 66 | for idx, file_loc in enumerate(self.csv_details_dict.keys()): 67 | raw_image_name = file_loc[-12:] 68 | destination_path = './images_roi_marked/' 69 | chk_n_mkdir(destination_path) 70 | destination_image_path = destination_path + str(idx) + "_" + raw_image_name 71 | 72 | src_image = cv2.imread(file_loc) 73 | if src_image is None: 74 | logging.error('Could not open or find the image: %s', file_loc) 75 | exit(0) 76 | 77 | for feature_list in self.csv_details_dict[file_loc]: 78 | component_name = feature_list[0] 79 | defect_name = feature_list[2] 80 | roi = feature_list[3] 81 | if defect_name == "missing": 82 | num = '-mis' 83 | elif defect_name == "short": 84 | num = '-sht' 85 | else: 86 | num = '-inf' 87 | 88 | # draw the ROI 89 | cv2.rectangle(src_image, (roi[0], roi[1]), (roi[2], roi[3]), (255, 0, 0), 2) 90 | cv2.putText(src_image, component_name + num, (roi[0], roi[1]), cv2.FONT_HERSHEY_SIMPLEX, 1.0, 91 | (0, 0, 255), 92 | lineType=cv2.LINE_AA) 93 | 94 | # cv2.imshow("Labeled Image", src_image) 95 | # k = cv2.waitKey(5000) 96 | # if k == 27: # If escape was pressed exit 97 | # cv2.destroyAllWindows() 98 | # break 99 | cv2.imwrite(destination_image_path, src_image) 100 | logging.debug('ROI marked image saved: %s', destination_image_path) 101 | 102 | # this method generates new SolderJoint objs with non_defective labels from xml data 103 | def load_non_defective_rois_to_container(self): 104 | 105 | annotation_dir = './non_defective_xml_files' 106 | annotation_files = os.listdir(annotation_dir) 107 | 108 | for file in annotation_files: 109 | fp = open(annotation_dir + '/' + file, 'r') 110 | annotation_file = fp.read() 111 | file_name = annotation_file[annotation_file.index('') + 10:annotation_file.index('')] 112 | x_min_start = [m.start() for m in re.finditer('', annotation_file)] 113 | x_min_end = [m.start() for m in re.finditer('', annotation_file)] 114 | y_min_start = [m.start() for m in re.finditer('', annotation_file)] 115 | y_min_end = [m.start() for m in re.finditer('', annotation_file)] 116 | x_max_start = [m.start() for m in re.finditer('', annotation_file)] 117 | x_max_end = [m.start() for m in re.finditer('', annotation_file)] 118 | y_max_start = [m.start() for m in re.finditer('', annotation_file)] 119 | y_max_end = [m.start() for m in re.finditer('', annotation_file)] 120 | 121 | x_min = [int(annotation_file[x_min_start[j] + 6:x_min_end[j]]) for j in range(len(x_min_start))] 122 | y_min = [int(annotation_file[y_min_start[j] + 6:y_min_end[j]]) for j in range(len(y_min_start))] 123 | x_max = [int(annotation_file[x_max_start[j] + 6:x_max_end[j]]) for j in range(len(x_max_start))] 124 | y_max = [int(annotation_file[y_max_start[j] + 6:y_max_end[j]]) for j in range(len(y_max_start))] 125 | act_file_name = self.new_image_name_mapping_dict[file_name] 126 | view_identifier = act_file_name[:-5] 127 | board_view_obj = self.board_view_dict[view_identifier] 128 | 129 | for i in range(len(x_min)): 130 | 131 | x_min_i = x_min[i] 132 | y_min_i = y_min[i] 133 | x_max_i = x_max[i] 134 | y_max_i = y_max[i] 135 | width = x_max[i] - x_min[i] 136 | height = y_max[i] - y_min[i] 137 | 138 | logging.debug('Start height/width:' + str(height) + '/' + str(width)) 139 | if width > height: 140 | threshold = (width - height) / width * 100.0 141 | if threshold > 10.0: 142 | logging.debug('height < width:' + str(height) + '/' + str(width)) 143 | continue 144 | else: 145 | if (width - height) % 2 == 0: 146 | y_min_i = y_min_i - (width - height) // 2 147 | y_max_i = y_max_i + (width - height) // 2 148 | else: 149 | y_min_i = y_min_i - (width - height) // 2 - 1 150 | y_max_i = y_max_i + (width - height) // 2 151 | 152 | height = y_max_i - y_min_i 153 | logging.debug('new height/width:' + str(height) + '/' + str(width)) 154 | 155 | if width < height: 156 | threshold = (height - width) / height * 100.0 157 | if threshold > 10.0: 158 | logging.debug('height > width:' + str(height) + '/' + str(width)) 159 | continue 160 | else: 161 | if (height - width) % 2 == 0: 162 | x_min_i = x_min_i - (height - width) // 2 163 | x_max_i = x_max_i + (height - width) // 2 164 | else: 165 | x_min_i = x_min_i - (height - width) // 2 - 1 166 | x_max_i = x_max_i + (height - width) // 2 167 | 168 | width = x_max_i - x_min_i 169 | logging.debug('new height/width:' + str(height) + '/' + str(width)) 170 | 171 | if not (x_max_i - x_min_i) == (y_max_i - y_min_i): 172 | logging.error('w,h,xmin,xmax,ymin,ymax: %d,%d,%d,%d,%d,%d', width, height, x_min_i, x_max_i, y_min_i, y_max_i) 173 | sys.exit() 174 | 175 | board_view_obj.add_solder_joint('unknown', -1, 'normal', [x_min_i, y_min_i, x_max_i, y_max_i]) 176 | board_view_obj.add_slices_to_solder_joints() 177 | 178 | @staticmethod 179 | def create_incorrect_roi_pickle_file(self): 180 | images_list = os.listdir('./incorrect_roi_images') 181 | with open('incorrect_roi_images.p', 'wb') as filehandle: 182 | pickle.dump(images_list, filehandle) 183 | 184 | def find_flag_incorrect_roi_board_view_objs(self): 185 | with open('incorrect_roi_images.p', 'rb') as filehandle: 186 | incorrect_roi_list = pickle.load(filehandle) 187 | temp_list = [] 188 | for image_name in incorrect_roi_list: 189 | temp_list.append(self.new_image_name_mapping_dict[image_name][:-5]) 190 | 191 | list_set = set(temp_list) 192 | temp_list = list(list_set) 193 | self.incorrect_board_view_ids.extend(temp_list) 194 | 195 | for board_view_obj in self.board_view_dict.values(): 196 | if board_view_obj.view_identifier in self.incorrect_board_view_ids: 197 | board_view_obj.is_incorrect_view = True 198 | logging.debug('marked board view %s as incorrect', board_view_obj.view_identifier) 199 | 200 | def print_container_details(self): 201 | board_views = 0 202 | solder_joints = 0 203 | missing_defects = 0 204 | short_defects = 0 205 | insuf_defects = 0 206 | normal_defects = 0 207 | 208 | joints_with_1_slices = 0 209 | joints_with_2_slices = 0 210 | joints_with_3_slices = 0 211 | joints_with_4_slices = 0 212 | joints_with_5_slices = 0 213 | joints_with_6_slices = 0 214 | 215 | for board_view_obj in self.board_view_dict.values(): 216 | if not board_view_obj.is_incorrect_view: 217 | board_views += 1 218 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 219 | solder_joints += 1 220 | 221 | if len(solder_joint_obj.slice_dict.keys()) == 1: 222 | joints_with_1_slices += 1 223 | if len(solder_joint_obj.slice_dict.keys()) == 2: 224 | joints_with_2_slices += 1 225 | if len(solder_joint_obj.slice_dict.keys()) == 3: 226 | joints_with_3_slices += 1 227 | if len(solder_joint_obj.slice_dict.keys()) == 4: 228 | joints_with_4_slices += 1 229 | if len(solder_joint_obj.slice_dict.keys()) == 5: 230 | joints_with_5_slices += 1 231 | if len(solder_joint_obj.slice_dict.keys()) == 6: 232 | joints_with_6_slices += 1 233 | 234 | label = solder_joint_obj.defect_name 235 | if label == 'normal': 236 | normal_defects += 1 237 | if label == 'missing': 238 | missing_defects += 1 239 | if label == 'insufficient': 240 | insuf_defects += 1 241 | if label == 'short': 242 | short_defects += 1 243 | 244 | print('*****correct view details*****') 245 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects, 246 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects) 247 | print('joints_with_1_slices:', joints_with_1_slices, 'joints_with_2_slices:', joints_with_1_slices, 248 | 'joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices, 249 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices) 250 | 251 | board_views = 0 252 | solder_joints = 0 253 | missing_defects = 0 254 | short_defects = 0 255 | insuf_defects = 0 256 | normal_defects = 0 257 | 258 | joints_with_1_slices = 0 259 | joints_with_2_slices = 0 260 | joints_with_3_slices = 0 261 | joints_with_4_slices = 0 262 | joints_with_5_slices = 0 263 | joints_with_6_slices = 0 264 | 265 | for board_view_obj in self.board_view_dict.values(): 266 | if board_view_obj.is_incorrect_view: 267 | board_views += 1 268 | 269 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 270 | solder_joints += 1 271 | 272 | if len(solder_joint_obj.slice_dict.keys()) == 1: 273 | joints_with_1_slices += 1 274 | if len(solder_joint_obj.slice_dict.keys()) == 2: 275 | joints_with_2_slices += 1 276 | if len(solder_joint_obj.slice_dict.keys()) == 3: 277 | joints_with_3_slices += 1 278 | if len(solder_joint_obj.slice_dict.keys()) == 4: 279 | joints_with_4_slices += 1 280 | if len(solder_joint_obj.slice_dict.keys()) == 5: 281 | joints_with_5_slices += 1 282 | if len(solder_joint_obj.slice_dict.keys()) == 6: 283 | joints_with_6_slices += 1 284 | 285 | label = solder_joint_obj.defect_name 286 | if label == 'normal': 287 | normal_defects += 1 288 | if label == 'missing': 289 | missing_defects += 1 290 | if label == 'insufficient': 291 | insuf_defects += 1 292 | if label == 'short': 293 | short_defects += 1 294 | 295 | print('*****incorrect view details*****') 296 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects, 297 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects) 298 | print('joints_with_1_slices:', joints_with_1_slices, 'joints_with_2_slices:', joints_with_1_slices, 299 | 'joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices, 300 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices) 301 | 302 | board_views = 0 303 | solder_joints = 0 304 | missing_defects = 0 305 | short_defects = 0 306 | insuf_defects = 0 307 | normal_defects = 0 308 | 309 | joints_with_3_slices = 0 310 | joints_with_4_slices = 0 311 | joints_with_5_slices = 0 312 | joints_with_6_slices = 0 313 | 314 | for board_view_obj in self.board_view_dict.values(): 315 | if not board_view_obj.is_incorrect_view: 316 | board_views += 1 317 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 318 | if solder_joint_obj.is_square: 319 | solder_joints += 1 320 | 321 | if len(solder_joint_obj.slice_dict.keys()) == 3: 322 | joints_with_3_slices += 1 323 | if len(solder_joint_obj.slice_dict.keys()) == 4: 324 | joints_with_4_slices += 1 325 | if len(solder_joint_obj.slice_dict.keys()) == 5: 326 | joints_with_5_slices += 1 327 | if len(solder_joint_obj.slice_dict.keys()) == 6: 328 | joints_with_6_slices += 1 329 | 330 | label = solder_joint_obj.defect_name 331 | if label == 'normal': 332 | normal_defects += 1 333 | if label == 'missing': 334 | missing_defects += 1 335 | if label == 'insufficient': 336 | insuf_defects += 1 337 | if label == 'short': 338 | short_defects += 1 339 | 340 | print('*****correct square roi details*****') 341 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects, 342 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects) 343 | print('joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices, 344 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices) 345 | 346 | with open('./data/rois_first_four_slices_2d.p', 'rb') as handle: 347 | img_dict = pickle.load(handle) 348 | missing, short, insuf, normal = 0, 0, 0, 0 349 | for val_list in img_dict.values(): 350 | if val_list[1] == 'missing': 351 | missing += 1 352 | if val_list[1] == 'short': 353 | short += 1 354 | if val_list[1] == 'insufficient': 355 | insuf += 1 356 | if val_list[1] == 'normal': 357 | normal += 1 358 | 359 | print('after concat 4 slices, missing:', missing, 'short:', short, 'insufficient:', insuf, 'normal:', normal) 360 | 361 | with open('./data/rois_all_slices_3d.p', 'rb') as handle: 362 | img_dict = pickle.load(handle) 363 | missing, short, insuf, normal = 0, 0, 0, 0 364 | for val_list in img_dict.values(): 365 | if val_list[1] == 'missing': 366 | missing += 1 367 | if val_list[1] == 'short': 368 | short += 1 369 | if val_list[1] == 'insufficient': 370 | insuf += 1 371 | if val_list[1] == 'normal': 372 | normal += 1 373 | 374 | print('after concat all slices, missing:', missing, 'short:', short, 'insufficient:', insuf, 'normal:', normal) 375 | 376 | def print_solder_joint_resolution_details(self): 377 | res_dict = {} 378 | 379 | for board_view_obj in self.board_view_dict.values(): 380 | if not board_view_obj.is_incorrect_view: 381 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 382 | if solder_joint_obj.is_square: 383 | res = solder_joint_obj.x_max - solder_joint_obj.x_min 384 | if res in res_dict.keys(): 385 | res_dict[res] = res_dict[res] + 1 386 | else: 387 | res_dict[res] = 1 388 | 389 | print('*****resolution details*****') 390 | plt.figure() 391 | plt.title('Distribution of resolution values') 392 | plt.xlabel('Resolution') 393 | plt.ylabel('Number of solder joints') 394 | plt.bar(res_dict.keys(), res_dict.values(), 1.0, color='g') 395 | plt.show() 396 | 397 | sum = 0 398 | count = 0 399 | for key in res_dict.keys(): 400 | sum += key * res_dict[key] 401 | count += res_dict[key] 402 | print('mean:', sum/count) 403 | 404 | def save_concat_images_first_four_slices_2d(self): 405 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/short/') 406 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/insufficient/') 407 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/missing/') 408 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/normal/') 409 | img_count = 0 410 | for board_view_obj in self.board_view_dict.values(): 411 | if not board_view_obj.is_incorrect_view: 412 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 413 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 414 | concat_image, label = solder_joint_obj.concat_first_four_slices_2d() 415 | if concat_image is not None: 416 | img_count += 1 417 | destination_image_path = './data/roi_concatenated_four_slices/' + label + '/' + str(img_count) \ 418 | + '.jpg' 419 | cv2.imwrite(destination_image_path, concat_image) 420 | logging.debug('saving concatenated image, joint type: %s', label) 421 | logging.info('saved images of concatenated 4 slices in 2d') 422 | 423 | def save_concat_images_first_four_slices_2d_pickle(self): 424 | image_dict = {} 425 | img_count = 0 426 | for board_view_obj in self.board_view_dict.values(): 427 | if not board_view_obj.is_incorrect_view: 428 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 429 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 430 | concat_image, label = solder_joint_obj.concat_first_four_slices_2d() 431 | if concat_image is not None: 432 | img_count += 1 433 | image_dict[img_count] = [concat_image, label] 434 | 435 | with open('./data/rois_first_four_slices_2d.p', 'wb') as handle: 436 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 437 | logging.info('saved images of concatenated 4 slices in 2d to pickle') 438 | 439 | def save_concat_images_first_four_slices_2d_more_pickle(self): 440 | image_dict = {} 441 | img_count = 0 442 | for board_view_obj in self.board_view_dict.values(): 443 | if not board_view_obj.is_incorrect_view: 444 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 445 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 446 | concat_images, label = solder_joint_obj.concat_first_four_slices_2d_4rotated() 447 | if concat_images is not None: 448 | img_count += 1 449 | image_dict[img_count] = [concat_images[0], label] 450 | img_count += 1 451 | image_dict[img_count] = [concat_images[1], label] 452 | img_count += 1 453 | image_dict[img_count] = [concat_images[2], label] 454 | img_count += 1 455 | image_dict[img_count] = [concat_images[3], label] 456 | 457 | with open('./data/rois_first_four_slices_2d_rotated.p', 'wb') as handle: 458 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 459 | logging.info('saved images of concatenated 4 slices in 2d to pickle') 460 | 461 | def save_concat_images_first_four_slices_2d_more_normal_pickle(self): 462 | image_dict = {} 463 | img_count = 0 464 | for board_view_obj in self.board_view_dict.values(): 465 | if not board_view_obj.is_incorrect_view: 466 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 467 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 468 | concat_images, label = solder_joint_obj.concat_first_four_slices_2d_4rotated() 469 | if concat_images is not None: 470 | if label == 'normal': 471 | img_count += 1 472 | image_dict[img_count] = [concat_images[0], label] 473 | img_count += 1 474 | image_dict[img_count] = [concat_images[1], label] 475 | img_count += 1 476 | image_dict[img_count] = [concat_images[2], label] 477 | img_count += 1 478 | image_dict[img_count] = [concat_images[3], label] 479 | else: 480 | img_count += 1 481 | image_dict[img_count] = [concat_images[0], label] 482 | 483 | with open('./data/rois_first_four_slices_2d_more_normal.p', 'wb') as handle: 484 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 485 | logging.info('saved images of concatenated 4 slices in 2d to pickle') 486 | 487 | def save_concat_images_first_four_slices_list_more_normal_pickle(self): 488 | image_dict = {} 489 | img_count = 0 490 | for board_view_obj in self.board_view_dict.values(): 491 | if not board_view_obj.is_incorrect_view: 492 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 493 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 494 | concat_images, label = solder_joint_obj.concat_first_four_slices_list_4rotated() 495 | if concat_images is not None: 496 | if label == 'normal': 497 | img_count += 1 498 | image_dict[img_count] = [concat_images[0], label] 499 | img_count += 1 500 | image_dict[img_count] = [concat_images[1], label] 501 | img_count += 1 502 | image_dict[img_count] = [concat_images[2], label] 503 | img_count += 1 504 | image_dict[img_count] = [concat_images[3], label] 505 | else: 506 | img_count += 1 507 | image_dict[img_count] = [concat_images[0], label] 508 | 509 | with open('./data/rois_first_four_slices_list_more_normal.p', 'wb') as handle: 510 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 511 | logging.info('saved images of concatenated 4 slices in 2d to pickle') 512 | 513 | def save_concat_images_first_four_slices_list_rgb_pickle(self): 514 | image_dict = {} 515 | img_count = 0 516 | for board_view_obj in self.board_view_dict.values(): 517 | if not board_view_obj.is_incorrect_view: 518 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 519 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 520 | stacked_list, label = solder_joint_obj.concat_first_four_slices_list_rgb() 521 | if stacked_list is not None: 522 | img_count += 1 523 | image_dict[img_count] = [stacked_list, label] 524 | 525 | with open('./data/rois_first_four_slices_list_rgb.p', 'wb') as handle: 526 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 527 | logging.info('saved images of concatenated 4 slices in a list to pickle') 528 | 529 | @staticmethod 530 | def save_4slices_individual_pickle(self): 531 | chk_n_mkdir('./data/joints_4slices') 532 | with open('./data/rois_first_four_slices_list_rgb.p', 'rb') as handle: 533 | image_dict = pickle.load(handle) 534 | 535 | train_key_list = [] 536 | test_key_list = [] 537 | partition = {'train': train_key_list, 'test': test_key_list} 538 | 539 | classes = ['missing', 'insufficient', 'short', 'normal'] 540 | integer_mapping_dict = {x: i for i, x in enumerate(classes)} 541 | missing_key_list = [] 542 | insufficient_key_list = [] 543 | short_key_list = [] 544 | normal_key_list = [] 545 | 546 | labels = {} 547 | for image_key in image_dict.keys(): 548 | label = image_dict[image_key][1] 549 | pickle_file_name = str(image_key) + '.p' 550 | labels[pickle_file_name] = integer_mapping_dict[label] 551 | if label == 'missing': 552 | missing_key_list.append(image_key) 553 | elif label == 'insufficient': 554 | insufficient_key_list.append(image_key) 555 | elif label == 'short': 556 | short_key_list.append(image_key) 557 | elif label == 'normal': 558 | normal_key_list.append(image_key) 559 | 560 | pickle_file_name = './data/joints_4slices/' + str(image_key) + '.p' 561 | with open(pickle_file_name, 'wb') as handle: 562 | pickle.dump(image_dict[image_key][0], handle, protocol=pickle.HIGHEST_PROTOCOL) 563 | 564 | for key_list in [missing_key_list, insufficient_key_list, short_key_list, normal_key_list]: 565 | num_train = int(len(key_list) * 0.9) 566 | num_images = len(key_list) 567 | train_indices = random.sample(range(num_images), num_train) 568 | test_indices = [] 569 | for index in range(num_images): 570 | if index not in train_indices: 571 | test_indices.append(index) 572 | 573 | for index in train_indices: 574 | pickle_file_name = str(key_list[index]) + '.p' 575 | train_key_list.append(pickle_file_name) 576 | 577 | for index in test_indices: 578 | pickle_file_name = str(key_list[index]) + '.p' 579 | test_key_list.append(pickle_file_name) 580 | 581 | if isinstance(next(iter(image_dict.values()))[0], list): 582 | image_shape = next(iter(image_dict.values()))[0][0].shape 583 | print('slices list selected, image shape:', image_shape) 584 | else: 585 | image_shape = next(iter(image_dict.values()))[0].shape 586 | 587 | pickle_file_name = './data/joints_4slices/details_list.p' 588 | with open(pickle_file_name, 'wb') as handle: 589 | details_list = [partition, labels, integer_mapping_dict, image_shape] 590 | pickle.dump(details_list, handle, protocol=pickle.HIGHEST_PROTOCOL) 591 | 592 | def save_concat_images_first_four_slices_3d_pickle(self): 593 | image_dict = {} 594 | img_count = 0 595 | for board_view_obj in self.board_view_dict.values(): 596 | if not board_view_obj.is_incorrect_view: 597 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 598 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 599 | stacked_np_array, label = solder_joint_obj.concat_first_four_slices_3d() 600 | if stacked_np_array is not None: 601 | img_count += 1 602 | image_dict[img_count] = [stacked_np_array, label] 603 | 604 | with open('./data/rois_first_four_slices_3d.p', 'wb') as handle: 605 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 606 | logging.info('saved images of concatenated 4 slices in 3d to pickle') 607 | 608 | def save_concat_images_all_slices_2d_pickle(self): 609 | image_dict = {} 610 | img_count = 0 611 | for board_view_obj in self.board_view_dict.values(): 612 | if not board_view_obj.is_incorrect_view: 613 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 614 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 615 | concat_image, label = solder_joint_obj.concat_pad_all_slices_2d() 616 | if concat_image is not None: 617 | img_count += 1 618 | image_dict[img_count] = [concat_image, label] 619 | 620 | with open('./data/rois_all_slices_2d.p', 'wb') as handle: 621 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 622 | logging.info('saved images of concatenated 6 slices in 2d to pickle') 623 | 624 | def save_concat_images_all_slices_3d_pickle(self): 625 | image_dict = {} 626 | img_count = 0 627 | for board_view_obj in self.board_view_dict.values(): 628 | if not board_view_obj.is_incorrect_view: 629 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 630 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 631 | stacked_np_array, label = solder_joint_obj.concat_pad_all_slices_3d() 632 | if stacked_np_array is not None: 633 | img_count += 1 634 | image_dict[img_count] = [stacked_np_array, label] 635 | 636 | with open('./data/rois_all_slices_3d.p', 'wb') as handle: 637 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 638 | logging.info('saved images of concatenated 6 slices in 3d to pickle') 639 | 640 | def save_concat_images_all_slices_inverse_3d_pickle(self): 641 | image_dict = {} 642 | img_count = 0 643 | for board_view_obj in self.board_view_dict.values(): 644 | if not board_view_obj.is_incorrect_view: 645 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier) 646 | for solder_joint_obj in board_view_obj.solder_joint_dict.values(): 647 | stacked_np_array, label = solder_joint_obj.concat_pad_all_slices_inverse_3d() 648 | if stacked_np_array is not None: 649 | img_count += 1 650 | image_dict[img_count] = [stacked_np_array, label] 651 | 652 | with open('./data/rois_all_slices_inverse_3d.p', 'wb') as handle: 653 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 654 | logging.info('saved images of concatenated 6 inverse slices in 3d to pickle') 655 | -------------------------------------------------------------------------------- /solder_joint_container_obj.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chinthysl/AXI_PCB_defect_detection/fe81c1d8e144ce5434aee78548cdc026d0d53d1c/solder_joint_container_obj.p -------------------------------------------------------------------------------- /utils_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from sklearn.metrics import accuracy_score 7 | from sklearn.metrics import precision_score 8 | from sklearn.metrics import recall_score 9 | 10 | 11 | def chk_n_mkdir(path): 12 | if not os.path.exists(path): 13 | os.makedirs(path) 14 | 15 | 16 | def calculate_metrics(y_true, y_pred, duration, y_true_val=None, y_pred_val=None): 17 | res = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0], 18 | columns=['precision', 'accuracy', 'recall', 'duration']) 19 | res['precision'] = precision_score(y_true, y_pred, average='macro') 20 | res['accuracy'] = accuracy_score(y_true, y_pred) 21 | 22 | if y_true_val is not None: 23 | # this is useful when transfer learning is used with cross validation 24 | res['accuracy_val'] = accuracy_score(y_true_val, y_pred_val) 25 | 26 | res['recall'] = recall_score(y_true, y_pred, average='macro') 27 | res['duration'] = duration 28 | return res 29 | 30 | 31 | def plot_epochs_metric(hist, file_name, metric='loss'): 32 | plt.figure() 33 | plt.plot(hist.history[metric]) 34 | plt.plot(hist.history['val_' + metric]) 35 | plt.title('model ' + metric) 36 | plt.ylabel(metric, fontsize='large') 37 | plt.xlabel('epoch', fontsize='large') 38 | plt.legend(['train', 'val'], loc='upper left') 39 | plt.savefig(file_name, bbox_inches='tight') 40 | plt.close() 41 | 42 | 43 | def save_logs(output_directory, hist, y_pred, y_true, duration, lr=True, y_true_val=None, y_pred_val=None): 44 | hist_df = pd.DataFrame(hist.history) 45 | hist_df.to_csv(output_directory + '/history.csv', index=False) 46 | 47 | df_metrics = calculate_metrics(y_true, y_pred, duration, y_true_val, y_pred_val) 48 | df_metrics.to_csv(output_directory + '/df_metrics.csv', index=False) 49 | 50 | index_best_model = hist_df['loss'].idxmin() 51 | row_best_model = hist_df.loc[index_best_model] 52 | 53 | df_best_model = pd.DataFrame(data=np.zeros((1, 6), dtype=np.float), index=[0], 54 | columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc', 55 | 'best_model_val_acc', 'best_model_learning_rate', 'best_model_nb_epoch']) 56 | 57 | df_best_model['best_model_train_loss'] = row_best_model['loss'] 58 | df_best_model['best_model_val_loss'] = row_best_model['val_loss'] 59 | df_best_model['best_model_train_acc'] = row_best_model['acc'] 60 | df_best_model['best_model_val_acc'] = row_best_model['val_acc'] 61 | if lr == True: 62 | df_best_model['best_model_learning_rate'] = row_best_model['lr'] 63 | df_best_model['best_model_nb_epoch'] = index_best_model 64 | 65 | df_best_model.to_csv(output_directory + '/df_best_model.csv', index=False) 66 | 67 | # plot losses 68 | plot_epochs_metric(hist, output_directory + '/epochs_loss.png') 69 | 70 | return df_metrics 71 | -------------------------------------------------------------------------------- /utils_datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | import os 4 | import tensorflow as tf 5 | from keras.callbacks import TensorBoard 6 | import pickle 7 | import random 8 | import cv2 9 | 10 | 11 | def create_partition_label_dict_multi_class(pickle_file): 12 | with open(pickle_file, 'rb') as handle: 13 | image_dict = pickle.load(handle) 14 | 15 | train_key_list = [] 16 | test_key_list = [] 17 | partition = {'train': train_key_list, 'test': test_key_list} 18 | 19 | classes = ['missing', 'insufficient', 'short', 'normal'] 20 | integer_mapping_dict = {x: i for i, x in enumerate(classes)} 21 | missing_key_list = [] 22 | insufficient_key_list = [] 23 | short_key_list = [] 24 | normal_key_list = [] 25 | 26 | labels = {} 27 | for image_key in image_dict.keys(): 28 | label = image_dict[image_key][1] 29 | labels[image_key] = integer_mapping_dict[label] 30 | if label == 'missing': 31 | missing_key_list.append(image_key) 32 | elif label == 'insufficient': 33 | insufficient_key_list.append(image_key) 34 | elif label == 'short': 35 | short_key_list.append(image_key) 36 | elif label == 'normal': 37 | normal_key_list.append(image_key) 38 | 39 | for key_list in [missing_key_list, insufficient_key_list, short_key_list, normal_key_list]: 40 | num_train = int(len(key_list) * 0.9) 41 | num_images = len(key_list) 42 | train_indices = random.sample(range(num_images), num_train) 43 | test_indices = [] 44 | for index in range(num_images): 45 | if index not in train_indices: 46 | test_indices.append(index) 47 | 48 | for index in train_indices: 49 | train_key_list.append(key_list[index]) 50 | 51 | for index in test_indices: 52 | test_key_list.append(key_list[index]) 53 | 54 | if isinstance(next(iter(image_dict.values()))[0], list): 55 | image_shape = next(iter(image_dict.values()))[0][0].shape 56 | print('slices list selected, image shape:', image_shape) 57 | else: 58 | image_shape = next(iter(image_dict.values()))[0].shape 59 | 60 | return partition, labels, integer_mapping_dict, image_shape 61 | 62 | 63 | def create_partition_label_dict_binary(pickle_file): 64 | with open(pickle_file, 'rb') as handle: 65 | image_dict = pickle.load(handle) 66 | 67 | train_key_list = [] 68 | test_key_list = [] 69 | partition = {'train': train_key_list, 'test': test_key_list} 70 | 71 | integer_mapping_dict = {'missing': 0, 'insufficient': 0, 'short': 0, 'normal': 1} 72 | defective_key_list = [] 73 | normal_key_list = [] 74 | 75 | labels = {} 76 | for image_key in image_dict.keys(): 77 | label = image_dict[image_key][1] 78 | labels[image_key] = integer_mapping_dict[label] 79 | if label == 'missing' or label == 'insufficient' or label == 'short': 80 | defective_key_list.append(image_key) 81 | elif label == 'normal': 82 | normal_key_list.append(image_key) 83 | 84 | for key_list in [defective_key_list, normal_key_list]: 85 | num_train = int(len(key_list) * 0.9) 86 | num_images = len(key_list) 87 | train_indices = random.sample(range(num_images), num_train) 88 | test_indices = [] 89 | for index in range(num_images): 90 | if index not in train_indices: 91 | test_indices.append(index) 92 | 93 | for index in train_indices: 94 | train_key_list.append(key_list[index]) 95 | 96 | for index in test_indices: 97 | test_key_list.append(key_list[index]) 98 | 99 | if isinstance(next(iter(image_dict.values()))[0], list): 100 | image_shape = next(iter(image_dict.values()))[0][0].shape 101 | print('slices list selected, image shape:', image_shape) 102 | else: 103 | image_shape = next(iter(image_dict.values()))[0].shape 104 | 105 | integer_mapping_dict = {'defect': 0, 'normal': 1} 106 | return partition, labels, integer_mapping_dict, image_shape 107 | 108 | 109 | class DataGenerator(keras.utils.Sequence): 110 | 'Generates data for Keras' 111 | def __init__(self, pickle_file, image_keys, labels, integer_mapping_dict, batch_size, dim, n_classes, shuffle=True): 112 | 'Initialization' 113 | with open(pickle_file, 'rb') as handle: 114 | image_dict = pickle.load(handle) 115 | self.image_dict = image_dict 116 | self.labels = labels 117 | self.image_keys = image_keys 118 | self.integer_mapping_dict = integer_mapping_dict 119 | self.dim = dim 120 | self.batch_size = batch_size 121 | self.n_classes = n_classes 122 | self.shuffle = shuffle 123 | self.on_epoch_end() 124 | 125 | def gen_class_labels(self): 126 | classes = [] 127 | end_idx = len(self.image_keys) // self.batch_size * self.batch_size 128 | for i in self.image_keys[0:end_idx]: 129 | classes.append(self.labels[i]) 130 | return classes 131 | 132 | def __len__(self): 133 | 'Denotes the number of batches per epoch' 134 | return int(np.floor(len(self.image_keys) / self.batch_size)) 135 | 136 | def __getitem__(self, index): 137 | 'Generate one batch of data' 138 | # Generate indexes of the batch 139 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] 140 | 141 | # Find list of IDs 142 | image_keys_temp = [self.image_keys[k] for k in indexes] 143 | 144 | # Generate data 145 | X, y = self.__data_generation(image_keys_temp) 146 | return X, y 147 | 148 | # X1, X2, X3, X4, y = self.__data_generation(image_keys_temp) 149 | # return [X1, X2, X3, X4], y 150 | 151 | def on_epoch_end(self): 152 | 'Updates indexes after each epoch' 153 | self.indexes = np.arange(len(self.image_keys)) 154 | if self.shuffle == True: 155 | np.random.shuffle(self.indexes) 156 | 157 | def __data_generation(self, image_keys_temp): 158 | # Generates data containing batch_size samples, X : (n_samples, *dim, n_channels) 159 | 160 | # # Initialization 161 | # X = np.empty((self.batch_size, *self.dim)) 162 | # y = np.empty(self.batch_size, dtype=int) 163 | # 164 | # # Generate data 165 | # for i, ID in enumerate(image_keys_temp): 166 | # # Store sample 167 | # image = self.image_dict[ID][0] 168 | # # if len(image.shape) == 2: 169 | # # image = np.expand_dims(image, axis=2) 170 | # X[i, ] = image 171 | # 172 | # # Store class 173 | # y[i] = self.labels[ID] 174 | # 175 | # return X, keras.utils.to_categorical(y, num_classes=self.n_classes) 176 | 177 | # # Initialization 178 | # X = np.empty((self.batch_size, *self.dim)) 179 | # y = np.empty(self.batch_size, dtype=int) 180 | # 181 | # # Generate data 182 | # for i, ID in enumerate(image_keys_temp): 183 | # # Store sample 184 | # image = self.image_dict[ID][0] 185 | # image = cv2.resize(image, dsize=(128, 128), interpolation=cv2.INTER_AREA) 186 | # image = np.expand_dims(image, axis=2) 187 | # X[i, ] = image 188 | # 189 | # # Store class 190 | # y[i] = self.labels[ID] 191 | # 192 | # return X, keras.utils.to_categorical(y, num_classes=self.n_classes) 193 | 194 | # # Initialization 195 | # X1 = np.empty((self.batch_size, 128, 128, 1)) 196 | # X2 = np.empty((self.batch_size, 128, 128, 1)) 197 | # X3 = np.empty((self.batch_size, 128, 128, 1)) 198 | # X4 = np.empty((self.batch_size, 128, 128, 1)) 199 | # y = np.empty(self.batch_size, dtype=int) 200 | # 201 | # # Generate data 202 | # for i, ID in enumerate(image_keys_temp): 203 | # # Store sample 204 | # images = self.image_dict[ID][0] 205 | # # print(images[:,:,0,:].shape) 206 | # image1 = images[:,:,0,:] 207 | # image2 = images[:,:,1,:] 208 | # image3 = images[:,:,2,:] 209 | # image4 = images[:,:,3,:] 210 | # X1[i,] = image1 211 | # X2[i,] = image2 212 | # X3[i,] = image3 213 | # X4[i,] = image4 214 | # 215 | # # Store class 216 | # y[i] = self.labels[ID] 217 | # 218 | # return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes) 219 | 220 | # Initialization 221 | X1 = np.empty((self.batch_size, *self.dim)) 222 | X2 = np.empty((self.batch_size, *self.dim)) 223 | X3 = np.empty((self.batch_size, *self.dim)) 224 | X4 = np.empty((self.batch_size, *self.dim)) 225 | y = np.empty(self.batch_size, dtype=int) 226 | 227 | # Generate data 228 | for i, ID in enumerate(image_keys_temp): 229 | # Store sample 230 | slices_list = self.image_dict[ID][0] 231 | 232 | image1 = slices_list[0] 233 | image2 = slices_list[1] 234 | image3 = slices_list[2] 235 | image4 = slices_list[3] 236 | X1[i,] = image1 237 | X2[i,] = image2 238 | X3[i,] = image3 239 | X4[i,] = image4 240 | 241 | # Store class 242 | y[i] = self.labels[ID] 243 | 244 | return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes) 245 | 246 | 247 | class DataGeneratorIndividual(keras.utils.Sequence): 248 | 'Generates data for Keras' 249 | def __init__(self, pickle_file_dir, image_keys, labels, integer_mapping_dict, batch_size, dim, n_classes, shuffle=True): 250 | 'Initialization' 251 | self.pickle_file_dir = pickle_file_dir 252 | self.labels = labels 253 | self.image_keys = image_keys 254 | self.integer_mapping_dict = integer_mapping_dict 255 | self.dim = dim 256 | self.batch_size = batch_size 257 | self.n_classes = n_classes 258 | self.shuffle = shuffle 259 | self.on_epoch_end() 260 | 261 | def gen_class_labels(self): 262 | classes = [] 263 | end_idx = len(self.image_keys) // self.batch_size * self.batch_size 264 | for i in self.image_keys[0:end_idx]: 265 | classes.append(self.labels[i]) 266 | return classes 267 | 268 | def __len__(self): 269 | 'Denotes the number of batches per epoch' 270 | return int(np.floor(len(self.image_keys) / self.batch_size)) 271 | 272 | def __getitem__(self, index): 273 | 'Generate one batch of data' 274 | # Generate indexes of the batch 275 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] 276 | 277 | # Find list of IDs 278 | image_keys_temp = [self.image_keys[k] for k in indexes] 279 | 280 | # Generate data 281 | X, y = self.__data_generation(image_keys_temp) 282 | return X, y 283 | 284 | # X1, X2, X3, X4, y = self.__data_generation(image_keys_temp) 285 | # return [X1, X2, X3, X4], y 286 | 287 | def on_epoch_end(self): 288 | 'Updates indexes after each epoch' 289 | self.indexes = np.arange(len(self.image_keys)) 290 | if self.shuffle == True: 291 | np.random.shuffle(self.indexes) 292 | 293 | def __data_generation(self, image_keys_temp): 294 | X1 = np.empty((self.batch_size, *self.dim)) 295 | X2 = np.empty((self.batch_size, *self.dim)) 296 | X3 = np.empty((self.batch_size, *self.dim)) 297 | X4 = np.empty((self.batch_size, *self.dim)) 298 | y = np.empty(self.batch_size, dtype=int) 299 | 300 | # Generate data 301 | for i, ID in enumerate(image_keys_temp): 302 | # Store sample 303 | with open(self.pickle_file_dir + ID, 'rb') as handle: 304 | slices_list = pickle.load(handle) 305 | 306 | image1 = slices_list[0] 307 | image2 = slices_list[1] 308 | image3 = slices_list[2] 309 | image4 = slices_list[3] 310 | X1[i,] = image1 311 | X2[i,] = image2 312 | X3[i,] = image3 313 | X4[i,] = image4 314 | 315 | # Store class 316 | y[i] = self.labels[ID] 317 | 318 | return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes) 319 | 320 | 321 | class TrainValTensorBoard(TensorBoard): 322 | def __init__(self, log_dir='./log', **kwargs): 323 | self.log_dir = log_dir 324 | # Make the original `TensorBoard` log to a subdirectory 'training' 325 | training_log_dir = os.path.join(self.log_dir, '/training') 326 | super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs) 327 | 328 | # Log the validation metrics to a separate subdirectory 329 | self.val_log_dir = os.path.join(self.log_dir, '/validation') 330 | 331 | def set_model(self, model): 332 | # Setup writer for validation metrics 333 | self.val_writer = tf.summary.FileWriter(self.val_log_dir) 334 | super(TrainValTensorBoard, self).set_model(model) 335 | 336 | def on_epoch_end(self, epoch, logs=None): 337 | # Pop the validation logs and handle them separately with 338 | # `self.val_writer`. Also rename the keys so that they can 339 | # be plotted on the same figure with the training metrics 340 | logs = logs or {} 341 | val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')} 342 | for name, value in val_logs.items(): 343 | summary = tf.Summary() 344 | summary_value = summary.value.add() 345 | summary_value.simple_value = value.item() 346 | summary_value.tag = name 347 | self.val_writer.add_summary(summary, epoch) 348 | self.val_writer.flush() 349 | 350 | # Pass the remaining logs to `TensorBoard.on_epoch_end` 351 | logs = {k: v for k, v in logs.items() if not k.startswith('val_')} 352 | super(TrainValTensorBoard, self).on_epoch_end(epoch, logs) 353 | 354 | def on_train_end(self, logs=None): 355 | super(TrainValTensorBoard, self).on_train_end(logs) 356 | self.val_writer.close() 357 | --------------------------------------------------------------------------------