├── .dvc
├── .gitignore
└── config
├── .gitignore
├── LICENSE
├── README.md
├── board_view.py
├── constants.py
├── create_datasets.py
├── data.dvc
├── incorrect_roi_images.p
├── models
├── .gitignore
├── base_model.py
├── cnn_2d.py
├── cnn_3d.py
├── mvcnn_cnn.py
├── mvcnn_xception.py
├── saved_models.dvc
├── vcnn1.py
├── vcnn2.py
├── vgg19.py
├── vgg19_3d.py
├── xception.py
├── xception_3d.py
└── xception_kapp.py
├── non_defective_xml_files.dvc
├── result_images.dvc
├── run_models.py
├── run_models_kapp_integrated.py
├── solder_joint.py
├── solder_joint_container.py
├── solder_joint_container_obj.p
├── utils_basic.py
└── utils_datagen.py
/.dvc/.gitignore:
--------------------------------------------------------------------------------
1 | /config.local
2 | /updater
3 | /state-journal
4 | /state-wal
5 | /state
6 | /lock
7 | /tmp
8 | /updater.lock
9 | /cache
10 |
--------------------------------------------------------------------------------
/.dvc/config:
--------------------------------------------------------------------------------
1 | ['remote "myremote"']
2 | url = ..\dvc-storage
3 | [core]
4 | remote = myremote
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Don't track content of these folders
2 | images_roi_marked
3 | incorrect_roi_images
4 | original_dataset
5 | .idea
6 | __pycache__
7 |
8 |
9 | # Compiled source #
10 | ###################
11 | *.com
12 | *.class
13 | *.dll
14 | *.exe
15 | *.o
16 | *.so
17 |
18 | # File types #
19 | ##############
20 | *.jpg
21 | *.pickle
22 | *.pkl
23 | *.xml
24 | *.csv
25 | *.npy
26 | *.numpy
27 |
28 | # Packages #
29 | ############
30 | # it's better to unpack these files and commit the raw source
31 | # git has its own built in compression methods
32 | *.7z
33 | *.dmg
34 | *.gz
35 | *.iso
36 | *.jar
37 | *.rar
38 | *.tar
39 | *.zip
40 | /data
41 | /non_defective_xml_files
42 | /result_images
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 chinthysl
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AXI_PCB_defect_detection
2 |
3 | This repo contains data pre-processing, classification and defect detection methodologies for images from **Advance XRay Inspection** from multi-layer PCB boards.
4 | Please go through our research paper if you need more details and be kind enough to Cite it if you are going to use this implementation.
5 | - https://ieeexplore.ieee.org/abstract/document/9442142/?casa_token=FNPEvLNODrgAAAAA:Lwm3mFCDg-BgJiSLl1uhefLUv_ApdkNBMbwECzTi1KEGnGX1PohgRLILGQKf3l7Dugr-vuQ7gDZdt4U
6 | - @inproceedings{zhang2020deep,
7 | title={Deep Learning Based Defect Detection for Solder Joints on Industrial X-Ray Circuit Board Images},
8 | author={Zhang, Qianru and Zhang, Meng and Gamanayake, Chinthaka and Yuen, Chau and Geng, Zehao and Jayasekaraand, Hirunima and Zhang, Xuewen and Woo, Chia-wei and Low, Jenny and Liu, Xiang},
9 | booktitle={2020 IEEE 18th International Conference on Industrial Informatics (INDIN)},
10 | volume={1},
11 | pages={74--79},
12 | year={2020},
13 | organization={IEEE}
14 | }
15 |
16 | - [AXI_PCB_defect_detection](#AXI_PCB_defect_detection)
17 | * [Details of the dataset](#general-guidelines)
18 | * [General guidelines for contributors](#general-guidelines)
19 | * [Project file structure](#project-file-structure)
20 | * [Code explanation](#code-explanation)
21 | + [solder_joint.py](#solder_joint)
22 | + [board_view.py](#board_view)
23 | + [solder_joint_container.py](#solder_joint_container)
24 | + [main.py](#main)
25 | * [DVC integration](#dvc-integration)
26 |
27 | ## Details of the proprietory dataset
28 | - Distinct .jpg images including all slices – 32377
29 | - Distinct ROIs with labels – 92208
30 | - PCB types – 17
31 | - Distinct XRay board views – 8063
32 | - Views containing correct ROIs – 6112, Views containing incorrect ROIs – 1951
33 | - Slices per XRay board view (slices per physical solder joint) – 3, 4, 5, 6
34 | - Number of solder joints – 22872
35 | - Correct joints – 15613, Incorrect joints – 7259
36 | - Correct-square solder joints – 14672
37 | - missing - 1496, short - 5605, insufficient - 4691 normal - 2880
38 | - 3 sliced joints - 5471, 4 sliced joints - 8590, 5 sliced joints - 97, 6 sliced joints - 277
39 |
40 | #### Special Notes:
41 | - A single XRay board view has multiple ROIs marked.
42 | - A single defective ROI(solder joint) can have multiple defect labels.
43 | - If a XRay view contains one incorrect ROI, all solder joints in that view are considered as incorrect.
44 |
45 |
46 | ## General Guidelines for contributors
47 | - Please add your data folders into the ignore file until DVC is setup
48 | - Don't do `git add --all`. Please be aware of what you commit.
49 |
50 | ## Project file structure
51 |
52 | ```bash
53 | ├── non_defective_xml_files/ # put xml labels for non defective rois here
54 | ├── original_dataset/ # put two image folders and PTH2_reviewed.csv inside this
55 | ├── board_view.py # python class for a unique PCB BoardView
56 | ├── constants.py # constant values for the project
57 | ├── incorrect_roi_images.p # this file contains file names of all incorrect rois
58 | ├── main.py # script for creating objects and run ROI concatenation
59 | ├── solder_joint.py # python class for a unique PCB SolderJoint
60 | ├── solder_joint_container.py # python class containing all BoardView and SolderJoint objs
61 | ├── solder_joint_container_obj.p # saved pickle obj for SolderJointContainer class
62 | └── utils_basics.py # helper functions
63 | ```
64 |
65 | ## Code explanation
66 |
67 | ### solder_joint.py
68 |
69 | ```
70 | +-----------------------------------------------------------------------+
71 | | SolderJoint |
72 | +-----------------------------------------------------------------------+
73 | |+ self.component_name |
74 | |+ self.defect_id |
75 | |+ self.defect_name |
76 | |+ self.roi |
77 | |+ self.is_square |
78 | |+ self.slice_dict |
79 | +-----------------------------------------------------------------------+
80 | |+__init__(self, component_name, defect_id, defect_name, roi) |
81 | |+add_slice(self, slice_id, file_location) |
82 | |+concat_first_four_slices_and_resize(self, width, height) |
83 | +-----------------------------------------------------------------------+
84 | ```
85 |
86 | - `add_slice(self, slice_id, file_location)`:
87 |
88 | This method will add the slice id and corresponding image location of that slice to the `self.slice_dict`.
89 |
90 | - `concat_first_four_slices_and_resize(self, width, height)`:
91 |
92 | In this method only 1st 4 slices from ids 0,1,2,3 are concatenated in a 2d square shape. Concatenated 2d image and the label of the joint is returned. You have to catch them in `SolderJointContainer` object and save to disk accordingly.
93 |
94 | If you want to write a new method of channel concatenation please modify this method. (For an example, if you want to generate gray scale image per slice and concatenate them as a pickle or numpy file)
95 |
96 | ### board_view.py
97 |
98 | ```
99 | +-----------------------------------------------------------------------+
100 | | BoardView |
101 | +-----------------------------------------------------------------------+
102 | |+ self.view_identifier |
103 | |+ self.is_incorrect_view |
104 | |+ self.solder_joint_dict |
105 | |+ self.slice_dict |
106 | +-----------------------------------------------------------------------+
107 | |+__init__(self, view_identifier) |
108 | |+add_solder_joint(self, component, defect_id, defect_name, roi) |
109 | |+add_slice(self, file_location) |
110 | |+add_slices_to_solder_joints(self) |
111 | +-----------------------------------------------------------------------+
112 | ```
113 |
114 | - This class contains all the SolderJoint objects and all the slices details regarding that board view of the PCB.
115 |
116 | ### solder_joint_container.py
117 |
118 | ```
119 | +-----------------------------------------------------------------------+
120 | | SolderJointContainer |
121 | +-----------------------------------------------------------------------+
122 | |+ self.board_view_dict |
123 | |+ self.new_image_name_mapping_dict |
124 | |+ self.csv_details_dict |
125 | |+ self.incorrect_board_view_ids |
126 | +-----------------------------------------------------------------------+
127 | |+__init__(self) |
128 | |+mark_all_images_with_rois(self) |
129 | |+load_non_defective_rois_to_container(self) |
130 | |+find_flag_incorrect_roi_board_view_objs(self) |
131 | |+save_concat_images_first_four_slices(self, width=128, height=128) |
132 | |+print_container_details(self) |
133 | |+write_csv_defect_roi_images(self) |
134 | +-----------------------------------------------------------------------+
135 | ```
136 |
137 | - `save_concat_images_first_four_slices(self, width=128, height=128)`:
138 |
139 | In this method only 1st 4 slices from ids 0,1,2,3 are concatenated in a 2d square shape. Concatenated 2d image and the label of the joints returned from the `SolderJoint` object is saved to the disk accordingly.
140 |
141 | If you want to write a new method of channel concatenation please modify this method.
142 |
143 | ### main.py
144 |
145 | `SolderJointContainer ` object is created inside this script. Then using that object we can call the required method to generate concatenated images.
146 |
147 | ## DVC integration
148 |
149 | This section will explains about setting up the data version controlling and storage for the project. Data version control is integrated in this project. You can simple train any model by changing parameters of the neural network models. After training you'll see changes in data folders. DVC commit those changes followed by a git commit. Trained models and data will be tracked by a local DVC storage.
150 |
--------------------------------------------------------------------------------
/board_view.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from solder_joint import SolderJoint
4 |
5 |
6 | class BoardView:
7 | def __init__(self, view_identifier):
8 | self.view_identifier = view_identifier
9 | self.solder_joint_dict = {}
10 | self.slice_dict = {}
11 | self.is_incorrect_view = False
12 |
13 | logging.info('BoardView obj created for view id: %s', self.view_identifier)
14 |
15 | def add_solder_joint(self, component, defect_id, defect_name, roi):
16 | my_tupple = tuple([roi[0], roi[1], roi[2], roi[3], defect_id])
17 | if my_tupple in self.solder_joint_dict.keys():
18 | logging.info('ROI+Defect found inside the solder_joint_dict, won\'t add a new joint')
19 | else:
20 | logging.info('Adding new SolderJoint obj for the new ROI+Defect')
21 | self.solder_joint_dict[my_tupple] = SolderJoint(component, defect_id, defect_name, roi)
22 |
23 | def add_slice(self, file_location):
24 | slice_id = int(file_location[-5])
25 | self.slice_dict[slice_id] = file_location
26 | for solder_joint_obj in self.solder_joint_dict.values():
27 | solder_joint_obj.add_slice(slice_id, file_location)
28 |
29 | def add_slices_to_solder_joints(self):
30 | for slice_id in self.slice_dict.keys():
31 | file_location = self.slice_dict[slice_id]
32 | for solder_joint_obj in self.solder_joint_dict.values():
33 | solder_joint_obj.add_slice(slice_id, file_location)
34 |
35 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | DEFECT_NAMES_DICT = {89: "missing", 164: "insufficient", 145: "short", 0: "normal"}
2 |
--------------------------------------------------------------------------------
/create_datasets.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 |
4 | from solder_joint_container import SolderJointContainer
5 |
6 | logging.basicConfig(level=logging.DEBUG)
7 |
8 | # if you don't have the pickle file for the container obj make CREATE_OBJ = True
9 | CREATE_OBJ = False
10 | if CREATE_OBJ:
11 | # creating solder joint object will read the data set and create all other objects
12 | solder_joint_container_obj = SolderJointContainer()
13 | # from a single slice of incorrect roi images whole board view objects will be marked
14 | solder_joint_container_obj.find_flag_incorrect_roi_board_view_objs()
15 | # load non defective solder joint info to the container
16 | solder_joint_container_obj.load_non_defective_rois_to_container()
17 |
18 | with open('solder_joint_container_obj.p', 'wb') as output_handle:
19 | pickle.dump(solder_joint_container_obj, output_handle, pickle.HIGHEST_PROTOCOL)
20 |
21 | # **********************************************************************************************************************
22 | with open('solder_joint_container_obj.p', 'rb') as input_handle:
23 | solder_joint_container_obj = pickle.load(input_handle)
24 |
25 | # # # this method will create a directory full of images with marked rois
26 | # solder_joint_container_obj.mark_all_images_with_rois()
27 |
28 | # # methods to create extracted roi joint image or pickle data sets
29 | # # solder_joint_container_obj.save_concat_images_first_four_slices_2d()
30 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_pickle()
31 | # solder_joint_container_obj.save_concat_images_first_four_slices_3d_pickle()
32 | # solder_joint_container_obj.save_concat_images_all_slices_2d_pickle()
33 | # solder_joint_container_obj.save_concat_images_all_slices_3d_pickle()
34 | # solder_joint_container_obj.save_concat_images_all_slices_inverse_3d_pickle()
35 | # solder_joint_container_obj.save_concat_images_first_four_slices_list_rgb_pickle()
36 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_rotated_pickle()
37 | # solder_joint_container_obj.save_concat_images_first_four_slices_2d_more_normal_pickle()
38 | solder_joint_container_obj.save_concat_images_first_four_slices_list_more_normal_pickle()
39 |
40 | # solder_joint_container_obj.save_4slices_individual_pickle()
41 |
42 | # # print details of BoardView objects and SolderJoint objects
43 | # solder_joint_container_obj.print_solder_joint_resolution_details()
44 | # solder_joint_container_obj.print_container_details()
45 |
46 |
--------------------------------------------------------------------------------
/data.dvc:
--------------------------------------------------------------------------------
1 | md5: 2e10abaf47729002df92f7752d3fe3f2
2 | outs:
3 | - md5: f52d9804b9765aae3fee0e645ffbccb7.dir
4 | path: data
5 | cache: true
6 | metric: false
7 | persist: false
8 |
--------------------------------------------------------------------------------
/incorrect_roi_images.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chinthysl/AXI_PCB_defect_detection/fe81c1d8e144ce5434aee78548cdc026d0d53d1c/incorrect_roi_images.p
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | /saved_models
2 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import seaborn as sn
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | import keras
7 | from sklearn.metrics import classification_report, confusion_matrix
8 | from abc import ABC, abstractmethod
9 | from utils_basic import save_logs
10 |
11 |
12 | class BaseModel(ABC):
13 | @abstractmethod
14 | def build_model(self, input_shape, n_classes):
15 | pass
16 |
17 | def fit(self, training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs):
18 |
19 | start_time = time.time()
20 | hist = self.model.fit_generator(generator=training_generator,
21 | validation_data=testing_generator,
22 | samples_per_epoch=samples_per_train_epoch,
23 | validation_steps=samples_per_test_epoch,
24 | epochs=epochs,
25 | callbacks=self.callbacks,
26 | use_multiprocessing=False,
27 | workers=1,
28 | verbose=self.verbose)
29 | duration = time.time() - start_time
30 |
31 | model = keras.models.load_model(self.output_directory + '/best_model.hdf5')
32 | test_loss, test_acc = model.evaluate_generator(testing_generator, steps=samples_per_test_epoch)
33 | print('test_loss:', test_loss, 'test_acc:', test_acc)
34 |
35 | y_pred = model.predict_generator(testing_generator, steps=samples_per_test_epoch)
36 | y_true = testing_generator.gen_class_labels()
37 | save_logs(self.output_directory, hist, y_pred, y_true, duration, lr=False)
38 |
39 | keras.backend.clear_session()
40 |
41 | def predict(self, testing_generator, samples_per_test_epoch):
42 | model = keras.models.load_model(self.output_directory + '/best_model.hdf5')
43 | test_loss, test_acc = model.evaluate_generator(testing_generator, steps=samples_per_test_epoch)
44 | print('test_loss:', test_loss, 'test_acc:', test_acc)
45 | y_pred = model.predict_generator(testing_generator, steps=samples_per_test_epoch)
46 |
47 | y_pred = np.argmax(y_pred, axis=1)
48 | y_true = testing_generator.gen_class_labels()
49 | print('Confusion Matrix')
50 |
51 | target_names = testing_generator.integer_mapping_dict.keys()
52 | mat = confusion_matrix(y_true, y_pred)
53 | print(mat)
54 | df_cm = pd.DataFrame(mat, index=target_names, columns=target_names)
55 | plt.figure()
56 | sn.heatmap(df_cm, annot=True, cmap='Blues', fmt='g')
57 | plt.show()
58 |
59 | print('Classification Report')
60 | print(classification_report(y_true, y_pred, target_names=target_names))
61 |
62 | keras.backend.clear_session()
63 |
64 |
65 |
--------------------------------------------------------------------------------
/models/cnn_2d.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam, Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class CNN2D(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/cnn_2d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | ## input layer
26 | input_layer = Input(input_shape)
27 |
28 | ## convolutional layers
29 | conv_layer1 = Conv2D(filters=8, kernel_size=(3, 3), activation='relu')(input_layer)
30 | conv_layer2 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(conv_layer1)
31 |
32 | ## add max pooling to obtain the most imformatic features
33 | pooling_layer1 = MaxPool2D(pool_size=(2, 2))(conv_layer2)
34 |
35 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1)
36 | conv_layer4 = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(conv_layer3)
37 | pooling_layer2 = MaxPool2D(pool_size=(2, 2))(conv_layer4)
38 |
39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture
40 | pooling_layer2 = BatchNormalization()(pooling_layer2)
41 | flatten_layer = Flatten()(pooling_layer2)
42 |
43 | ## create an MLP architecture with dense layers : 4096 -> 512 -> 10
44 | ## add dropouts to avoid overfitting / perform regularization
45 | dense_layer1 = Dense(units=2048, activation='relu')(flatten_layer)
46 | dense_layer1 = Dropout(0.4)(dense_layer1)
47 | dense_layer2 = Dense(units=512, activation='relu')(dense_layer1)
48 | dense_layer2 = Dropout(0.4)(dense_layer2)
49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2)
50 |
51 | ## define the model with input layer and output layer
52 | model = Model(inputs=input_layer, outputs=output_layer)
53 | model.summary()
54 |
55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
56 |
57 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(), metrics=['acc'])
58 |
59 | # model save
60 | file_path = self.output_directory + '/best_model.hdf5'
61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
62 |
63 | # Tensorboard log
64 | log_dir = self.output_directory + '/tf_logs'
65 | chk_n_mkdir(log_dir)
66 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
67 |
68 | self.callbacks = [model_checkpoint, tb_cb]
69 | return model
70 |
71 |
72 |
--------------------------------------------------------------------------------
/models/cnn_3d.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv3D, MaxPool3D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam, Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class CNN3D(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/cnn_3d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | ## input layer
26 | input_layer = Input(input_shape)
27 |
28 | ## convolutional layers
29 | conv_layer1 = Conv3D(filters=8, kernel_size=(3, 3, 2), activation='relu')(input_layer)
30 | conv_layer2 = Conv3D(filters=16, kernel_size=(3, 3, 2), activation='relu')(conv_layer1)
31 |
32 | ## add max pooling to obtain the most imformatic features
33 | pooling_layer1 = MaxPool3D(pool_size=(2, 2, 2))(conv_layer2)
34 |
35 | conv_layer3 = Conv3D(filters=32, kernel_size=(3, 3, 1), activation='relu')(pooling_layer1)
36 | conv_layer4 = Conv3D(filters=64, kernel_size=(3, 3, 1), activation='relu')(conv_layer3)
37 | pooling_layer2 = MaxPool3D(pool_size=(2, 2, 1))(conv_layer4)
38 |
39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture
40 | pooling_layer2 = BatchNormalization()(pooling_layer2)
41 | flatten_layer = Flatten()(pooling_layer2)
42 |
43 | ## create an MLP architecture with dense layers : 4096 -> 512 -> 10
44 | ## add dropouts to avoid overfitting / perform regularization
45 | dense_layer1 = Dense(units=2048, activation='relu')(flatten_layer)
46 | dense_layer1 = Dropout(0.4)(dense_layer1)
47 | dense_layer2 = Dense(units=512, activation='relu')(dense_layer1)
48 | dense_layer2 = Dropout(0.4)(dense_layer2)
49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2)
50 |
51 | ## define the model with input layer and output layer
52 | model = Model(inputs=input_layer, outputs=output_layer)
53 | model.summary()
54 |
55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
56 |
57 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(), metrics=['acc'])
58 |
59 | # model save
60 | file_path = self.output_directory + '/best_model.hdf5'
61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
62 |
63 | # Tensorboard log
64 | log_dir = self.output_directory + '/tf_logs'
65 | chk_n_mkdir(log_dir)
66 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
67 |
68 | self.callbacks = [model_checkpoint, tb_cb]
69 | return model
70 |
71 |
--------------------------------------------------------------------------------
/models/mvcnn_cnn.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, concatenate, Maximum
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam, Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class MVCNN_CNN(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/mvcnn_cnn'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def cnn(self, input_shape):
25 | # ## input layer
26 | input_layer = Input(input_shape)
27 |
28 | ## convolutional layers
29 | conv_layer1 = Conv2D(filters=8, kernel_size=(3, 3), activation='relu')(input_layer)
30 | conv_layer2 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(conv_layer1)
31 |
32 | ## add max pooling to obtain the most imformatic features
33 | pooling_layer1 = MaxPooling2D(pool_size=(2, 2))(conv_layer2)
34 |
35 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1)
36 | conv_layer4 = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(conv_layer3)
37 | pooling_layer2 = MaxPooling2D(pool_size=(2, 2))(conv_layer4)
38 |
39 | ## perform batch normalization on the convolution outputs before feeding it to MLP architecture
40 | pooling_layer2 = BatchNormalization()(pooling_layer2)
41 | flatten_layer = Flatten()(pooling_layer2)
42 |
43 | # input layer
44 | # input_layer = Input(input_shape)
45 | # conv_layer1 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(input_layer)
46 | # conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer1)
47 | # pooling_layer1 = MaxPooling2D(pool_size=(4, 4))(conv_layer2)
48 | # # dropout_layer1 = Dropout(0.25)(pooling_layer1)
49 | # dropout_layer1 =BatchNormalization()(pooling_layer1)
50 | # flatten_layer = Flatten()(dropout_layer1)
51 |
52 | # Create model.
53 | model = Model(input_layer, flatten_layer)
54 | return model
55 |
56 | def build_model(self, input_shape, n_classes):
57 | cnn_1 = self.cnn(input_shape)
58 | input_1 = cnn_1.input
59 | output_1 = cnn_1.output
60 | cnn_2 = self.cnn(input_shape)
61 | input_2 = cnn_2.input
62 | output_2 = cnn_2.output
63 | cnn_3 = self.cnn(input_shape)
64 | input_3 = cnn_3.input
65 | output_3 = cnn_3.output
66 | cnn_4 = self.cnn(input_shape)
67 | input_4 = cnn_4.input
68 | output_4 = cnn_4.output
69 |
70 | concat_layer = Maximum()([output_1, output_2, output_3, output_4])
71 | concat_layer = Dropout(0.5)(concat_layer)
72 | dense_layer1 = Dense(units=1024, activation='relu')(concat_layer)
73 | dense_layer1 = Dropout(0.5)(dense_layer1)
74 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(dense_layer1)
75 |
76 | model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_layer])
77 | model.summary()
78 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
79 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.01), metrics=['accuracy'])
80 |
81 | # model save
82 | file_path = self.output_directory + '/best_model.hdf5'
83 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
84 |
85 | # Tensorboard log
86 | log_dir = self.output_directory + '/tf_logs'
87 | chk_n_mkdir(log_dir)
88 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
89 |
90 | self.callbacks = [model_checkpoint, tb_cb]
91 | return model
92 |
93 |
--------------------------------------------------------------------------------
/models/mvcnn_xception.py:
--------------------------------------------------------------------------------
1 | import keras
2 | import os
3 |
4 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalMaxPooling2D, Flatten, Dense, concatenate, Maximum
5 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D
6 | from keras.losses import categorical_crossentropy
7 | from keras.optimizers import Adam
8 | from keras.models import Model
9 | from keras.utils import plot_model
10 | import keras.backend as backend
11 | import keras.utils as keras_utils
12 |
13 | from utils_datagen import TrainValTensorBoard
14 | from utils_basic import chk_n_mkdir
15 | from models.base_model import BaseModel
16 |
17 |
18 | class MVCNN_XCEPTION(BaseModel):
19 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
20 | self.output_directory = output_directory + '/mvcnn_xception'
21 | chk_n_mkdir(self.output_directory)
22 | self.model = self.build_model(input_shape, n_classes)
23 | if verbose:
24 | self.model.summary()
25 | self.verbose = verbose
26 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
27 |
28 | def xception(self, include_top=True, weights='imagenet', input_tensor=None, pooling=None, classes=1000):
29 |
30 | TF_WEIGHTS_PATH = (
31 | 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/'
32 | 'xception_weights_tf_dim_ordering_tf_kernels.h5')
33 | TF_WEIGHTS_PATH_NO_TOP = (
34 | 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/'
35 | 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
36 |
37 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
38 | raise ValueError('The `weights` argument should be either '
39 | '`None` (random initialization), `imagenet` '
40 | '(pre-training on ImageNet), '
41 | 'or the path to the weights file to be loaded.')
42 |
43 | if weights == 'imagenet' and include_top and classes != 1000:
44 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
45 | ' as true, `classes` should be 1000')
46 |
47 | # Determine proper input shape
48 | input_shape = (299, 299, 3)
49 |
50 | if input_tensor is None:
51 | img_input = Input(shape=input_shape)
52 | else:
53 | if not backend.is_keras_tensor(input_tensor):
54 | img_input = Input(tensor=input_tensor, shape=input_shape)
55 | else:
56 | img_input = input_tensor
57 |
58 | channel_axis = -1
59 |
60 | x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
61 | x = BatchNormalization(axis=channel_axis)(x)
62 | x = Activation('relu')(x)
63 | x = Conv2D(64, (3, 3), use_bias=False)(x)
64 | x = BatchNormalization(axis=channel_axis)(x)
65 | x = Activation('relu')(x)
66 |
67 | residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
68 | residual = BatchNormalization(axis=channel_axis)(residual)
69 |
70 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
71 | x = BatchNormalization(axis=channel_axis)(x)
72 | x = Activation('relu')(x)
73 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
74 | x = BatchNormalization(axis=channel_axis)(x)
75 |
76 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
77 | x = add([x, residual])
78 |
79 | residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
80 | residual = BatchNormalization(axis=channel_axis)(residual)
81 |
82 | x = Activation('relu')(x)
83 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
84 | x = BatchNormalization(axis=channel_axis)(x)
85 | x = Activation('relu')(x)
86 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
87 | x = BatchNormalization(axis=channel_axis)(x)
88 |
89 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
90 | x = add([x, residual])
91 |
92 | residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
93 | residual = BatchNormalization(axis=channel_axis)(residual)
94 |
95 | x = Activation('relu')(x)
96 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
97 | x = BatchNormalization(axis=channel_axis)(x)
98 | x = Activation('relu')(x)
99 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
100 | x = BatchNormalization(axis=channel_axis)(x)
101 |
102 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
103 | x = add([x, residual])
104 |
105 | for i in range(8):
106 | residual = x
107 |
108 | x = Activation('relu')(x)
109 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
110 | x = BatchNormalization(axis=channel_axis)(x)
111 | x = Activation('relu')(x)
112 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
113 | x = BatchNormalization(axis=channel_axis)(x)
114 | x = Activation('relu')(x)
115 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
116 | x = BatchNormalization(axis=channel_axis)(x)
117 |
118 | x = add([x, residual])
119 |
120 | residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
121 | residual = BatchNormalization(axis=channel_axis)(residual)
122 |
123 | x = Activation('relu')(x)
124 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
125 | x = BatchNormalization(axis=channel_axis)(x)
126 | x = Activation('relu')(x)
127 | x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x)
128 | x = BatchNormalization(axis=channel_axis)(x)
129 |
130 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
131 | x = add([x, residual])
132 |
133 | x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x)
134 | x = BatchNormalization(axis=channel_axis)(x)
135 | x = Activation('relu')(x)
136 |
137 | x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x)
138 | x = BatchNormalization(axis=channel_axis)(x)
139 | x = Activation('relu')(x)
140 |
141 | if include_top:
142 | x = GlobalAveragePooling2D(name='avg_pool')(x)
143 | x = Dense(classes, activation='softmax', name='predictions')(x)
144 | else:
145 | if pooling == 'avg':
146 | x = GlobalAveragePooling2D()(x)
147 | elif pooling == 'max':
148 | x = GlobalMaxPooling2D()(x)
149 |
150 | # Ensure that the model takes into account
151 | # any potential predecessors of `input_tensor`.
152 | if input_tensor is not None:
153 | inputs = keras_utils.get_source_inputs(input_tensor)
154 | else:
155 | inputs = img_input
156 | # Create model.
157 | model = Model(inputs, x)
158 |
159 | # Load weights.
160 | if weights == 'imagenet':
161 | if include_top:
162 | weights_path = keras_utils.get_file(
163 | 'xception_weights_tf_dim_ordering_tf_kernels.h5',
164 | TF_WEIGHTS_PATH,
165 | cache_subdir='models',
166 | file_hash='0a58e3b7378bc2990ea3b43d5981f1f6')
167 | else:
168 | weights_path = keras_utils.get_file(
169 | 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
170 | TF_WEIGHTS_PATH_NO_TOP,
171 | cache_subdir='models',
172 | file_hash='b0042744bf5b25fce3cb969f33bebb97')
173 | model.load_weights(weights_path)
174 | if backend.backend() == 'theano':
175 | keras_utils.convert_all_kernels_in_model(model)
176 | elif weights is not None:
177 | model.load_weights(weights)
178 |
179 | return model
180 |
181 | def build_model(self, input_shape, n_classes):
182 | xception_1 = self.xception(include_top=False, pooling='max')
183 | for layer in xception_1.layers:
184 | layer.trainable = False
185 | input_1 = xception_1.input
186 | output_1 = xception_1.output
187 | xception_2 = self.xception(include_top=False, pooling='max')
188 | for layer in xception_2.layers:
189 | layer.trainable = False
190 | input_2 = xception_2.input
191 | output_2 = xception_2.output
192 | xception_3 = self.xception(include_top=False, pooling='max')
193 | for layer in xception_3.layers:
194 | layer.trainable = False
195 | input_3 = xception_3.input
196 | output_3 = xception_3.output
197 | xception_4 = self.xception(include_top=False, pooling='max')
198 | for layer in xception_4.layers:
199 | layer.trainable = False
200 | input_4 = xception_4.input
201 | output_4 = xception_4.output
202 |
203 | concat_layer = Maximum()([output_1, output_2, output_3, output_4])
204 | concat_layer.trainable = False
205 | # concat_layer = Dropout(0.25)(concat_layer)
206 | # dense_layer1 = Dense(units=1024, activation='relu')(concat_layer)
207 | dense_layer1 = Dropout(0.5)(concat_layer)
208 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(dense_layer1)
209 |
210 | model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_layer])
211 | model.summary()
212 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
213 | model.compile(loss=categorical_crossentropy, optimizer=Adam(lr=0.01), metrics=['acc'])
214 |
215 | # model save
216 | file_path = self.output_directory + '/best_model.hdf5'
217 | model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
218 |
219 | # Tensorboard log
220 | log_dir = self.output_directory + '/tf_logs'
221 | chk_n_mkdir(log_dir)
222 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
223 |
224 | self.callbacks = [model_checkpoint, tb_cb]
225 | return model
226 |
227 |
228 |
--------------------------------------------------------------------------------
/models/saved_models.dvc:
--------------------------------------------------------------------------------
1 | md5: f2f2f88b356078e7b5147950ccf87b92
2 | outs:
3 | - md5: 1574ce19441fca2c5303c2c59e2e5bc8.dir
4 | path: saved_models
5 | cache: true
6 | metric: false
7 | persist: false
8 |
--------------------------------------------------------------------------------
/models/vcnn1.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class VCNN1(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/vcnn1_3d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | ## input layer
26 | input_layer = Input(input_shape)
27 |
28 | ## convolutional layers
29 | conv_layer1 = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(input_layer)
30 | pooling_layer1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer1)
31 |
32 | conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1)
33 |
34 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer2)
35 | pooling_layer2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer3)
36 | dropout_layer =Dropout(0.5)(pooling_layer2)
37 |
38 | dense_layer = Dense(units=2048, activation='relu')(dropout_layer)
39 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer)
40 |
41 | ## define the model with input layer and output layer
42 | model = Model(inputs=input_layer, outputs=output_layer)
43 | model.summary()
44 |
45 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
46 |
47 | model.compile(loss=categorical_crossentropy, optimizer=Adam(), metrics=['acc'])
48 |
49 | # model save
50 | file_path = self.output_directory + '/best_model.hdf5'
51 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
52 |
53 | # Tensorboard log
54 | log_dir = self.output_directory + '/tf_logs'
55 | chk_n_mkdir(log_dir)
56 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
57 |
58 | self.callbacks = [model_checkpoint, tb_cb]
59 | return model
60 |
61 |
--------------------------------------------------------------------------------
/models/vcnn2.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense, Concatenate, Activation
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class VCNN1(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/vcnn1_3d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | ## input layer
26 | input_layer = Input(input_shape)
27 |
28 | ## convolutional layers
29 | conv_layer1_1 = Conv2D(filters=16, kernel_size=(1, 1), activation=None, padding='same')(input_layer)
30 | conv_layer1_2 = Conv2D(filters=16, kernel_size=(3, 3), activation=None, padding='same')(input_layer)
31 | conv_layer1_3 = Conv2D(filters=16, kernel_size=(5, 5), activation=None, padding='same')(input_layer)
32 | concat_layer1 = Concatenate([conv_layer1_1, conv_layer1_2, conv_layer1_3])
33 | activation_layer1 = Activation('relu')(concat_layer1)
34 | dropout_layer1 =Dropout(0.2)(activation_layer1)
35 |
36 | conv_layer2_1 = Conv2D(filters=16, kernel_size=(1, 1), activation=None, padding='same')(dropout_layer1)
37 | conv_layer2_2 = Conv2D(filters=16, kernel_size=(3, 3), activation=None, padding='same')(dropout_layer1)
38 | concat_layer2 = Concatenate([conv_layer2_1, conv_layer2_2])
39 | activation_layer2 = Activation('relu')(concat_layer2)
40 | dropout_layer2 =Dropout(0.2)(activation_layer2)
41 |
42 | conv_layer2 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(pooling_layer1)
43 |
44 | conv_layer3 = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(conv_layer2)
45 | pooling_layer2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(conv_layer3)
46 | dropout_layer =Dropout(0.5)(pooling_layer2)
47 |
48 | dense_layer = Dense(units=2048, activation='relu')(dropout_layer)
49 | output_layer = Dense(units=n_classes, activation='softmax')(dense_layer)
50 |
51 | ## define the model with input layer and output layer
52 | model = Model(inputs=input_layer, outputs=output_layer)
53 | model.summary()
54 |
55 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
56 |
57 | model.compile(loss=categorical_crossentropy, optimizer=Adam(), metrics=['acc'])
58 |
59 | # model save
60 | file_path = self.output_directory + '/best_model.hdf5'
61 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
62 |
63 | # Tensorboard log
64 | log_dir = self.output_directory + '/tf_logs'
65 | chk_n_mkdir(log_dir)
66 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
67 |
68 | self.callbacks = [model_checkpoint, tb_cb]
69 | return model
70 |
71 |
--------------------------------------------------------------------------------
/models/vgg19.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, MaxPool2D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class VGG19(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/vgg19'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + 'model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | # input layer
26 | input_layer = Input(input_shape)
27 | # Block 1
28 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(input_layer)
29 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
30 | x = MaxPool2D((2, 2), strides=(2, 2), name='block1_pool')(x)
31 |
32 | # Block 2
33 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
34 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
35 | x = MaxPool2D((2, 2), strides=(2, 2), name='block2_pool')(x)
36 |
37 | # Block 3
38 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
39 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
40 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
41 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)
42 | x = MaxPool2D((2, 2), strides=(2, 2), name='block3_pool')(x)
43 |
44 | # Block 4
45 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
46 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
47 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
48 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)
49 | x = MaxPool2D((2, 2), strides=(2, 2), name='block4_pool')(x)
50 |
51 | # Block 5
52 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
53 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
54 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
55 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)
56 | x = MaxPool2D((2, 2), strides=(2, 2), name='block5_pool')(x)
57 |
58 | # Classification block
59 | x = Flatten(name='flatten')(x)
60 | x = Dense(4096, activation='relu', name='fc1')(x)
61 | x = Dense(4096, activation='relu', name='fc2')(x)
62 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x)
63 |
64 | ## define the model with input layer and output layer
65 | model = Model(inputs=input_layer, outputs=output_layer)
66 | model.summary()
67 |
68 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
69 |
70 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc'])
71 |
72 | # model save
73 | file_path = self.output_directory + '/best_model.hdf5'
74 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
75 |
76 | # Tensorboard log
77 | log_dir = self.output_directory + '/tf_logs'
78 | chk_n_mkdir(log_dir)
79 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
80 |
81 | self.callbacks = [model_checkpoint, tb_cb]
82 | return model
83 |
--------------------------------------------------------------------------------
/models/vgg19_3d.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv3D, MaxPool3D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class VGG193D(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/vgg19_3d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + 'model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | # input layer
26 | input_layer = Input(input_shape)
27 | # Block 1
28 | x = Conv3D(64, (3, 3, 3), activation='relu', padding='same', name='block1_conv1')(input_layer)
29 | x = Conv3D(64, (3, 3, 3), activation='relu', padding='same', name='block1_conv2')(x)
30 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block1_pool')(x)
31 |
32 | # Block 2
33 | x = Conv3D(128, (3, 3, 3), activation='relu', padding='same', name='block2_conv1')(x)
34 | x = Conv3D(128, (3, 3, 3), activation='relu', padding='same', name='block2_conv2')(x)
35 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block2_pool')(x)
36 |
37 | # Block 3
38 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv1')(x)
39 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv2')(x)
40 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv3')(x)
41 | x = Conv3D(256, (3, 3, 2), activation='relu', padding='same', name='block3_conv4')(x)
42 | x = MaxPool3D((2, 2, 2), strides=(2, 2, 1), name='block3_pool')(x)
43 |
44 | # Block 4
45 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv1')(x)
46 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv2')(x)
47 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv3')(x)
48 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block4_conv4')(x)
49 | x = MaxPool3D((2, 2, 1), strides=(2, 2, 1), name='block4_pool')(x)
50 |
51 | # Block 5
52 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv1')(x)
53 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv2')(x)
54 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv3')(x)
55 | x = Conv3D(512, (3, 3, 1), activation='relu', padding='same', name='block5_conv4')(x)
56 | x = MaxPool3D((2, 2, 1), strides=(2, 2, 1), name='block5_pool')(x)
57 |
58 | # Classification block
59 | x = Flatten(name='flatten')(x)
60 | x = Dense(4096, activation='relu', name='fc1')(x)
61 | x = Dropout(0.4)(x)
62 | x = Dense(4096, activation='relu', name='fc2')(x)
63 | x = Dropout(0.4)(x)
64 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x)
65 |
66 | ## define the model with input layer and output layer
67 | model = Model(inputs=input_layer, outputs=output_layer)
68 | model.summary()
69 |
70 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
71 |
72 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc'])
73 |
74 | # model save
75 | file_path = self.output_directory + '/best_model.hdf5'
76 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
77 |
78 | # Tensorboard log
79 | log_dir = self.output_directory + '/tf_logs'
80 | chk_n_mkdir(log_dir)
81 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
82 |
83 | self.callbacks = [model_checkpoint, tb_cb]
84 | return model
85 |
86 |
--------------------------------------------------------------------------------
/models/xception.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class XCEPTION(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/xception'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | # input layer
26 | input_layer = Input(input_shape)
27 | channel_axis = -1 # channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
28 | # Block 1
29 | x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(input_layer)
30 | x = BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
31 | x = Activation('relu', name='block1_conv1_act')(x)
32 | x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
33 | x = BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
34 | x = Activation('relu', name='block1_conv2_act')(x)
35 |
36 | residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
37 | residual = BatchNormalization(axis=channel_axis)(residual)
38 |
39 | # Block 2
40 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
41 | x = BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
42 | x = Activation('relu', name='block2_sepconv2_act')(x)
43 | x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
44 | x = BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)
45 |
46 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
47 | x = add([x, residual])
48 |
49 | residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
50 | residual = BatchNormalization(axis=channel_axis)(residual)
51 |
52 | # Block 3
53 | x = Activation('relu', name='block3_sepconv1_act')(x)
54 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
55 | x = BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
56 | x = Activation('relu', name='block3_sepconv2_act')(x)
57 | x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
58 | x = BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)
59 |
60 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
61 | x = add([x, residual])
62 |
63 | residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
64 | residual = BatchNormalization(axis=channel_axis)(residual)
65 |
66 | # Block 4
67 | x = Activation('relu', name='block4_sepconv1_act')(x)
68 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
69 | x = BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
70 | x = Activation('relu', name='block4_sepconv2_act')(x)
71 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
72 | x = BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)
73 |
74 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
75 | x = add([x, residual])
76 |
77 | # Block 5-12
78 | for i in range(8):
79 | residual = x
80 | prefix = 'block' + str(i + 5)
81 |
82 | x = Activation('relu', name=prefix + '_sepconv1_act')(x)
83 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
84 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv1_bn')(x)
85 | x = Activation('relu', name=prefix + '_sepconv2_act')(x)
86 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
87 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv2_bn')(x)
88 | x = Activation('relu', name=prefix + '_sepconv3_act')(x)
89 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
90 | x = BatchNormalization(axis=channel_axis, name=prefix + '_sepconv3_bn')(x)
91 | x = add([x, residual])
92 |
93 | residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
94 | residual = BatchNormalization(axis=channel_axis)(residual)
95 |
96 | # Block 13
97 | x = Activation('relu', name='block13_sepconv1_act')(x)
98 | x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
99 | x = BatchNormalization(axis=channel_axis, name='block13_sepconv1_bn')(x)
100 | x = Activation('relu', name='block13_sepconv2_act')(x)
101 | x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
102 | x = BatchNormalization(axis=channel_axis, name='block13_sepconv2_bn')(x)
103 |
104 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
105 | x = add([x, residual])
106 |
107 | # Block 14
108 | x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
109 | x = BatchNormalization(axis=channel_axis, name='block14_sepconv1_bn')(x)
110 | x = Activation('relu', name='block14_sepconv1_act')(x)
111 | x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False,name='block14_sepconv2')(x)
112 | x = BatchNormalization(axis=channel_axis, name='block14_sepconv2_bn')(x)
113 | x = Activation('relu', name='block14_sepconv2_act')(x)
114 |
115 | # Classification block
116 | x = GlobalAveragePooling2D(name='avg_pool')(x)
117 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x)
118 |
119 | # define the model with input layer and output layer
120 | model = Model(inputs=input_layer, outputs=output_layer)
121 | model.summary()
122 |
123 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
124 |
125 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc'])
126 |
127 | # model save
128 | file_path = self.output_directory + '/best_model.hdf5'
129 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
130 |
131 | # Tensorboard log
132 | log_dir = self.output_directory + '/tf_logs'
133 | chk_n_mkdir(log_dir)
134 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
135 |
136 | self.callbacks = [model_checkpoint, tb_cb]
137 | return model
138 |
139 |
--------------------------------------------------------------------------------
/models/xception_3d.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv3D, MaxPooling3D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling3D
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adadelta
5 | from keras.models import Model
6 | from keras.utils import plot_model
7 | from keras import callbacks
8 |
9 | from utils_datagen import TrainValTensorBoard
10 | from utils_basic import chk_n_mkdir
11 | from models.base_model import BaseModel
12 |
13 |
14 | class XCEPTION3D(BaseModel):
15 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
16 | self.output_directory = output_directory + '/xception_3d'
17 | chk_n_mkdir(self.output_directory)
18 | self.model = self.build_model(input_shape, n_classes)
19 | if verbose:
20 | self.model.summary()
21 | self.verbose = verbose
22 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
23 |
24 | def build_model(self, input_shape, n_classes):
25 | # input layer
26 | input_layer = Input(input_shape)
27 | channel_axis = -1 # channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
28 |
29 | # Block 1
30 | x = Conv3D(8, (3, 3, 3), use_bias=False, name='block1_conv1')(input_layer)
31 | x = BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
32 | x = Activation('relu', name='block1_conv1_act')(x)
33 | x = Conv3D(8, (3, 3, 2), use_bias=False, name='block1_conv2')(x)
34 | x = BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
35 | x = Activation('relu', name='block1_conv2_act')(x)
36 |
37 | residual = Conv3D(16, (1, 1, 1), strides=(2, 2, 1), padding='same', use_bias=False)(x)
38 | residual = BatchNormalization(axis=channel_axis)(residual)
39 |
40 | # Block 2
41 | x = Conv3D(16, (3, 3, 1), padding='same', use_bias=False, name='block2_conv1')(x)
42 | x = BatchNormalization(axis=channel_axis, name='block2_conv1_bn')(x)
43 |
44 | x = MaxPooling3D((3, 3, 1), strides=(2, 2, 1), padding='same', name='block2_pool')(x)
45 | x = add([x, residual])
46 |
47 | residual = Conv3D(32, (1, 1, 1), strides=(2, 2, 1), padding='same', use_bias=False)(x)
48 | residual = BatchNormalization(axis=channel_axis)(residual)
49 |
50 | # Block 3
51 | x = Activation('relu', name='block3_conv1_act')(x)
52 | x = Conv3D(32, (3, 3, 1), padding='same', use_bias=False, name='block3_conv1')(x)
53 | x = BatchNormalization(axis=channel_axis, name='block3_conv1_bn')(x)
54 |
55 | x = MaxPooling3D((3, 3, 1), strides=(2, 2, 1), padding='same', name='block3_pool')(x)
56 | x = add([x, residual])
57 |
58 | # Block 4
59 | x = Conv3D(64, (3, 3, 1), padding='same', use_bias=False, name='block4_conv1')(x)
60 | x = BatchNormalization(axis=channel_axis, name='block4_conv1_bn')(x)
61 | x = Activation('relu', name='block4_conv1_act')(x)
62 |
63 | # Classification block
64 | x = GlobalAveragePooling3D(name='avg_pool')(x)
65 | output_layer = Dense(n_classes, activation='softmax', name='predictions')(x)
66 |
67 | # ## create an MLP architecture with dense layers : 4096 -> 512 -> 10
68 | # ## add dropouts to avoid overfitting / perform regularization
69 | # dense_layer1 = Dense(units=2048, activation='relu')(x)
70 | # dense_layer1 = Dropout(0.4)(dense_layer1)
71 | # dense_layer2 = Dense(units=512, activation='relu')(dense_layer1)
72 | # dense_layer2 = Dropout(0.4)(dense_layer2)
73 | # output_layer = Dense(units=n_classes, activation='softmax')(dense_layer2)
74 |
75 | # define the model with input layer and output layer
76 | model = Model(inputs=input_layer, outputs=output_layer)
77 | model.summary()
78 |
79 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
80 |
81 | model.compile(loss=categorical_crossentropy, optimizer=Adadelta(lr=0.1), metrics=['acc'])
82 |
83 | # model save
84 | file_path = self.output_directory + '/best_model.hdf5'
85 | model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
86 |
87 | # Tensorboard log
88 | log_dir = self.output_directory + '/tf_logs'
89 | chk_n_mkdir(log_dir)
90 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
91 |
92 | self.callbacks = [model_checkpoint, tb_cb]
93 | return model
94 |
--------------------------------------------------------------------------------
/models/xception_kapp.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, Flatten, Dense
2 | from keras.layers import Dropout, Input, BatchNormalization, Activation, add, GlobalAveragePooling2D
3 | from keras.losses import categorical_crossentropy
4 | from keras.optimizers import Adam
5 | from keras.utils import plot_model
6 | from keras import callbacks
7 | from keras import models
8 | from keras.applications import Xception
9 |
10 | from utils_datagen import TrainValTensorBoard
11 | from utils_basic import chk_n_mkdir
12 | from models.base_model import BaseModel
13 |
14 |
15 | class XCEPTION_APP(BaseModel):
16 | def __init__(self, output_directory, input_shape, n_classes, verbose=False):
17 | self.output_directory = output_directory + '/xception_kapp'
18 | chk_n_mkdir(self.output_directory)
19 | self.model = self.build_model(input_shape, n_classes)
20 | if verbose:
21 | self.model.summary()
22 | self.verbose = verbose
23 | self.model.save_weights(self.output_directory + '/model_init.hdf5')
24 |
25 | def build_model(self, input_shape, n_classes):
26 | # Load the VGG model
27 | xception_conv = Xception(weights='imagenet', include_top=False, input_shape=input_shape)
28 |
29 | # Freeze the layers except the last 4 layers
30 | for layer in xception_conv.layers:
31 | layer.trainable = False
32 |
33 | # Create the model
34 | model = models.Sequential()
35 |
36 | # Add the vgg convolutional base model
37 | model.add(xception_conv)
38 | # Add new layers
39 | model.add(Flatten())
40 | model.add(Dense(1024, activation='relu'))
41 | model.add(Dropout(0.5))
42 | model.add(Dense(n_classes, activation='softmax', name='predictions'))
43 |
44 | # define the model with input layer and output layer
45 | model.summary()
46 | plot_model(model, to_file=self.output_directory + '/model_graph.png', show_shapes=True, show_layer_names=True)
47 | model.compile(loss=categorical_crossentropy, optimizer=Adam(lr=0.01), metrics=['acc'])
48 |
49 | # model save
50 | file_path = self.output_directory + '/best_model.hdf5'
51 | model_checkpoint = callbacks.ModelCheckpoint(filepath=file_path, monitor='loss', save_best_only=True)
52 |
53 | # Tensorboard log
54 | log_dir = self.output_directory + '/tf_logs'
55 | chk_n_mkdir(log_dir)
56 | tb_cb = TrainValTensorBoard(log_dir=log_dir)
57 |
58 | self.callbacks = [model_checkpoint, tb_cb]
59 | return model
60 |
61 |
--------------------------------------------------------------------------------
/non_defective_xml_files.dvc:
--------------------------------------------------------------------------------
1 | md5: 95ba84bc7f7a3f57df1fe8fb88041c81
2 | outs:
3 | - md5: 09e023a9b75c98dbc0445c4852db8bf1.dir
4 | path: non_defective_xml_files
5 | cache: true
6 | metric: false
7 | persist: false
8 |
--------------------------------------------------------------------------------
/result_images.dvc:
--------------------------------------------------------------------------------
1 | md5: 78a149405f6b8b4b47ede83c26222084
2 | outs:
3 | - md5: 819032fe705cf34a03306d8f3bb989ab.dir
4 | path: result_images
5 | cache: true
6 | metric: false
7 | persist: false
8 |
--------------------------------------------------------------------------------
/run_models.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from utils_datagen import DataGenerator, create_partition_label_dict_multi_class, create_partition_label_dict_binary
4 | from models.cnn_2d import CNN2D
5 | from models.cnn_3d import CNN3D
6 | from models.vgg19 import VGG19
7 | from models.vgg19_3d import VGG193D
8 | from models.xception import XCEPTION
9 | from models.xception_3d import XCEPTION3D
10 | from models.mvcnn_xception import MVCNN_XCEPTION
11 | from models.mvcnn_cnn import MVCNN_CNN
12 |
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 | CLASSIFY_BINARY = True
16 | TRAIN = False
17 | MODEL3D = False
18 |
19 | # model selection
20 | model_types = ['CNN2D', 'CNN3D', 'VGG19', 'VGG193D', 'XCEPTION', 'XCEPTION3D', 'MVCNN_XCEPTION', 'MVCNN_CNN']
21 | model_type = model_types[7]
22 |
23 | # Dataset selection
24 | pickle_files = ['./data/rois_all_slices_2d.p', './data/rois_all_slices_3d.p', './data/rois_all_slices_inverse_3d.p',
25 | './data/rois_first_four_slices_2d.p', './data/rois_first_four_slices_2d_rotated.p',
26 | './data/rois_first_four_slices_2d_more_normal.p', './data/rois_first_four_slices_list_more_normal.p',
27 | './data/rois_first_four_slices_3d.p']
28 | pickle_file = pickle_files[6]
29 | n_slices = 4
30 |
31 | if CLASSIFY_BINARY:
32 | n_classes = 2
33 | partition, labels, integer_mapping_label_dict, image_shape = create_partition_label_dict_binary(pickle_file)
34 | print(integer_mapping_label_dict)
35 | else:
36 | n_classes = 4
37 | partition, labels, integer_mapping_label_dict, image_shape = create_partition_label_dict_multi_class(pickle_file)
38 |
39 | img_width, img_height = image_shape[0], image_shape[1]
40 | if MODEL3D:
41 | input_shape = (img_width, img_height, n_slices, 1)
42 | else:
43 | input_shape = (img_width, img_height, 1)
44 |
45 | # input_shape = (128,128,1)
46 |
47 | if model_type == 'CNN2D':
48 | model = CNN2D('./models/saved_models/', input_shape, n_classes, True)
49 | if model_type == 'CNN3D':
50 | model = CNN3D('./models/saved_models/', input_shape, n_classes, True)
51 | if model_type == 'VGG19':
52 | model = VGG19('./models/saved_models/', input_shape, n_classes, True)
53 | if model_type == 'VGG193D':
54 | model = VGG193D('./models/saved_models/', input_shape, n_classes, True)
55 | if model_type == 'XCEPTION':
56 | model = XCEPTION('./models/saved_models/', input_shape, n_classes, True)
57 | if model_type == 'XCEPTION3D':
58 | model = XCEPTION3D('./models/saved_models/', input_shape, n_classes, True)
59 | if model_type == 'MVCNN_XCEPTION':
60 | model = MVCNN_XCEPTION('./models/saved_models/', input_shape, n_classes, True)
61 | if model_type == 'MVCNN_CNN':
62 | model = MVCNN_CNN('./models/saved_models/', input_shape, n_classes, True)
63 |
64 | # Parameters
65 | n_train_samples = len(partition['train'])
66 | n_test_samples = len(partition['test'])
67 |
68 | batch_size = 32
69 | epochs = 200
70 | samples_per_train_epoch = n_train_samples // batch_size
71 | samples_per_test_epoch = n_test_samples // batch_size
72 |
73 | params = {'dim': input_shape,
74 | 'batch_size': batch_size,
75 | 'n_classes': n_classes,
76 | 'shuffle': False}
77 |
78 | # Generators
79 | testing_generator = DataGenerator(pickle_file, partition['test'], labels, integer_mapping_label_dict, **params)
80 | training_generator = DataGenerator(pickle_file, partition['train'], labels, integer_mapping_label_dict, **params)
81 |
82 | if TRAIN:
83 | model.fit(training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs)
84 |
85 | else:
86 | model.predict(testing_generator, samples_per_test_epoch)
87 |
--------------------------------------------------------------------------------
/run_models_kapp_integrated.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 |
4 | from utils_datagen import DataGeneratorIndividual
5 | from models.mvcnn_xception import MVCNN_XCEPTION
6 |
7 | logging.basicConfig(level=logging.DEBUG)
8 |
9 | CLASSIFY_BINARY = True
10 | TRAIN = False
11 |
12 | # model selection
13 | model_types = ['MVCNN_XCEPTION']
14 | model_type = model_types[0]
15 | n_slices = 4
16 |
17 | if CLASSIFY_BINARY:
18 | n_classes = 2
19 | pickle_file_name = './data/joints_4slices/details_list.p'
20 | with open(pickle_file_name, 'rb') as handle:
21 | details_list = pickle.load(handle)
22 | partition, temp_labels, temp_integer_mapping_label_dict, image_shape = details_list[0], details_list[1], details_list[2], details_list[3]
23 | labels = {}
24 | for image_name, value in temp_labels.items():
25 | if value == 3:
26 | new_value = 1
27 | else:
28 | new_value = 0
29 | labels[image_name] = new_value
30 | integer_mapping_label_dict = {'defect': 0, 'normal': 1}
31 | else:
32 | n_classes = 4
33 | pickle_file_name = './data/joints_4slices/details_list.p'
34 | with open(pickle_file_name, 'rb') as handle:
35 | details_list = pickle.load(handle)
36 | partition, labels, integer_mapping_label_dict, image_shape = details_list[0], details_list[1], details_list[2], details_list[3]
37 |
38 | img_width, img_height = image_shape[0], image_shape[1]
39 | input_shape = image_shape
40 |
41 | if model_type == 'MVCNN_XCEPTION':
42 | model = MVCNN_XCEPTION('./models/saved_models/', input_shape, n_classes, True)
43 |
44 | # Parameters
45 | n_train_samples = len(partition['train'])
46 | n_test_samples = len(partition['test'])
47 |
48 | batch_size = 32
49 | epochs = 1000
50 | samples_per_train_epoch = n_train_samples // batch_size
51 | samples_per_test_epoch = n_test_samples // batch_size
52 |
53 | params = {'dim': input_shape,
54 | 'batch_size': batch_size,
55 | 'n_classes': n_classes,
56 | 'shuffle': False}
57 |
58 |
59 | # Generators
60 | testing_generator = DataGeneratorIndividual('./data/joints_4slices/', partition['test'], labels, integer_mapping_label_dict, **params)
61 | training_generator = DataGeneratorIndividual('./data/joints_4slices/', partition['train'], labels, integer_mapping_label_dict,**params)
62 |
63 |
64 | if TRAIN:
65 | model.fit(training_generator, testing_generator, samples_per_train_epoch, samples_per_test_epoch, epochs)
66 |
67 | else:
68 | model.predict(testing_generator, samples_per_test_epoch)
69 |
--------------------------------------------------------------------------------
/solder_joint.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import logging
4 |
5 |
6 | class SolderJoint:
7 | def __init__(self, component_name, defect_id, defect_name, roi):
8 | self.component_name = component_name
9 | self.defect_id = defect_id
10 | self.defect_name = defect_name
11 | self.roi = roi
12 | self.x_min, self.y_min, self.x_max, self.y_max = roi[0], roi[1], roi[2], roi[3]
13 | self.slice_dict = {}
14 |
15 | logging.info('Created solder join, component_name:%s defect_id:%d defect_name:%s roi:%d,%d,%d,%d',
16 | self.component_name, self.defect_id, self.defect_name, self.roi[0], self.roi[1], self.roi[2],
17 | self.roi[3])
18 |
19 | def is_square(self):
20 | if (self.x_max - self.x_min) == (self.y_max - self.y_min):
21 | return True
22 | else:
23 | return False
24 |
25 | def add_slice(self, slice_id, file_location):
26 | self.slice_dict[slice_id] = file_location
27 | logging.info('slice id: %d added to the joint, image name: %s', slice_id, file_location)
28 |
29 | # this method concat 0,1,2,3 slices in 2d plain if those slices exist and roi is square
30 | # if you want to concat slices in different method write a new function like this
31 | def concat_first_four_slices_2d(self):
32 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
33 |
34 | if not self.is_square():
35 | logging.debug('joint roi is rectangular, canceling concatenation')
36 | return None, None
37 |
38 | if len(self.slice_dict.keys()) < 4:
39 | logging.error('Number of slice < 4, canceling concatenation')
40 | return None, None
41 |
42 | if len(self.slice_dict.keys()) > 4:
43 | logging.error('Number of slice > 4, canceling concatenation')
44 | return None, None
45 |
46 | slices_list = [None, None, None, None]
47 | for slice_id in range(4):
48 | # check whether 1st 4 slices are available
49 | if slice_id not in self.slice_dict.keys():
50 | logging.error('First 4 slices not available, canceling concatenation')
51 | return None, None
52 |
53 | img = cv2.imread(self.slice_dict[slice_id])
54 | # there's a bug here. image slicing doesn't give a perfect square sometimes
55 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
56 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
57 | if img_roi_gray is None:
58 | logging.error('Slice read is None, canceling concatenation')
59 | return None, None
60 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA)
61 |
62 | if resized_image is None:
63 | logging.error('Error occured in opencv ROI extraction')
64 | return None, None
65 | slices_list[slice_id] = resized_image
66 |
67 | im_h1 = cv2.hconcat(slices_list[0:2])
68 | im_h2 = cv2.hconcat(slices_list[2:4])
69 | im_concat = cv2.vconcat([im_h1, im_h2])
70 |
71 | if im_concat is None:
72 | logging.error('Error occured in opencv ROI concat, is none, skipping concatenation')
73 | return None, None
74 |
75 | logging.debug('First 4 slices available, concatenation done')
76 | return im_concat, self.defect_name
77 |
78 | def concat_first_four_slices_2d_4rotated(self):
79 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
80 |
81 | if not self.is_square():
82 | logging.debug('joint roi is rectangular, canceling concatenation')
83 | return None, None
84 |
85 | if len(self.slice_dict.keys()) < 4:
86 | logging.error('Number of slice < 4, canceling concatenation')
87 | return None, None
88 |
89 | if len(self.slice_dict.keys()) > 4:
90 | logging.error('Number of slice > 4, canceling concatenation')
91 | return None, None
92 |
93 | slices_list = [None, None, None, None]
94 | try:
95 | for slice_id in range(4):
96 | if slice_id not in self.slice_dict.keys():
97 | logging.error('First 4 slices not available, canceling concatenation')
98 | return None, None
99 |
100 | img = cv2.imread(self.slice_dict[slice_id])
101 | # there's a bug here. image slicing doesn't give a perfect square sometimes
102 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
103 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
104 | resized_image = cv2.resize(img_roi_gray, (64, 64), interpolation=cv2.INTER_LINEAR)
105 | slices_list[slice_id] = resized_image
106 |
107 | im_h1 = cv2.hconcat(slices_list[0:2])
108 | im_h2 = cv2.hconcat(slices_list[2:4])
109 | im_concat = cv2.vconcat([im_h1, im_h2])
110 | im_concat = im_concat.astype(np.float32) / 255.0
111 | im_concat = np.expand_dims(im_concat, axis=2)
112 |
113 | (h, w) = im_concat.shape[:2]
114 | center = (w / 2, h / 2)
115 | rot_mat = cv2.getRotationMatrix2D(center, 90, 1.0)
116 | rotated90 = cv2.warpAffine(im_concat, rot_mat, (h, w))
117 | rotated90 = np.expand_dims(rotated90, axis=2)
118 | rot_mat = cv2.getRotationMatrix2D(center, 180, 1.0)
119 | rotated180 = cv2.warpAffine(im_concat, rot_mat, (w, h))
120 | rotated180 = np.expand_dims(rotated180, axis=2)
121 | rot_mat = cv2.getRotationMatrix2D(center, 270, 1.0)
122 | rotated270 = cv2.warpAffine(im_concat, rot_mat, (h, w))
123 | rotated270 = np.expand_dims(rotated270, axis=2)
124 |
125 | except cv2.error as e:
126 | logging.error('OpenCV Error: %s, canceling concatenation', e)
127 | return None, None
128 |
129 | logging.debug('First 4 slices available, concatenation done')
130 | return [im_concat, rotated90, rotated180, rotated270], self.defect_name
131 |
132 | def concat_first_four_slices_list_rgb(self):
133 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
134 |
135 | if not self.is_square():
136 | logging.debug('joint roi is rectangular, canceling concatenation')
137 | return None, None
138 |
139 | if len(self.slice_dict.keys()) < 4:
140 | logging.error('Number of slice < 4, canceling concatenation')
141 | return None, None
142 |
143 | if len(self.slice_dict.keys()) > 4:
144 | logging.error('Number of slice > 4, canceling concatenation')
145 | return None, None
146 |
147 | slices_list = [None, None, None, None]
148 | for slice_id in range(4):
149 | # check whether 1st 4 slices are available
150 | if slice_id not in self.slice_dict.keys():
151 | logging.error('First 4 slices not available, canceling concatenation')
152 | return None, None
153 |
154 | try:
155 | img = cv2.imread(self.slice_dict[slice_id])
156 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
157 | resized_image = cv2.resize(img_roi, (299, 299), interpolation=cv2.INTER_LINEAR)
158 | # resized_image = resized_image.astype(np.float32) // 255.0
159 | slices_list[slice_id] = resized_image
160 | except cv2.error as e:
161 | logging.error('OpenCV Error: %s, canceling concatenation', e)
162 | return None, None
163 |
164 | logging.debug('First 4 slices available, concatenation done')
165 | return slices_list, self.defect_name
166 |
167 | def concat_first_four_slices_list(self):
168 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
169 |
170 | if not self.is_square():
171 | logging.debug('joint roi is rectangular, canceling concatenation')
172 | return None, None
173 |
174 | if len(self.slice_dict.keys()) < 4:
175 | logging.error('Number of slice < 4, canceling concatenation')
176 | return None, None
177 |
178 | if len(self.slice_dict.keys()) > 4:
179 | logging.error('Number of slice > 4, canceling concatenation')
180 | return None, None
181 |
182 | slices_list = [None, None, None, None]
183 | for slice_id in range(4):
184 | # check whether 1st 4 slices are available
185 | if slice_id not in self.slice_dict.keys():
186 | logging.error('First 4 slices not available, canceling concatenation')
187 | return None, None
188 |
189 | try:
190 | img = cv2.imread(self.slice_dict[slice_id])
191 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
192 | resized_image = cv2.resize(img_roi, (128, 128), interpolation=cv2.INTER_LINEAR)
193 | resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
194 | resized_image = resized_image.astype(np.float32) / 255.0
195 | slices_list[slice_id] = resized_image
196 | except cv2.error as e:
197 | logging.error('OpenCV Error: %s, canceling concatenation', e)
198 | return None, None
199 |
200 | logging.debug('First 4 slices available, concatenation done')
201 | return slices_list, self.defect_name
202 |
203 | def concat_first_four_slices_list_4rotated(self):
204 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
205 |
206 | if not self.is_square():
207 | logging.debug('joint roi is rectangular, canceling concatenation')
208 | return None, None
209 |
210 | if len(self.slice_dict.keys()) < 4:
211 | logging.error('Number of slice < 4, canceling concatenation')
212 | return None, None
213 |
214 | if len(self.slice_dict.keys()) > 4:
215 | logging.error('Number of slice > 4, canceling concatenation')
216 | return None, None
217 |
218 | slices_list = [None, None, None, None]
219 | for slice_id in range(4):
220 | # check whether 1st 4 slices are available
221 | if slice_id not in self.slice_dict.keys():
222 | logging.error('First 4 slices not available, canceling concatenation')
223 | return None, None
224 |
225 | try:
226 | img = cv2.imread(self.slice_dict[slice_id])
227 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
228 | resized_image = cv2.resize(img_roi, (128, 128), interpolation=cv2.INTER_LINEAR)
229 | resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
230 | resized_image = np.expand_dims(resized_image, axis=2)
231 | resized_image = resized_image.astype(np.float32) / 255.0
232 | slices_list[slice_id] = resized_image
233 | except cv2.error as e:
234 | logging.error('OpenCV Error: %s, canceling concatenation', e)
235 | return None, None
236 |
237 | img_list_of_lists = [slices_list, None, None, None]
238 | rotated_list = [None, None, None, None]
239 |
240 | for j, angle in enumerate([90, 180, 270]):
241 | (h, w) = slices_list[0].shape[:2]
242 | center = (w / 2, h / 2)
243 | for i in range(4):
244 | rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0)
245 | rotated = cv2.warpAffine(slices_list[i], rot_mat, (h, w))
246 | rotated = np.expand_dims(rotated, axis=2)
247 | rotated_list[i] = rotated
248 | img_list_of_lists[j+1] = rotated_list
249 |
250 | logging.debug('First 4 slices available, concatenation done')
251 | return img_list_of_lists, self.defect_name
252 |
253 | def concat_first_four_slices_3d(self):
254 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
255 |
256 | if not self.is_square():
257 | logging.debug('joint roi is rectangular, canceling concatenation')
258 | return None, None
259 |
260 | if len(self.slice_dict.keys()) < 4:
261 | logging.error('Number of slice < 4, canceling concatenation')
262 | return None, None
263 |
264 | if len(self.slice_dict.keys()) > 4:
265 | logging.error('Number of slice > 4, canceling concatenation')
266 | return None, None
267 |
268 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
269 |
270 | slices_list = [None, None, None, None]
271 | for slice_id in range(4):
272 | # check whether 1st 4 slices are available
273 | if slice_id not in self.slice_dict.keys():
274 | logging.error('First 4 slices not available, canceling concatenation')
275 | return None, None
276 |
277 | img = cv2.imread(self.slice_dict[slice_id])
278 | # there's a bug here. image slicing doesn't give a perfect square sometimes
279 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
280 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
281 | if img_roi_gray is None:
282 | logging.error('Slice read is None, canceling concatenation')
283 | return None, None
284 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA)
285 | resized_image = resized_image.astype(np.float32) / 255
286 |
287 | if resized_image is None:
288 | logging.error('Error occured in opencv ROI extraction')
289 | return None, None
290 |
291 | slices_list[slice_id] = resized_image
292 |
293 | # logging.debug(slices_list[0].shape)
294 | stacked_np_array = np.stack(slices_list, axis=2)
295 | # logging.debug(stacked_np_array.shape)
296 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4)
297 | logging.debug('3d image shape: %s', stacked_np_array.shape)
298 |
299 | logging.debug('First 4 slices available, concatenation done')
300 | return stacked_np_array, self.defect_name
301 |
302 | def concat_pad_all_slices_2d(self):
303 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
304 |
305 | if not self.is_square():
306 | logging.error('joint roi is rectangular, canceling concatenation')
307 | return None, None
308 |
309 | blank_image = np.zeros(shape=[128, 128], dtype=np.uint8)
310 | slices_list = [None, None, None, None, None, None]
311 | for slice_id in range(6):
312 | if slice_id in self.slice_dict.keys():
313 | img = cv2.imread(self.slice_dict[slice_id])
314 | # there's a bug here. image slicing doesn't give a perfect square sometimes
315 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
316 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
317 | if img_roi_gray is None:
318 | logging.error('Slice read is None, canceling concatenation')
319 | return None, None
320 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA)
321 |
322 | if resized_image is None:
323 | logging.error('Error occured in opencv ROI extraction')
324 | return None, None
325 | slices_list[slice_id] = resized_image
326 |
327 | else:
328 | slices_list[slice_id] = blank_image
329 | logging.debug('blank slice added to slice: %d', slice_id)
330 |
331 | im_h1 = cv2.hconcat(slices_list[0:3])
332 | im_h2 = cv2.hconcat(slices_list[3:6])
333 | im_concat = cv2.vconcat([im_h1, im_h2])
334 |
335 | if im_concat is None:
336 | logging.error('im_concat is none, skipping concatenation')
337 | return None, None
338 |
339 | logging.debug('concatenation done')
340 | return im_concat, self.defect_name
341 |
342 | def concat_pad_all_slices_3d(self):
343 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
344 |
345 | if not self.is_square():
346 | logging.error('joint roi is rectangular, canceling concatenation')
347 | return None, None
348 |
349 | blank_image = np.zeros(shape=[128, 128], dtype=np.uint8)
350 | slices_list = [None, None, None, None, None, None]
351 | for slice_id in range(6):
352 | if slice_id in self.slice_dict.keys():
353 | img = cv2.imread(self.slice_dict[slice_id])
354 | # there's a bug here. image slicing doesn't give a perfect square sometimes
355 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
356 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
357 | if img_roi_gray is None:
358 | logging.error('Slice read is None, canceling concatenation')
359 | return None, None
360 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA)
361 | resized_image = resized_image.astype(np.float32) / 255
362 |
363 | if resized_image is None:
364 | logging.error('Error occured in opencv ROI extraction')
365 | return None, None
366 |
367 | slices_list[slice_id] = resized_image
368 |
369 | else:
370 | slices_list[slice_id] = blank_image
371 | logging.debug('blank slice added to slice: %d', slice_id)
372 |
373 | # logging.debug('xmax, xmin, ymax, ymin: %d, %d, %d, %d', self.x_max, self.x_min, self.y_max, self.y_min)
374 | # logging.debug('Slices shapes: %s, %s, %s, %s, %s, %s', slices_list[0].shape, slices_list[1].shape,
375 | # slices_list[2].shape, slices_list[3].shape, slices_list[4].shape, slices_list[5].shape)
376 |
377 | # logging.debug(slices_list[0].shape)
378 | stacked_np_array = np.stack(slices_list, axis=2)
379 | # logging.debug(stacked_np_array.shape)
380 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4)
381 | logging.debug('3d image shape: %s', stacked_np_array.shape)
382 |
383 | logging.debug('Padded and concatenated 6 slices in 3d')
384 | return stacked_np_array, self.defect_name
385 |
386 | def concat_pad_all_slices_inverse_3d(self):
387 | logging.debug('start concatenating image, joint type: %s', self.defect_name)
388 |
389 | if not self.is_square():
390 | logging.error('joint roi is rectangular, canceling concatenation')
391 | return None, None
392 |
393 | blank_image = np.ones(shape=[128, 128], dtype=np.uint8)
394 | slices_list = [None, None, None, None, None, None]
395 | for slice_id in range(6):
396 | if slice_id in self.slice_dict.keys():
397 | img = cv2.imread(self.slice_dict[slice_id])
398 | # there's a bug here. image slicing doesn't give a perfect square sometimes
399 | img_roi = img[self.y_min:self.y_max, self.x_min:self.x_max]
400 | img_roi_gray = cv2.cvtColor(img_roi, cv2.COLOR_BGR2GRAY)
401 | img_roi_gray = cv2.bitwise_not(img_roi_gray)
402 | if img_roi_gray is None:
403 | logging.error('Slice read is None, canceling concatenation')
404 | return None, None
405 | resized_image = cv2.resize(img_roi_gray, (128, 128), interpolation=cv2.INTER_AREA)
406 | resized_image = resized_image.astype(np.float32) / 255
407 |
408 | if resized_image is None:
409 | logging.error('Error occured in opencv ROI extraction')
410 | return None, None
411 |
412 | slices_list[slice_id] = resized_image
413 |
414 | else:
415 | slices_list[slice_id] = blank_image
416 | logging.debug('blank slice added to slice: %d', slice_id)
417 |
418 | # logging.debug(slices_list[0].shape)
419 | stacked_np_array = np.stack(slices_list, axis=2)
420 | # logging.debug(stacked_np_array.shape)
421 | stacked_np_array = np.expand_dims(stacked_np_array, axis=4)
422 | logging.debug('3d image shape: %s', stacked_np_array.shape)
423 |
424 | logging.debug('Padded and concatenated 6 slices in 3d')
425 | return stacked_np_array, self.defect_name
426 |
--------------------------------------------------------------------------------
/solder_joint_container.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | import os
4 | import csv
5 | import re
6 | import cv2
7 | import sys
8 | import random
9 | import matplotlib.pyplot as plt
10 |
11 | from constants import DEFECT_NAMES_DICT
12 | from utils_basic import chk_n_mkdir
13 |
14 | from board_view import BoardView
15 |
16 |
17 | class SolderJointContainer:
18 | def __init__(self):
19 | self.board_view_dict = {}
20 | self.new_image_name_mapping_dict = {}
21 | self.csv_details_dict = {}
22 | self.incorrect_board_view_ids = []
23 |
24 | with open('original_dataset/PTH2_reviewed.csv', mode='r') as csv_file:
25 | csv_reader = csv.DictReader(csv_file)
26 |
27 | for row in csv_reader:
28 | component_name = row["component"]
29 | defect_id = int(row["defect_type_id"])
30 | defect_name = DEFECT_NAMES_DICT[defect_id]
31 | roi = [int(float(str)) for str in row["roi_original"].split()]
32 | file_location = 'original_dataset\\' + row["image_filename"].strip('C:\Projects\pia-test\\')
33 | view_identifier = file_location[:-5]
34 |
35 | if view_identifier in self.board_view_dict.keys():
36 | logging.debug('fount view_identifier inside the board_view_dict')
37 | board_view_obj = self.board_view_dict[view_identifier]
38 | board_view_obj.add_slice(file_location)
39 | else:
40 | logging.debug('adding new BoardView obj to the board_view_dict')
41 | board_view_obj = BoardView(view_identifier)
42 | self.board_view_dict[view_identifier] = board_view_obj
43 |
44 | board_view_obj.add_solder_joint(component_name, defect_id, defect_name, roi)
45 |
46 | # csv_details_dict is only made for file tracking purpose
47 | if file_location in self.csv_details_dict.keys():
48 | self.csv_details_dict[file_location].append([component_name, defect_id, defect_name, roi])
49 | else:
50 | self.csv_details_dict[file_location] = []
51 | self.csv_details_dict[file_location].append([component_name, defect_id, defect_name, roi])
52 |
53 | logging.debug('csv row details, component_name:%s, defect_name:%s, roi:%d,%d,%d,%d', component_name,
54 | defect_name, roi[0], roi[1], roi[2], roi[3])
55 |
56 | for idx, file_loc in enumerate(self.csv_details_dict.keys()):
57 | raw_image_name = file_loc[-12:]
58 | image_name_with_idx = str(idx) + "_" + raw_image_name
59 | self.new_image_name_mapping_dict[image_name_with_idx] = file_loc
60 |
61 | logging.info('SolderJointContainer obj created')
62 |
63 | # this method create images in a seperate directory marked with rois
64 | def mark_all_images_with_rois(self):
65 | logging.info("Num of images to be marked:%d", len(self.csv_details_dict.keys()))
66 | for idx, file_loc in enumerate(self.csv_details_dict.keys()):
67 | raw_image_name = file_loc[-12:]
68 | destination_path = './images_roi_marked/'
69 | chk_n_mkdir(destination_path)
70 | destination_image_path = destination_path + str(idx) + "_" + raw_image_name
71 |
72 | src_image = cv2.imread(file_loc)
73 | if src_image is None:
74 | logging.error('Could not open or find the image: %s', file_loc)
75 | exit(0)
76 |
77 | for feature_list in self.csv_details_dict[file_loc]:
78 | component_name = feature_list[0]
79 | defect_name = feature_list[2]
80 | roi = feature_list[3]
81 | if defect_name == "missing":
82 | num = '-mis'
83 | elif defect_name == "short":
84 | num = '-sht'
85 | else:
86 | num = '-inf'
87 |
88 | # draw the ROI
89 | cv2.rectangle(src_image, (roi[0], roi[1]), (roi[2], roi[3]), (255, 0, 0), 2)
90 | cv2.putText(src_image, component_name + num, (roi[0], roi[1]), cv2.FONT_HERSHEY_SIMPLEX, 1.0,
91 | (0, 0, 255),
92 | lineType=cv2.LINE_AA)
93 |
94 | # cv2.imshow("Labeled Image", src_image)
95 | # k = cv2.waitKey(5000)
96 | # if k == 27: # If escape was pressed exit
97 | # cv2.destroyAllWindows()
98 | # break
99 | cv2.imwrite(destination_image_path, src_image)
100 | logging.debug('ROI marked image saved: %s', destination_image_path)
101 |
102 | # this method generates new SolderJoint objs with non_defective labels from xml data
103 | def load_non_defective_rois_to_container(self):
104 |
105 | annotation_dir = './non_defective_xml_files'
106 | annotation_files = os.listdir(annotation_dir)
107 |
108 | for file in annotation_files:
109 | fp = open(annotation_dir + '/' + file, 'r')
110 | annotation_file = fp.read()
111 | file_name = annotation_file[annotation_file.index('') + 10:annotation_file.index('')]
112 | x_min_start = [m.start() for m in re.finditer('', annotation_file)]
113 | x_min_end = [m.start() for m in re.finditer('', annotation_file)]
114 | y_min_start = [m.start() for m in re.finditer('', annotation_file)]
115 | y_min_end = [m.start() for m in re.finditer('', annotation_file)]
116 | x_max_start = [m.start() for m in re.finditer('', annotation_file)]
117 | x_max_end = [m.start() for m in re.finditer('', annotation_file)]
118 | y_max_start = [m.start() for m in re.finditer('', annotation_file)]
119 | y_max_end = [m.start() for m in re.finditer('', annotation_file)]
120 |
121 | x_min = [int(annotation_file[x_min_start[j] + 6:x_min_end[j]]) for j in range(len(x_min_start))]
122 | y_min = [int(annotation_file[y_min_start[j] + 6:y_min_end[j]]) for j in range(len(y_min_start))]
123 | x_max = [int(annotation_file[x_max_start[j] + 6:x_max_end[j]]) for j in range(len(x_max_start))]
124 | y_max = [int(annotation_file[y_max_start[j] + 6:y_max_end[j]]) for j in range(len(y_max_start))]
125 | act_file_name = self.new_image_name_mapping_dict[file_name]
126 | view_identifier = act_file_name[:-5]
127 | board_view_obj = self.board_view_dict[view_identifier]
128 |
129 | for i in range(len(x_min)):
130 |
131 | x_min_i = x_min[i]
132 | y_min_i = y_min[i]
133 | x_max_i = x_max[i]
134 | y_max_i = y_max[i]
135 | width = x_max[i] - x_min[i]
136 | height = y_max[i] - y_min[i]
137 |
138 | logging.debug('Start height/width:' + str(height) + '/' + str(width))
139 | if width > height:
140 | threshold = (width - height) / width * 100.0
141 | if threshold > 10.0:
142 | logging.debug('height < width:' + str(height) + '/' + str(width))
143 | continue
144 | else:
145 | if (width - height) % 2 == 0:
146 | y_min_i = y_min_i - (width - height) // 2
147 | y_max_i = y_max_i + (width - height) // 2
148 | else:
149 | y_min_i = y_min_i - (width - height) // 2 - 1
150 | y_max_i = y_max_i + (width - height) // 2
151 |
152 | height = y_max_i - y_min_i
153 | logging.debug('new height/width:' + str(height) + '/' + str(width))
154 |
155 | if width < height:
156 | threshold = (height - width) / height * 100.0
157 | if threshold > 10.0:
158 | logging.debug('height > width:' + str(height) + '/' + str(width))
159 | continue
160 | else:
161 | if (height - width) % 2 == 0:
162 | x_min_i = x_min_i - (height - width) // 2
163 | x_max_i = x_max_i + (height - width) // 2
164 | else:
165 | x_min_i = x_min_i - (height - width) // 2 - 1
166 | x_max_i = x_max_i + (height - width) // 2
167 |
168 | width = x_max_i - x_min_i
169 | logging.debug('new height/width:' + str(height) + '/' + str(width))
170 |
171 | if not (x_max_i - x_min_i) == (y_max_i - y_min_i):
172 | logging.error('w,h,xmin,xmax,ymin,ymax: %d,%d,%d,%d,%d,%d', width, height, x_min_i, x_max_i, y_min_i, y_max_i)
173 | sys.exit()
174 |
175 | board_view_obj.add_solder_joint('unknown', -1, 'normal', [x_min_i, y_min_i, x_max_i, y_max_i])
176 | board_view_obj.add_slices_to_solder_joints()
177 |
178 | @staticmethod
179 | def create_incorrect_roi_pickle_file(self):
180 | images_list = os.listdir('./incorrect_roi_images')
181 | with open('incorrect_roi_images.p', 'wb') as filehandle:
182 | pickle.dump(images_list, filehandle)
183 |
184 | def find_flag_incorrect_roi_board_view_objs(self):
185 | with open('incorrect_roi_images.p', 'rb') as filehandle:
186 | incorrect_roi_list = pickle.load(filehandle)
187 | temp_list = []
188 | for image_name in incorrect_roi_list:
189 | temp_list.append(self.new_image_name_mapping_dict[image_name][:-5])
190 |
191 | list_set = set(temp_list)
192 | temp_list = list(list_set)
193 | self.incorrect_board_view_ids.extend(temp_list)
194 |
195 | for board_view_obj in self.board_view_dict.values():
196 | if board_view_obj.view_identifier in self.incorrect_board_view_ids:
197 | board_view_obj.is_incorrect_view = True
198 | logging.debug('marked board view %s as incorrect', board_view_obj.view_identifier)
199 |
200 | def print_container_details(self):
201 | board_views = 0
202 | solder_joints = 0
203 | missing_defects = 0
204 | short_defects = 0
205 | insuf_defects = 0
206 | normal_defects = 0
207 |
208 | joints_with_1_slices = 0
209 | joints_with_2_slices = 0
210 | joints_with_3_slices = 0
211 | joints_with_4_slices = 0
212 | joints_with_5_slices = 0
213 | joints_with_6_slices = 0
214 |
215 | for board_view_obj in self.board_view_dict.values():
216 | if not board_view_obj.is_incorrect_view:
217 | board_views += 1
218 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
219 | solder_joints += 1
220 |
221 | if len(solder_joint_obj.slice_dict.keys()) == 1:
222 | joints_with_1_slices += 1
223 | if len(solder_joint_obj.slice_dict.keys()) == 2:
224 | joints_with_2_slices += 1
225 | if len(solder_joint_obj.slice_dict.keys()) == 3:
226 | joints_with_3_slices += 1
227 | if len(solder_joint_obj.slice_dict.keys()) == 4:
228 | joints_with_4_slices += 1
229 | if len(solder_joint_obj.slice_dict.keys()) == 5:
230 | joints_with_5_slices += 1
231 | if len(solder_joint_obj.slice_dict.keys()) == 6:
232 | joints_with_6_slices += 1
233 |
234 | label = solder_joint_obj.defect_name
235 | if label == 'normal':
236 | normal_defects += 1
237 | if label == 'missing':
238 | missing_defects += 1
239 | if label == 'insufficient':
240 | insuf_defects += 1
241 | if label == 'short':
242 | short_defects += 1
243 |
244 | print('*****correct view details*****')
245 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects,
246 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects)
247 | print('joints_with_1_slices:', joints_with_1_slices, 'joints_with_2_slices:', joints_with_1_slices,
248 | 'joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices,
249 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices)
250 |
251 | board_views = 0
252 | solder_joints = 0
253 | missing_defects = 0
254 | short_defects = 0
255 | insuf_defects = 0
256 | normal_defects = 0
257 |
258 | joints_with_1_slices = 0
259 | joints_with_2_slices = 0
260 | joints_with_3_slices = 0
261 | joints_with_4_slices = 0
262 | joints_with_5_slices = 0
263 | joints_with_6_slices = 0
264 |
265 | for board_view_obj in self.board_view_dict.values():
266 | if board_view_obj.is_incorrect_view:
267 | board_views += 1
268 |
269 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
270 | solder_joints += 1
271 |
272 | if len(solder_joint_obj.slice_dict.keys()) == 1:
273 | joints_with_1_slices += 1
274 | if len(solder_joint_obj.slice_dict.keys()) == 2:
275 | joints_with_2_slices += 1
276 | if len(solder_joint_obj.slice_dict.keys()) == 3:
277 | joints_with_3_slices += 1
278 | if len(solder_joint_obj.slice_dict.keys()) == 4:
279 | joints_with_4_slices += 1
280 | if len(solder_joint_obj.slice_dict.keys()) == 5:
281 | joints_with_5_slices += 1
282 | if len(solder_joint_obj.slice_dict.keys()) == 6:
283 | joints_with_6_slices += 1
284 |
285 | label = solder_joint_obj.defect_name
286 | if label == 'normal':
287 | normal_defects += 1
288 | if label == 'missing':
289 | missing_defects += 1
290 | if label == 'insufficient':
291 | insuf_defects += 1
292 | if label == 'short':
293 | short_defects += 1
294 |
295 | print('*****incorrect view details*****')
296 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects,
297 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects)
298 | print('joints_with_1_slices:', joints_with_1_slices, 'joints_with_2_slices:', joints_with_1_slices,
299 | 'joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices,
300 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices)
301 |
302 | board_views = 0
303 | solder_joints = 0
304 | missing_defects = 0
305 | short_defects = 0
306 | insuf_defects = 0
307 | normal_defects = 0
308 |
309 | joints_with_3_slices = 0
310 | joints_with_4_slices = 0
311 | joints_with_5_slices = 0
312 | joints_with_6_slices = 0
313 |
314 | for board_view_obj in self.board_view_dict.values():
315 | if not board_view_obj.is_incorrect_view:
316 | board_views += 1
317 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
318 | if solder_joint_obj.is_square:
319 | solder_joints += 1
320 |
321 | if len(solder_joint_obj.slice_dict.keys()) == 3:
322 | joints_with_3_slices += 1
323 | if len(solder_joint_obj.slice_dict.keys()) == 4:
324 | joints_with_4_slices += 1
325 | if len(solder_joint_obj.slice_dict.keys()) == 5:
326 | joints_with_5_slices += 1
327 | if len(solder_joint_obj.slice_dict.keys()) == 6:
328 | joints_with_6_slices += 1
329 |
330 | label = solder_joint_obj.defect_name
331 | if label == 'normal':
332 | normal_defects += 1
333 | if label == 'missing':
334 | missing_defects += 1
335 | if label == 'insufficient':
336 | insuf_defects += 1
337 | if label == 'short':
338 | short_defects += 1
339 |
340 | print('*****correct square roi details*****')
341 | print('board_views:', board_views, 'solder_joints:', solder_joints, 'missing_defects:', missing_defects,
342 | 'short_defects:', short_defects, 'insuf_defects:', insuf_defects, 'normal_defects:', normal_defects)
343 | print('joints_with_3_slices:', joints_with_3_slices, 'joints_with_4_slices:', joints_with_4_slices,
344 | 'joints_with_5_slices', joints_with_5_slices, 'joints_with_6_slices', joints_with_6_slices)
345 |
346 | with open('./data/rois_first_four_slices_2d.p', 'rb') as handle:
347 | img_dict = pickle.load(handle)
348 | missing, short, insuf, normal = 0, 0, 0, 0
349 | for val_list in img_dict.values():
350 | if val_list[1] == 'missing':
351 | missing += 1
352 | if val_list[1] == 'short':
353 | short += 1
354 | if val_list[1] == 'insufficient':
355 | insuf += 1
356 | if val_list[1] == 'normal':
357 | normal += 1
358 |
359 | print('after concat 4 slices, missing:', missing, 'short:', short, 'insufficient:', insuf, 'normal:', normal)
360 |
361 | with open('./data/rois_all_slices_3d.p', 'rb') as handle:
362 | img_dict = pickle.load(handle)
363 | missing, short, insuf, normal = 0, 0, 0, 0
364 | for val_list in img_dict.values():
365 | if val_list[1] == 'missing':
366 | missing += 1
367 | if val_list[1] == 'short':
368 | short += 1
369 | if val_list[1] == 'insufficient':
370 | insuf += 1
371 | if val_list[1] == 'normal':
372 | normal += 1
373 |
374 | print('after concat all slices, missing:', missing, 'short:', short, 'insufficient:', insuf, 'normal:', normal)
375 |
376 | def print_solder_joint_resolution_details(self):
377 | res_dict = {}
378 |
379 | for board_view_obj in self.board_view_dict.values():
380 | if not board_view_obj.is_incorrect_view:
381 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
382 | if solder_joint_obj.is_square:
383 | res = solder_joint_obj.x_max - solder_joint_obj.x_min
384 | if res in res_dict.keys():
385 | res_dict[res] = res_dict[res] + 1
386 | else:
387 | res_dict[res] = 1
388 |
389 | print('*****resolution details*****')
390 | plt.figure()
391 | plt.title('Distribution of resolution values')
392 | plt.xlabel('Resolution')
393 | plt.ylabel('Number of solder joints')
394 | plt.bar(res_dict.keys(), res_dict.values(), 1.0, color='g')
395 | plt.show()
396 |
397 | sum = 0
398 | count = 0
399 | for key in res_dict.keys():
400 | sum += key * res_dict[key]
401 | count += res_dict[key]
402 | print('mean:', sum/count)
403 |
404 | def save_concat_images_first_four_slices_2d(self):
405 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/short/')
406 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/insufficient/')
407 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/missing/')
408 | chk_n_mkdir('./data/roi_concatenated_four_slices_2d/normal/')
409 | img_count = 0
410 | for board_view_obj in self.board_view_dict.values():
411 | if not board_view_obj.is_incorrect_view:
412 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
413 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
414 | concat_image, label = solder_joint_obj.concat_first_four_slices_2d()
415 | if concat_image is not None:
416 | img_count += 1
417 | destination_image_path = './data/roi_concatenated_four_slices/' + label + '/' + str(img_count) \
418 | + '.jpg'
419 | cv2.imwrite(destination_image_path, concat_image)
420 | logging.debug('saving concatenated image, joint type: %s', label)
421 | logging.info('saved images of concatenated 4 slices in 2d')
422 |
423 | def save_concat_images_first_four_slices_2d_pickle(self):
424 | image_dict = {}
425 | img_count = 0
426 | for board_view_obj in self.board_view_dict.values():
427 | if not board_view_obj.is_incorrect_view:
428 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
429 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
430 | concat_image, label = solder_joint_obj.concat_first_four_slices_2d()
431 | if concat_image is not None:
432 | img_count += 1
433 | image_dict[img_count] = [concat_image, label]
434 |
435 | with open('./data/rois_first_four_slices_2d.p', 'wb') as handle:
436 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
437 | logging.info('saved images of concatenated 4 slices in 2d to pickle')
438 |
439 | def save_concat_images_first_four_slices_2d_more_pickle(self):
440 | image_dict = {}
441 | img_count = 0
442 | for board_view_obj in self.board_view_dict.values():
443 | if not board_view_obj.is_incorrect_view:
444 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
445 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
446 | concat_images, label = solder_joint_obj.concat_first_four_slices_2d_4rotated()
447 | if concat_images is not None:
448 | img_count += 1
449 | image_dict[img_count] = [concat_images[0], label]
450 | img_count += 1
451 | image_dict[img_count] = [concat_images[1], label]
452 | img_count += 1
453 | image_dict[img_count] = [concat_images[2], label]
454 | img_count += 1
455 | image_dict[img_count] = [concat_images[3], label]
456 |
457 | with open('./data/rois_first_four_slices_2d_rotated.p', 'wb') as handle:
458 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
459 | logging.info('saved images of concatenated 4 slices in 2d to pickle')
460 |
461 | def save_concat_images_first_four_slices_2d_more_normal_pickle(self):
462 | image_dict = {}
463 | img_count = 0
464 | for board_view_obj in self.board_view_dict.values():
465 | if not board_view_obj.is_incorrect_view:
466 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
467 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
468 | concat_images, label = solder_joint_obj.concat_first_four_slices_2d_4rotated()
469 | if concat_images is not None:
470 | if label == 'normal':
471 | img_count += 1
472 | image_dict[img_count] = [concat_images[0], label]
473 | img_count += 1
474 | image_dict[img_count] = [concat_images[1], label]
475 | img_count += 1
476 | image_dict[img_count] = [concat_images[2], label]
477 | img_count += 1
478 | image_dict[img_count] = [concat_images[3], label]
479 | else:
480 | img_count += 1
481 | image_dict[img_count] = [concat_images[0], label]
482 |
483 | with open('./data/rois_first_four_slices_2d_more_normal.p', 'wb') as handle:
484 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
485 | logging.info('saved images of concatenated 4 slices in 2d to pickle')
486 |
487 | def save_concat_images_first_four_slices_list_more_normal_pickle(self):
488 | image_dict = {}
489 | img_count = 0
490 | for board_view_obj in self.board_view_dict.values():
491 | if not board_view_obj.is_incorrect_view:
492 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
493 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
494 | concat_images, label = solder_joint_obj.concat_first_four_slices_list_4rotated()
495 | if concat_images is not None:
496 | if label == 'normal':
497 | img_count += 1
498 | image_dict[img_count] = [concat_images[0], label]
499 | img_count += 1
500 | image_dict[img_count] = [concat_images[1], label]
501 | img_count += 1
502 | image_dict[img_count] = [concat_images[2], label]
503 | img_count += 1
504 | image_dict[img_count] = [concat_images[3], label]
505 | else:
506 | img_count += 1
507 | image_dict[img_count] = [concat_images[0], label]
508 |
509 | with open('./data/rois_first_four_slices_list_more_normal.p', 'wb') as handle:
510 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
511 | logging.info('saved images of concatenated 4 slices in 2d to pickle')
512 |
513 | def save_concat_images_first_four_slices_list_rgb_pickle(self):
514 | image_dict = {}
515 | img_count = 0
516 | for board_view_obj in self.board_view_dict.values():
517 | if not board_view_obj.is_incorrect_view:
518 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
519 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
520 | stacked_list, label = solder_joint_obj.concat_first_four_slices_list_rgb()
521 | if stacked_list is not None:
522 | img_count += 1
523 | image_dict[img_count] = [stacked_list, label]
524 |
525 | with open('./data/rois_first_four_slices_list_rgb.p', 'wb') as handle:
526 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
527 | logging.info('saved images of concatenated 4 slices in a list to pickle')
528 |
529 | @staticmethod
530 | def save_4slices_individual_pickle(self):
531 | chk_n_mkdir('./data/joints_4slices')
532 | with open('./data/rois_first_four_slices_list_rgb.p', 'rb') as handle:
533 | image_dict = pickle.load(handle)
534 |
535 | train_key_list = []
536 | test_key_list = []
537 | partition = {'train': train_key_list, 'test': test_key_list}
538 |
539 | classes = ['missing', 'insufficient', 'short', 'normal']
540 | integer_mapping_dict = {x: i for i, x in enumerate(classes)}
541 | missing_key_list = []
542 | insufficient_key_list = []
543 | short_key_list = []
544 | normal_key_list = []
545 |
546 | labels = {}
547 | for image_key in image_dict.keys():
548 | label = image_dict[image_key][1]
549 | pickle_file_name = str(image_key) + '.p'
550 | labels[pickle_file_name] = integer_mapping_dict[label]
551 | if label == 'missing':
552 | missing_key_list.append(image_key)
553 | elif label == 'insufficient':
554 | insufficient_key_list.append(image_key)
555 | elif label == 'short':
556 | short_key_list.append(image_key)
557 | elif label == 'normal':
558 | normal_key_list.append(image_key)
559 |
560 | pickle_file_name = './data/joints_4slices/' + str(image_key) + '.p'
561 | with open(pickle_file_name, 'wb') as handle:
562 | pickle.dump(image_dict[image_key][0], handle, protocol=pickle.HIGHEST_PROTOCOL)
563 |
564 | for key_list in [missing_key_list, insufficient_key_list, short_key_list, normal_key_list]:
565 | num_train = int(len(key_list) * 0.9)
566 | num_images = len(key_list)
567 | train_indices = random.sample(range(num_images), num_train)
568 | test_indices = []
569 | for index in range(num_images):
570 | if index not in train_indices:
571 | test_indices.append(index)
572 |
573 | for index in train_indices:
574 | pickle_file_name = str(key_list[index]) + '.p'
575 | train_key_list.append(pickle_file_name)
576 |
577 | for index in test_indices:
578 | pickle_file_name = str(key_list[index]) + '.p'
579 | test_key_list.append(pickle_file_name)
580 |
581 | if isinstance(next(iter(image_dict.values()))[0], list):
582 | image_shape = next(iter(image_dict.values()))[0][0].shape
583 | print('slices list selected, image shape:', image_shape)
584 | else:
585 | image_shape = next(iter(image_dict.values()))[0].shape
586 |
587 | pickle_file_name = './data/joints_4slices/details_list.p'
588 | with open(pickle_file_name, 'wb') as handle:
589 | details_list = [partition, labels, integer_mapping_dict, image_shape]
590 | pickle.dump(details_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
591 |
592 | def save_concat_images_first_four_slices_3d_pickle(self):
593 | image_dict = {}
594 | img_count = 0
595 | for board_view_obj in self.board_view_dict.values():
596 | if not board_view_obj.is_incorrect_view:
597 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
598 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
599 | stacked_np_array, label = solder_joint_obj.concat_first_four_slices_3d()
600 | if stacked_np_array is not None:
601 | img_count += 1
602 | image_dict[img_count] = [stacked_np_array, label]
603 |
604 | with open('./data/rois_first_four_slices_3d.p', 'wb') as handle:
605 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
606 | logging.info('saved images of concatenated 4 slices in 3d to pickle')
607 |
608 | def save_concat_images_all_slices_2d_pickle(self):
609 | image_dict = {}
610 | img_count = 0
611 | for board_view_obj in self.board_view_dict.values():
612 | if not board_view_obj.is_incorrect_view:
613 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
614 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
615 | concat_image, label = solder_joint_obj.concat_pad_all_slices_2d()
616 | if concat_image is not None:
617 | img_count += 1
618 | image_dict[img_count] = [concat_image, label]
619 |
620 | with open('./data/rois_all_slices_2d.p', 'wb') as handle:
621 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
622 | logging.info('saved images of concatenated 6 slices in 2d to pickle')
623 |
624 | def save_concat_images_all_slices_3d_pickle(self):
625 | image_dict = {}
626 | img_count = 0
627 | for board_view_obj in self.board_view_dict.values():
628 | if not board_view_obj.is_incorrect_view:
629 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
630 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
631 | stacked_np_array, label = solder_joint_obj.concat_pad_all_slices_3d()
632 | if stacked_np_array is not None:
633 | img_count += 1
634 | image_dict[img_count] = [stacked_np_array, label]
635 |
636 | with open('./data/rois_all_slices_3d.p', 'wb') as handle:
637 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
638 | logging.info('saved images of concatenated 6 slices in 3d to pickle')
639 |
640 | def save_concat_images_all_slices_inverse_3d_pickle(self):
641 | image_dict = {}
642 | img_count = 0
643 | for board_view_obj in self.board_view_dict.values():
644 | if not board_view_obj.is_incorrect_view:
645 | logging.debug('Concatenating images in board_view_obj: %s', board_view_obj.view_identifier)
646 | for solder_joint_obj in board_view_obj.solder_joint_dict.values():
647 | stacked_np_array, label = solder_joint_obj.concat_pad_all_slices_inverse_3d()
648 | if stacked_np_array is not None:
649 | img_count += 1
650 | image_dict[img_count] = [stacked_np_array, label]
651 |
652 | with open('./data/rois_all_slices_inverse_3d.p', 'wb') as handle:
653 | pickle.dump(image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
654 | logging.info('saved images of concatenated 6 inverse slices in 3d to pickle')
655 |
--------------------------------------------------------------------------------
/solder_joint_container_obj.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chinthysl/AXI_PCB_defect_detection/fe81c1d8e144ce5434aee78548cdc026d0d53d1c/solder_joint_container_obj.p
--------------------------------------------------------------------------------
/utils_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 | from sklearn.metrics import accuracy_score
7 | from sklearn.metrics import precision_score
8 | from sklearn.metrics import recall_score
9 |
10 |
11 | def chk_n_mkdir(path):
12 | if not os.path.exists(path):
13 | os.makedirs(path)
14 |
15 |
16 | def calculate_metrics(y_true, y_pred, duration, y_true_val=None, y_pred_val=None):
17 | res = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0],
18 | columns=['precision', 'accuracy', 'recall', 'duration'])
19 | res['precision'] = precision_score(y_true, y_pred, average='macro')
20 | res['accuracy'] = accuracy_score(y_true, y_pred)
21 |
22 | if y_true_val is not None:
23 | # this is useful when transfer learning is used with cross validation
24 | res['accuracy_val'] = accuracy_score(y_true_val, y_pred_val)
25 |
26 | res['recall'] = recall_score(y_true, y_pred, average='macro')
27 | res['duration'] = duration
28 | return res
29 |
30 |
31 | def plot_epochs_metric(hist, file_name, metric='loss'):
32 | plt.figure()
33 | plt.plot(hist.history[metric])
34 | plt.plot(hist.history['val_' + metric])
35 | plt.title('model ' + metric)
36 | plt.ylabel(metric, fontsize='large')
37 | plt.xlabel('epoch', fontsize='large')
38 | plt.legend(['train', 'val'], loc='upper left')
39 | plt.savefig(file_name, bbox_inches='tight')
40 | plt.close()
41 |
42 |
43 | def save_logs(output_directory, hist, y_pred, y_true, duration, lr=True, y_true_val=None, y_pred_val=None):
44 | hist_df = pd.DataFrame(hist.history)
45 | hist_df.to_csv(output_directory + '/history.csv', index=False)
46 |
47 | df_metrics = calculate_metrics(y_true, y_pred, duration, y_true_val, y_pred_val)
48 | df_metrics.to_csv(output_directory + '/df_metrics.csv', index=False)
49 |
50 | index_best_model = hist_df['loss'].idxmin()
51 | row_best_model = hist_df.loc[index_best_model]
52 |
53 | df_best_model = pd.DataFrame(data=np.zeros((1, 6), dtype=np.float), index=[0],
54 | columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc',
55 | 'best_model_val_acc', 'best_model_learning_rate', 'best_model_nb_epoch'])
56 |
57 | df_best_model['best_model_train_loss'] = row_best_model['loss']
58 | df_best_model['best_model_val_loss'] = row_best_model['val_loss']
59 | df_best_model['best_model_train_acc'] = row_best_model['acc']
60 | df_best_model['best_model_val_acc'] = row_best_model['val_acc']
61 | if lr == True:
62 | df_best_model['best_model_learning_rate'] = row_best_model['lr']
63 | df_best_model['best_model_nb_epoch'] = index_best_model
64 |
65 | df_best_model.to_csv(output_directory + '/df_best_model.csv', index=False)
66 |
67 | # plot losses
68 | plot_epochs_metric(hist, output_directory + '/epochs_loss.png')
69 |
70 | return df_metrics
71 |
--------------------------------------------------------------------------------
/utils_datagen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import keras
3 | import os
4 | import tensorflow as tf
5 | from keras.callbacks import TensorBoard
6 | import pickle
7 | import random
8 | import cv2
9 |
10 |
11 | def create_partition_label_dict_multi_class(pickle_file):
12 | with open(pickle_file, 'rb') as handle:
13 | image_dict = pickle.load(handle)
14 |
15 | train_key_list = []
16 | test_key_list = []
17 | partition = {'train': train_key_list, 'test': test_key_list}
18 |
19 | classes = ['missing', 'insufficient', 'short', 'normal']
20 | integer_mapping_dict = {x: i for i, x in enumerate(classes)}
21 | missing_key_list = []
22 | insufficient_key_list = []
23 | short_key_list = []
24 | normal_key_list = []
25 |
26 | labels = {}
27 | for image_key in image_dict.keys():
28 | label = image_dict[image_key][1]
29 | labels[image_key] = integer_mapping_dict[label]
30 | if label == 'missing':
31 | missing_key_list.append(image_key)
32 | elif label == 'insufficient':
33 | insufficient_key_list.append(image_key)
34 | elif label == 'short':
35 | short_key_list.append(image_key)
36 | elif label == 'normal':
37 | normal_key_list.append(image_key)
38 |
39 | for key_list in [missing_key_list, insufficient_key_list, short_key_list, normal_key_list]:
40 | num_train = int(len(key_list) * 0.9)
41 | num_images = len(key_list)
42 | train_indices = random.sample(range(num_images), num_train)
43 | test_indices = []
44 | for index in range(num_images):
45 | if index not in train_indices:
46 | test_indices.append(index)
47 |
48 | for index in train_indices:
49 | train_key_list.append(key_list[index])
50 |
51 | for index in test_indices:
52 | test_key_list.append(key_list[index])
53 |
54 | if isinstance(next(iter(image_dict.values()))[0], list):
55 | image_shape = next(iter(image_dict.values()))[0][0].shape
56 | print('slices list selected, image shape:', image_shape)
57 | else:
58 | image_shape = next(iter(image_dict.values()))[0].shape
59 |
60 | return partition, labels, integer_mapping_dict, image_shape
61 |
62 |
63 | def create_partition_label_dict_binary(pickle_file):
64 | with open(pickle_file, 'rb') as handle:
65 | image_dict = pickle.load(handle)
66 |
67 | train_key_list = []
68 | test_key_list = []
69 | partition = {'train': train_key_list, 'test': test_key_list}
70 |
71 | integer_mapping_dict = {'missing': 0, 'insufficient': 0, 'short': 0, 'normal': 1}
72 | defective_key_list = []
73 | normal_key_list = []
74 |
75 | labels = {}
76 | for image_key in image_dict.keys():
77 | label = image_dict[image_key][1]
78 | labels[image_key] = integer_mapping_dict[label]
79 | if label == 'missing' or label == 'insufficient' or label == 'short':
80 | defective_key_list.append(image_key)
81 | elif label == 'normal':
82 | normal_key_list.append(image_key)
83 |
84 | for key_list in [defective_key_list, normal_key_list]:
85 | num_train = int(len(key_list) * 0.9)
86 | num_images = len(key_list)
87 | train_indices = random.sample(range(num_images), num_train)
88 | test_indices = []
89 | for index in range(num_images):
90 | if index not in train_indices:
91 | test_indices.append(index)
92 |
93 | for index in train_indices:
94 | train_key_list.append(key_list[index])
95 |
96 | for index in test_indices:
97 | test_key_list.append(key_list[index])
98 |
99 | if isinstance(next(iter(image_dict.values()))[0], list):
100 | image_shape = next(iter(image_dict.values()))[0][0].shape
101 | print('slices list selected, image shape:', image_shape)
102 | else:
103 | image_shape = next(iter(image_dict.values()))[0].shape
104 |
105 | integer_mapping_dict = {'defect': 0, 'normal': 1}
106 | return partition, labels, integer_mapping_dict, image_shape
107 |
108 |
109 | class DataGenerator(keras.utils.Sequence):
110 | 'Generates data for Keras'
111 | def __init__(self, pickle_file, image_keys, labels, integer_mapping_dict, batch_size, dim, n_classes, shuffle=True):
112 | 'Initialization'
113 | with open(pickle_file, 'rb') as handle:
114 | image_dict = pickle.load(handle)
115 | self.image_dict = image_dict
116 | self.labels = labels
117 | self.image_keys = image_keys
118 | self.integer_mapping_dict = integer_mapping_dict
119 | self.dim = dim
120 | self.batch_size = batch_size
121 | self.n_classes = n_classes
122 | self.shuffle = shuffle
123 | self.on_epoch_end()
124 |
125 | def gen_class_labels(self):
126 | classes = []
127 | end_idx = len(self.image_keys) // self.batch_size * self.batch_size
128 | for i in self.image_keys[0:end_idx]:
129 | classes.append(self.labels[i])
130 | return classes
131 |
132 | def __len__(self):
133 | 'Denotes the number of batches per epoch'
134 | return int(np.floor(len(self.image_keys) / self.batch_size))
135 |
136 | def __getitem__(self, index):
137 | 'Generate one batch of data'
138 | # Generate indexes of the batch
139 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
140 |
141 | # Find list of IDs
142 | image_keys_temp = [self.image_keys[k] for k in indexes]
143 |
144 | # Generate data
145 | X, y = self.__data_generation(image_keys_temp)
146 | return X, y
147 |
148 | # X1, X2, X3, X4, y = self.__data_generation(image_keys_temp)
149 | # return [X1, X2, X3, X4], y
150 |
151 | def on_epoch_end(self):
152 | 'Updates indexes after each epoch'
153 | self.indexes = np.arange(len(self.image_keys))
154 | if self.shuffle == True:
155 | np.random.shuffle(self.indexes)
156 |
157 | def __data_generation(self, image_keys_temp):
158 | # Generates data containing batch_size samples, X : (n_samples, *dim, n_channels)
159 |
160 | # # Initialization
161 | # X = np.empty((self.batch_size, *self.dim))
162 | # y = np.empty(self.batch_size, dtype=int)
163 | #
164 | # # Generate data
165 | # for i, ID in enumerate(image_keys_temp):
166 | # # Store sample
167 | # image = self.image_dict[ID][0]
168 | # # if len(image.shape) == 2:
169 | # # image = np.expand_dims(image, axis=2)
170 | # X[i, ] = image
171 | #
172 | # # Store class
173 | # y[i] = self.labels[ID]
174 | #
175 | # return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
176 |
177 | # # Initialization
178 | # X = np.empty((self.batch_size, *self.dim))
179 | # y = np.empty(self.batch_size, dtype=int)
180 | #
181 | # # Generate data
182 | # for i, ID in enumerate(image_keys_temp):
183 | # # Store sample
184 | # image = self.image_dict[ID][0]
185 | # image = cv2.resize(image, dsize=(128, 128), interpolation=cv2.INTER_AREA)
186 | # image = np.expand_dims(image, axis=2)
187 | # X[i, ] = image
188 | #
189 | # # Store class
190 | # y[i] = self.labels[ID]
191 | #
192 | # return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
193 |
194 | # # Initialization
195 | # X1 = np.empty((self.batch_size, 128, 128, 1))
196 | # X2 = np.empty((self.batch_size, 128, 128, 1))
197 | # X3 = np.empty((self.batch_size, 128, 128, 1))
198 | # X4 = np.empty((self.batch_size, 128, 128, 1))
199 | # y = np.empty(self.batch_size, dtype=int)
200 | #
201 | # # Generate data
202 | # for i, ID in enumerate(image_keys_temp):
203 | # # Store sample
204 | # images = self.image_dict[ID][0]
205 | # # print(images[:,:,0,:].shape)
206 | # image1 = images[:,:,0,:]
207 | # image2 = images[:,:,1,:]
208 | # image3 = images[:,:,2,:]
209 | # image4 = images[:,:,3,:]
210 | # X1[i,] = image1
211 | # X2[i,] = image2
212 | # X3[i,] = image3
213 | # X4[i,] = image4
214 | #
215 | # # Store class
216 | # y[i] = self.labels[ID]
217 | #
218 | # return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes)
219 |
220 | # Initialization
221 | X1 = np.empty((self.batch_size, *self.dim))
222 | X2 = np.empty((self.batch_size, *self.dim))
223 | X3 = np.empty((self.batch_size, *self.dim))
224 | X4 = np.empty((self.batch_size, *self.dim))
225 | y = np.empty(self.batch_size, dtype=int)
226 |
227 | # Generate data
228 | for i, ID in enumerate(image_keys_temp):
229 | # Store sample
230 | slices_list = self.image_dict[ID][0]
231 |
232 | image1 = slices_list[0]
233 | image2 = slices_list[1]
234 | image3 = slices_list[2]
235 | image4 = slices_list[3]
236 | X1[i,] = image1
237 | X2[i,] = image2
238 | X3[i,] = image3
239 | X4[i,] = image4
240 |
241 | # Store class
242 | y[i] = self.labels[ID]
243 |
244 | return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes)
245 |
246 |
247 | class DataGeneratorIndividual(keras.utils.Sequence):
248 | 'Generates data for Keras'
249 | def __init__(self, pickle_file_dir, image_keys, labels, integer_mapping_dict, batch_size, dim, n_classes, shuffle=True):
250 | 'Initialization'
251 | self.pickle_file_dir = pickle_file_dir
252 | self.labels = labels
253 | self.image_keys = image_keys
254 | self.integer_mapping_dict = integer_mapping_dict
255 | self.dim = dim
256 | self.batch_size = batch_size
257 | self.n_classes = n_classes
258 | self.shuffle = shuffle
259 | self.on_epoch_end()
260 |
261 | def gen_class_labels(self):
262 | classes = []
263 | end_idx = len(self.image_keys) // self.batch_size * self.batch_size
264 | for i in self.image_keys[0:end_idx]:
265 | classes.append(self.labels[i])
266 | return classes
267 |
268 | def __len__(self):
269 | 'Denotes the number of batches per epoch'
270 | return int(np.floor(len(self.image_keys) / self.batch_size))
271 |
272 | def __getitem__(self, index):
273 | 'Generate one batch of data'
274 | # Generate indexes of the batch
275 | indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
276 |
277 | # Find list of IDs
278 | image_keys_temp = [self.image_keys[k] for k in indexes]
279 |
280 | # Generate data
281 | X, y = self.__data_generation(image_keys_temp)
282 | return X, y
283 |
284 | # X1, X2, X3, X4, y = self.__data_generation(image_keys_temp)
285 | # return [X1, X2, X3, X4], y
286 |
287 | def on_epoch_end(self):
288 | 'Updates indexes after each epoch'
289 | self.indexes = np.arange(len(self.image_keys))
290 | if self.shuffle == True:
291 | np.random.shuffle(self.indexes)
292 |
293 | def __data_generation(self, image_keys_temp):
294 | X1 = np.empty((self.batch_size, *self.dim))
295 | X2 = np.empty((self.batch_size, *self.dim))
296 | X3 = np.empty((self.batch_size, *self.dim))
297 | X4 = np.empty((self.batch_size, *self.dim))
298 | y = np.empty(self.batch_size, dtype=int)
299 |
300 | # Generate data
301 | for i, ID in enumerate(image_keys_temp):
302 | # Store sample
303 | with open(self.pickle_file_dir + ID, 'rb') as handle:
304 | slices_list = pickle.load(handle)
305 |
306 | image1 = slices_list[0]
307 | image2 = slices_list[1]
308 | image3 = slices_list[2]
309 | image4 = slices_list[3]
310 | X1[i,] = image1
311 | X2[i,] = image2
312 | X3[i,] = image3
313 | X4[i,] = image4
314 |
315 | # Store class
316 | y[i] = self.labels[ID]
317 |
318 | return [X1, X2, X3, X4], keras.utils.to_categorical(y, num_classes=self.n_classes)
319 |
320 |
321 | class TrainValTensorBoard(TensorBoard):
322 | def __init__(self, log_dir='./log', **kwargs):
323 | self.log_dir = log_dir
324 | # Make the original `TensorBoard` log to a subdirectory 'training'
325 | training_log_dir = os.path.join(self.log_dir, '/training')
326 | super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
327 |
328 | # Log the validation metrics to a separate subdirectory
329 | self.val_log_dir = os.path.join(self.log_dir, '/validation')
330 |
331 | def set_model(self, model):
332 | # Setup writer for validation metrics
333 | self.val_writer = tf.summary.FileWriter(self.val_log_dir)
334 | super(TrainValTensorBoard, self).set_model(model)
335 |
336 | def on_epoch_end(self, epoch, logs=None):
337 | # Pop the validation logs and handle them separately with
338 | # `self.val_writer`. Also rename the keys so that they can
339 | # be plotted on the same figure with the training metrics
340 | logs = logs or {}
341 | val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
342 | for name, value in val_logs.items():
343 | summary = tf.Summary()
344 | summary_value = summary.value.add()
345 | summary_value.simple_value = value.item()
346 | summary_value.tag = name
347 | self.val_writer.add_summary(summary, epoch)
348 | self.val_writer.flush()
349 |
350 | # Pass the remaining logs to `TensorBoard.on_epoch_end`
351 | logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
352 | super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)
353 |
354 | def on_train_end(self, logs=None):
355 | super(TrainValTensorBoard, self).on_train_end(logs)
356 | self.val_writer.close()
357 |
--------------------------------------------------------------------------------