├── Models ├── README.md └── logs.json ├── .gitignore ├── requirements.txt ├── config.py ├── Training ├── README.md └── Train.ipynb ├── Dataset └── README.md ├── api_server.py ├── infer.py ├── README.md ├── Model.py └── Infer.ipynb /Models/README.md: -------------------------------------------------------------------------------- 1 | Pretrained weights can be downloaded from [here](----) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | .idea 3 | __pycache__ 4 | .ipynb_checkpoints 5 | Sample/ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.6.0 2 | torch==1.5.0 3 | Flask==1.1.2 4 | Pillow==8.2.0 5 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Full path to the model weights on your machine 2 | MODEL_PATH = '/home/jeetu/Project/VehicleColorA/Exp1/model_3.pt' -------------------------------------------------------------------------------- /Training/README.md: -------------------------------------------------------------------------------- 1 | Run this notebook to train your VehicleColorRecognition Model. 2 | 3 | Note : For training, set the following paths according to your own machine 4 | 5 | - Dataset Directory 6 | - Path for saving the trained model -------------------------------------------------------------------------------- /Dataset/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | Download and extract dataset in this directory. 4 | 5 | - Dataset -[Link](https://drive.google.com/file/d/1n8Ja6g5eO82mbRlsTXkdVXNMPpApLd5K/view?usp=sharing) 6 | 7 | 8 | 9 | Link to Original Project - [Link](http://cloud.eic.hust.edu.cn:8071/~pchen/project.html) 10 | 11 | Original Paper 12 | 13 | P. Chen, X. Bai and W. Liu, "Vehicle Color Recognition on Urban Road by Feature Context," in IEEE Transactions on Intelligent Transportation Systems, vol. 15, no. 5, pp. 2340-2346, Oct. 2014, doi: 10.1109/TITS.2014.2308897. -------------------------------------------------------------------------------- /api_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | 4 | from flask import Flask, jsonify, request 5 | from PIL import Image 6 | 7 | from infer import infer 8 | 9 | app = Flask(__name__) 10 | 11 | @app.route('/', methods=['POST']) 12 | def image_handler(): 13 | bio = io.BytesIO() 14 | request.files['image'].save(bio) 15 | image = Image.open(bio) 16 | # img = np.frombuffer(bio.getvalue(), dtype='uint8') 17 | result = infer(image) 18 | return jsonify({'result': result}) 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--port', default = 8769 , type = int) 23 | args = parser.parse_args() 24 | app.debug = True 25 | app.run('0.0.0.0', args.port) 26 | 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from PIL import Image 5 | from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor 6 | 7 | from config import MODEL_PATH 8 | from Model import VehicleColorModel 9 | 10 | # Setting Up the Labels 11 | labels = ['black', 'blue' , 'cyan' , 'gray' , 'green' , 'red' , 'white' , 'yellow'] 12 | def decode_label(index): 13 | return labels[index] 14 | 15 | def encode_label_from_path(path): 16 | for index,value in enumerate(labels): 17 | if value in path: 18 | return index 19 | 20 | model = VehicleColorModel() 21 | model.load_state_dict(torch.load(MODEL_PATH , map_location=torch.device("cpu"))) 22 | transforms = Compose([Resize(224), CenterCrop(224), ToTensor()]) 23 | 24 | def infer(image): 25 | image = transforms(image) 26 | image = image.unsqueeze(0) 27 | pred = model.forward(image).argmax(dim = 1) 28 | class_label = decode_label(pred) 29 | return class_label 30 | 31 | 32 | if __name__ == "__main__": 33 | image = Image.open('/path/to/an/image').convert('RGB') 34 | print(infer(image)) 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vehicle Color Recognition using CNN 2 | Vehicle information recognition is a key component of Intelligent Transportation Systems. Color plays an important role in vehicle identification. As a vehicle has its inner structure, the main challenge of vehicle color recognition is to select the region of interest (ROI) for recognizing its dominant color. 3 | 4 | ## Training 5 | > Note : For inference tasks only, the weights can be [downloaded](https://drive.google.com/drive/folders/1iBAn9IwWXY8Ur4JA89ZkIOP4MSjtDea0?usp=sharing) and training is not required. 6 | 7 | In case you want to train the model (preferably, with addition data), edit the `Train.ipynb` notebook in `Training` directory. 8 | 9 | ## Running 10 | Download weights from [here](https://drive.google.com/drive/folders/1iBAn9IwWXY8Ur4JA89ZkIOP4MSjtDea0?usp=sharing) 11 | Edit `config.py` file and modify `MODEL_PATH` variable to the full path of the downloaded weights. 12 | 13 | Run the server using `python3 api_server.py --port 1234`. 14 | 15 | > Note : The server assumes that the images will be sent as a multipart file using 'image' as the `key`. 16 | 17 | ## References 18 | 19 | Link to Original Project - [Link](http://cloud.eic.hust.edu.cn:8071/~pchen/project.html) 20 | 21 | References 22 | P. Chen, X. Bai and W. Liu, "Vehicle Color Recognition on Urban Road by Feature Context," in IEEE Transactions on Intelligent Transportation Systems, vol. 15, no. 5, pp. 2340-2346, Oct. 2014, doi: 10.1109/TITS.2014.2308897. 23 | 24 | "Vehicle Color Recognition using Convolutional Neural Network",
25 | Reza Fuad Rachmadi and I Ketut Eddy Purnama
26 | https://arxiv.org/abs/1510.07391 27 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VehicleColorModel(nn.Module): 6 | def __init__(self): 7 | super(VehicleColorModel, self).__init__() 8 | 9 | self.top_conv1 = nn.Sequential( 10 | nn.Conv2d(3,48, kernel_size=(11,11) , stride=(4,4)), 11 | nn.ReLU(), 12 | nn.BatchNorm2d(48), 13 | nn.MaxPool2d(kernel_size=3 , stride=2) 14 | ) 15 | 16 | # first top convolution layer after split 17 | self.top_top_conv2 = nn.Sequential( 18 | 19 | # 1-1 conv layer 20 | nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1), 21 | nn.ReLU(), 22 | nn.BatchNorm2d(64), 23 | nn.MaxPool2d(kernel_size=3, stride=2) 24 | ) 25 | 26 | self.top_bot_conv2 = nn.Sequential( 27 | 28 | # 1-1 conv layer 29 | nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1), 30 | nn.ReLU(), 31 | nn.BatchNorm2d(64), 32 | nn.MaxPool2d(kernel_size=3, stride=2) 33 | ) 34 | 35 | 36 | # need a concat 37 | 38 | # after concat 39 | self.top_conv3 = nn.Sequential( 40 | # 1-1 conv layer 41 | nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1), 42 | nn.ReLU() 43 | ) 44 | 45 | # fourth top convolution layer 46 | # split feature map by half 47 | self.top_top_conv4 = nn.Sequential( 48 | # 1-1 conv layer 49 | nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1), 50 | nn.ReLU() 51 | ) 52 | 53 | self.top_bot_conv4 = nn.Sequential( 54 | # 1-1 conv layer 55 | nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1), 56 | nn.ReLU() 57 | ) 58 | 59 | 60 | # fifth top convolution layer 61 | self.top_top_conv5 = nn.Sequential( 62 | # 1-1 conv layer 63 | nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1), 64 | nn.ReLU(), 65 | nn.MaxPool2d(kernel_size=3, stride=2) 66 | ) 67 | self.top_bot_conv5 = nn.Sequential( 68 | # 1-1 conv layer 69 | nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1), 70 | nn.ReLU(), 71 | nn.MaxPool2d(kernel_size=3, stride=2) 72 | ) 73 | 74 | # # =============================== bottom ================================ 75 | 76 | 77 | # # first bottom convolution layer 78 | self.bottom_conv1 = nn.Sequential( 79 | 80 | # 1-1 conv layer 81 | nn.Conv2d(3, 48, kernel_size=(11,11), stride=(4,4)), 82 | nn.ReLU(), 83 | nn.BatchNorm2d(48), 84 | nn.MaxPool2d(kernel_size=3, stride=2) 85 | ) 86 | 87 | 88 | # first top convolution layer after split 89 | self.bottom_top_conv2 = nn.Sequential( 90 | 91 | # 1-1 conv layer 92 | nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1), 93 | nn.ReLU(), 94 | nn.BatchNorm2d(64), 95 | nn.MaxPool2d(kernel_size=3, stride=2) 96 | ) 97 | 98 | self.bottom_bot_conv2 = nn.Sequential( 99 | 100 | # 1-1 conv layer 101 | nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1), 102 | nn.ReLU(), 103 | nn.BatchNorm2d(64), 104 | nn.MaxPool2d(kernel_size=3, stride=2) 105 | ) 106 | 107 | 108 | # need a concat 109 | 110 | # after concat 111 | self.bottom_conv3 = nn.Sequential( 112 | # 1-1 conv layer 113 | nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1), 114 | nn.ReLU() 115 | ) 116 | 117 | # fourth top convolution layer 118 | # split feature map by half 119 | self.bottom_top_conv4 = nn.Sequential( 120 | # 1-1 conv layer 121 | nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1), 122 | nn.ReLU() 123 | ) 124 | 125 | self.bottom_bot_conv4 = nn.Sequential( 126 | # 1-1 conv layer 127 | nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1), 128 | nn.ReLU() 129 | ) 130 | 131 | 132 | # fifth top convolution layer 133 | self.bottom_top_conv5 = nn.Sequential( 134 | # 1-1 conv layer 135 | nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1), 136 | nn.ReLU(), 137 | nn.MaxPool2d(kernel_size=3, stride=2) 138 | ) 139 | self.bottom_bot_conv5 = nn.Sequential( 140 | # 1-1 conv layer 141 | nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1), 142 | nn.ReLU(), 143 | nn.MaxPool2d(kernel_size=3, stride=2) 144 | ) 145 | 146 | # Fully-connected layer 147 | self.classifier = nn.Sequential( 148 | nn.Linear(5*5*64*4, 4096), 149 | nn.ReLU(), 150 | nn.Dropout(0.7), 151 | nn.Linear(4096, 4096), 152 | nn.ReLU(), 153 | nn.Dropout(0.6), 154 | nn.Linear(4096, 8) 155 | ) 156 | 157 | def forward(self,x): 158 | # print(x.shape) 159 | x_top = self.top_conv1(x) 160 | # print(x_top.shape) 161 | 162 | x_top_conv = torch.split(x_top, 24, 1) 163 | 164 | x_top_top_conv2 = self.top_top_conv2(x_top_conv[0]) 165 | x_top_bot_conv2 = self.top_bot_conv2(x_top_conv[1]) 166 | 167 | x_top_cat1 = torch.cat([x_top_top_conv2,x_top_bot_conv2],1) 168 | 169 | x_top_conv3 = self.top_conv3(x_top_cat1) 170 | 171 | x_top_conv3 = torch.split(x_top_conv3, 96, 1) 172 | 173 | x_top_top_conv4 = self.top_top_conv4(x_top_conv3[0]) 174 | x_top_bot_conv4 = self.top_bot_conv4(x_top_conv3[1]) 175 | 176 | x_top_top_conv5 = self.top_top_conv5(x_top_top_conv4) 177 | x_top_bot_conv5 = self.top_bot_conv5(x_top_bot_conv4) 178 | 179 | x_bottom = self.bottom_conv1(x) 180 | 181 | x_bottom_conv = torch.split(x_bottom, 24, 1) 182 | 183 | x_bottom_top_conv2 = self.bottom_top_conv2(x_bottom_conv[0]) 184 | x_bottom_bot_conv2 = self.bottom_bot_conv2(x_bottom_conv[1]) 185 | 186 | x_bottom_cat1 = torch.cat([x_bottom_top_conv2,x_bottom_bot_conv2],1) 187 | 188 | x_bottom_conv3 = self.bottom_conv3(x_bottom_cat1) 189 | 190 | x_bottom_conv3 = torch.split(x_bottom_conv3, 96, 1) 191 | 192 | x_bottom_top_conv4 = self.bottom_top_conv4(x_bottom_conv3[0]) 193 | x_bottom_bot_conv4 = self.bottom_bot_conv4(x_bottom_conv3[1]) 194 | 195 | x_bottom_top_conv5 = self.bottom_top_conv5(x_bottom_top_conv4) 196 | x_bottom_bot_conv5 = self.bottom_bot_conv5(x_bottom_bot_conv4) 197 | 198 | x_cat = torch.cat([x_top_top_conv5,x_top_bot_conv5,x_bottom_top_conv5,x_bottom_bot_conv5],1) 199 | 200 | 201 | flatten = x_cat.view(x_cat.size(0), -1) 202 | 203 | output = self.classifier(flatten) 204 | 205 | #output = F.softmax(output) 206 | 207 | 208 | return output -------------------------------------------------------------------------------- /Models/logs.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch": 79, 3 | "test_acc": [ 4 | 0.3860613810741688, 5 | 0.47340153452685424, 6 | 0.568542199488491, 7 | 0.6520460358056266, 8 | 0.7159846547314578, 9 | 0.7734015345268542, 10 | 0.7612531969309463, 11 | 0.7856777493606139, 12 | 0.7774936061381074, 13 | 0.8067774936061382, 14 | 0.7260869565217392, 15 | 0.7838874680306905, 16 | 0.8227621483375959, 17 | 0.8222506393861893, 18 | 0.8136828644501278, 19 | 0.8089514066496164, 20 | 0.8297953964194373, 21 | 0.8634271099744245, 22 | 0.8707161125319693, 23 | 0.8822250639386189, 24 | 0.8638107416879796, 25 | 0.8398976982097187, 26 | 0.8755754475703325, 27 | 0.8806905370843989, 28 | 0.8781329923273657, 29 | 0.891687979539642, 30 | 0.8703324808184143, 31 | 0.8594629156010231, 32 | 0.881074168797954, 33 | 0.8951406649616368, 34 | 0.8856777493606138, 35 | 0.8997442455242967, 36 | 0.8852941176470588, 37 | 0.8823529411764706, 38 | 0.8950127877237851, 39 | 0.8664961636828644, 40 | 0.8863171355498721, 41 | 0.8658567774936061, 42 | 0.8920716112531969, 43 | 0.889769820971867, 44 | 0.8447570332480818, 45 | 0.8838874680306905, 46 | 0.8955242966751918, 47 | 0.8859335038363171, 48 | 0.8927109974424552, 49 | 0.8622762148337596, 50 | 0.8786445012787724, 51 | 0.8870843989769821, 52 | 0.8831202046035805, 53 | 0.8914322250639386, 54 | 0.8453964194373401, 55 | 0.8762148337595907, 56 | 0.8979539641943735, 57 | 0.9002557544757033, 58 | 0.9074168797953964, 59 | 0.8732736572890025, 60 | 0.9028132992327366, 61 | 0.8855498721227621, 62 | 0.9016624040920717, 63 | 0.8997442455242967, 64 | 0.9026854219948849, 65 | 0.8984654731457801, 66 | 0.9014066496163683, 67 | 0.9020460358056266, 68 | 0.9076726342710998, 69 | 0.8842710997442456, 70 | 0.9043478260869565, 71 | 0.889002557544757, 72 | 0.8998721227621483, 73 | 0.8566496163682864, 74 | 0.8945012787723785, 75 | 0.897314578005115, 76 | 0.9069053708439898, 77 | 0.8956521739130435, 78 | 0.8927109974424552, 79 | 0.8780051150895141, 80 | 0.869693094629156, 81 | 0.8717391304347826, 82 | 0.8643222506393862, 83 | 0.8991048593350384 84 | ], 85 | "test_loss": [ 86 | 1.6941967010498047, 87 | 1.413282746777815, 88 | 1.2300743241520489, 89 | 1.0186388185795616, 90 | 0.8052589972229565, 91 | 0.6959069020607892, 92 | 0.6545220179592862, 93 | 0.5863275900483131, 94 | 0.5774378241861567, 95 | 0.5302328051013105, 96 | 0.8679919847670723, 97 | 0.6133233887307784, 98 | 0.5330383909099242, 99 | 0.5709917242912685, 100 | 0.5444423876264516, 101 | 0.5323077063350117, 102 | 0.4967346903594101, 103 | 0.4001205657773158, 104 | 0.3681556407143088, 105 | 0.3345503132132923, 106 | 0.42683300638900085, 107 | 0.4592500037568457, 108 | 0.35102029122850475, 109 | 0.34043378623969417, 110 | 0.35698582341565804, 111 | 0.3133265166817343, 112 | 0.36764419714317603, 113 | 0.4082748420974788, 114 | 0.3464647836106665, 115 | 0.3021334641996552, 116 | 0.34193275901762876, 117 | 0.2912965695209363, 118 | 0.3407611415228423, 119 | 0.34141917526721954, 120 | 0.31408849818741574, 121 | 0.4057917531360598, 122 | 0.36029354088446675, 123 | 0.41712552550084453, 124 | 0.33051284545046444, 125 | 0.334898144883268, 126 | 0.4874957107445773, 127 | 0.3422568205963163, 128 | 0.32231982074239673, 129 | 0.32411665162619424, 130 | 0.33481639316853357, 131 | 0.43111707752241807, 132 | 0.3802726919598439, 133 | 0.3482595243874718, 134 | 0.3684794045984745, 135 | 0.3503421111141934, 136 | 0.47689756904454794, 137 | 0.43856004187289405, 138 | 0.3392842488692087, 139 | 0.37621956844540205, 140 | 0.3219831580405726, 141 | 0.4732762347249424, 142 | 0.32707333236056213, 143 | 0.3794022645143902, 144 | 0.35445414867032976, 145 | 0.3346128492232631, 146 | 0.34618619588368077, 147 | 0.3812798371209818, 148 | 0.35725006658364744, 149 | 0.40792761217145357, 150 | 0.36428572500453277, 151 | 0.46748057424145584, 152 | 0.36997203978107257, 153 | 0.4077695794403553, 154 | 0.4040851356352077, 155 | 0.5517700842636473, 156 | 0.46853991028140574, 157 | 0.3906759590129642, 158 | 0.38165930340833526, 159 | 0.4145576044478837, 160 | 0.4339598141172353, 161 | 0.5431482366779271, 162 | 0.6054621858193594, 163 | 0.6075730393914616, 164 | 0.5816612491274581, 165 | 0.40753378346562386 166 | ], 167 | "train_loss": [ 168 | 1.826199992614634, 169 | 1.5579509384491865, 170 | 1.3277928057838888, 171 | 1.1492610468583948, 172 | 0.9276555724003736, 173 | 0.7631941420190475, 174 | 0.6498804372899672, 175 | 0.587474809411694, 176 | 0.544483536744819, 177 | 0.5106695272466716, 178 | 0.48976775521741195, 179 | 0.4548078871825162, 180 | 0.44559210275902467, 181 | 0.4194487832924899, 182 | 0.40714043466483846, 183 | 0.3901485848952742, 184 | 0.3770808058188242, 185 | 0.3520677920211764, 186 | 0.3344503982978709, 187 | 0.31735584534266414, 188 | 0.30753152795574246, 189 | 0.30266259281950836, 190 | 0.2929159453248276, 191 | 0.28656001695815253, 192 | 0.2640054938109482, 193 | 0.2598270363010028, 194 | 0.2505674537490396, 195 | 0.24568976878243334, 196 | 0.2431361524059492, 197 | 0.23904529104337974, 198 | 0.2260390775387778, 199 | 0.2175775353303727, 200 | 0.20651827884071014, 201 | 0.2069938584943028, 202 | 0.1934815933380057, 203 | 0.19195803868419983, 204 | 0.18418305575409355, 205 | 0.18580856086576686, 206 | 0.17852679629097967, 207 | 0.18493566804510705, 208 | 0.18375280652852619, 209 | 0.17664210388765617, 210 | 0.17616673153551185, 211 | 0.16857664396657662, 212 | 0.16199297490803635, 213 | 0.1510674771578873, 214 | 0.16028491374762618, 215 | 0.15404618339722648, 216 | 0.13661601821727612, 217 | 0.13583267967709722, 218 | 0.13400652182891087, 219 | 0.12695811935426557, 220 | 0.12378869633025982, 221 | 0.12411187857608585, 222 | 0.12153887951417881, 223 | 0.11661709390361519, 224 | 0.10970952706959318, 225 | 0.10776158225010424, 226 | 0.10281199540066369, 227 | 0.1005824633380946, 228 | 0.09917867298731033, 229 | 0.09451699684209683, 230 | 0.09425465154516346, 231 | 0.09052921834347002, 232 | 0.08508123591651812, 233 | 0.08681707933325977, 234 | 0.08809264727375087, 235 | 0.07728377096902798, 236 | 0.08336889094618313, 237 | 0.08333722477340523, 238 | 0.08863679344272789, 239 | 0.09242369724875864, 240 | 0.07940908880246912, 241 | 0.07954263366649256, 242 | 0.07616484332281877, 243 | 0.07269402531742611, 244 | 0.07131834188476205, 245 | 0.07526747205787722, 246 | 0.07800157392835792, 247 | 0.07384552036905113 248 | ] 249 | } -------------------------------------------------------------------------------- /Training/Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import os\n", 15 | "import numpy as np\n", 16 | "import pandas as pd\n", 17 | "from torch.utils.data import Dataset, DataLoader\n", 18 | "from PIL import Image\n", 19 | "from torchvision.transforms import ToTensor, Compose, Resize, CenterCrop\n", 20 | "import glob" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "outputs": [], 27 | "source": [ 28 | "labels = ['black', 'blue' , 'cyan' , 'gray' , 'green' , 'red' , 'white' , 'yellow']\n", 29 | "def decode_label(index):\n", 30 | " return labels[index]\n", 31 | "\n", 32 | "def encode_label_from_path(path):\n", 33 | " for index,value in enumerate(labels):\n", 34 | " if value in path:\n", 35 | " return index" 36 | ], 37 | "metadata": { 38 | "collapsed": false, 39 | "pycharm": { 40 | "name": "#%%\n" 41 | } 42 | } 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "outputs": [], 48 | "source": [ 49 | "from sklearn.model_selection import train_test_split\n", 50 | "\n", 51 | "path = '/home/jeetu/Project/VehicleColor/Dataset/'\n", 52 | "image_list = glob.glob(path + '**/*')\n", 53 | "class_list = [encode_label_from_path(item) for item in image_list]\n", 54 | "x_train, x_test , y_train , y_test = train_test_split(image_list, class_list, train_size= 0.5 , stratify=class_list , shuffle=True, random_state=42)" 55 | ], 56 | "metadata": { 57 | "collapsed": false, 58 | "pycharm": { 59 | "name": "#%%\n" 60 | } 61 | } 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 11, 66 | "outputs": [], 67 | "source": [ 68 | "class VehicleColorDataset(Dataset):\n", 69 | " def __init__(self, image_list, class_list, transforms = None):\n", 70 | " self.transform = transforms\n", 71 | " self.image_list = image_list\n", 72 | " self.class_list = class_list\n", 73 | " self.data_len = len(self.image_list)\n", 74 | "\n", 75 | " def __len__(self):\n", 76 | " return self.data_len\n", 77 | "\n", 78 | " def __getitem__(self, index):\n", 79 | " image_path = self.image_list[index]\n", 80 | " image = Image.open(image_path).convert('RGB')\n", 81 | " if self.transform:\n", 82 | " image = self.transform(image)\n", 83 | " return image, self.class_list[index]\n", 84 | "\n", 85 | "transforms=Compose([Resize(224), CenterCrop(224) , ToTensor()])\n", 86 | "train_dataset = VehicleColorDataset( x_train , y_train , transforms)\n", 87 | "train_data_loader = DataLoader(train_dataset,batch_size=115 )\n", 88 | "test_dataset = VehicleColorDataset(x_test, y_test,transforms)\n", 89 | "test_data_loader = DataLoader(test_dataset, batch_size=115)\n" 90 | ], 91 | "metadata": { 92 | "collapsed": false, 93 | "pycharm": { 94 | "name": "#%%\n" 95 | } 96 | } 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "Using cuda device\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 112 | "print('Using {} device'.format(device))" 113 | ], 114 | "metadata": { 115 | "collapsed": false, 116 | "pycharm": { 117 | "name": "#%%\n" 118 | } 119 | } 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "outputs": [], 125 | "source": [ 126 | "#Define Model\n", 127 | "class VehicleColorModel(nn.Module):\n", 128 | " def __init__(self):\n", 129 | " super(VehicleColorModel, self).__init__()\n", 130 | "\n", 131 | " self.top_conv1 = nn.Sequential(\n", 132 | " nn.Conv2d(3,48, kernel_size=(11,11) , stride=(4,4)),\n", 133 | " nn.ReLU(),\n", 134 | " nn.BatchNorm2d(48),\n", 135 | " nn.MaxPool2d(kernel_size=3 , stride=2)\n", 136 | " )\n", 137 | "\n", 138 | " # first top convolution layer after split\n", 139 | " self.top_top_conv2 = nn.Sequential(\n", 140 | "\n", 141 | " # 1-1 conv layer\n", 142 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 143 | " nn.ReLU(),\n", 144 | " nn.BatchNorm2d(64),\n", 145 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 146 | " )\n", 147 | "\n", 148 | " self.top_bot_conv2 = nn.Sequential(\n", 149 | "\n", 150 | " # 1-1 conv layer\n", 151 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 152 | " nn.ReLU(),\n", 153 | " nn.BatchNorm2d(64),\n", 154 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 155 | " )\n", 156 | "\n", 157 | "\n", 158 | " # need a concat\n", 159 | "\n", 160 | " # after concat\n", 161 | " self.top_conv3 = nn.Sequential(\n", 162 | " # 1-1 conv layer\n", 163 | " nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1),\n", 164 | " nn.ReLU()\n", 165 | " )\n", 166 | "\n", 167 | " # fourth top convolution layer\n", 168 | " # split feature map by half\n", 169 | " self.top_top_conv4 = nn.Sequential(\n", 170 | " # 1-1 conv layer\n", 171 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 172 | " nn.ReLU()\n", 173 | " )\n", 174 | "\n", 175 | " self.top_bot_conv4 = nn.Sequential(\n", 176 | " # 1-1 conv layer\n", 177 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 178 | " nn.ReLU()\n", 179 | " )\n", 180 | "\n", 181 | "\n", 182 | " # fifth top convolution layer\n", 183 | " self.top_top_conv5 = nn.Sequential(\n", 184 | " # 1-1 conv layer\n", 185 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 186 | " nn.ReLU(),\n", 187 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 188 | " )\n", 189 | " self.top_bot_conv5 = nn.Sequential(\n", 190 | " # 1-1 conv layer\n", 191 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 192 | " nn.ReLU(),\n", 193 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 194 | " )\n", 195 | "\n", 196 | "# # =============================== bottom ================================\n", 197 | "\n", 198 | "\n", 199 | "# # first bottom convolution layer\n", 200 | " self.bottom_conv1 = nn.Sequential(\n", 201 | "\n", 202 | " # 1-1 conv layer\n", 203 | " nn.Conv2d(3, 48, kernel_size=(11,11), stride=(4,4)),\n", 204 | " nn.ReLU(),\n", 205 | " nn.BatchNorm2d(48),\n", 206 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 207 | " )\n", 208 | "\n", 209 | "\n", 210 | " # first top convolution layer after split\n", 211 | " self.bottom_top_conv2 = nn.Sequential(\n", 212 | "\n", 213 | " # 1-1 conv layer\n", 214 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 215 | " nn.ReLU(),\n", 216 | " nn.BatchNorm2d(64),\n", 217 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 218 | " )\n", 219 | "\n", 220 | " self.bottom_bot_conv2 = nn.Sequential(\n", 221 | "\n", 222 | " # 1-1 conv layer\n", 223 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 224 | " nn.ReLU(),\n", 225 | " nn.BatchNorm2d(64),\n", 226 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 227 | " )\n", 228 | "\n", 229 | "\n", 230 | " # need a concat\n", 231 | "\n", 232 | " # after concat\n", 233 | " self.bottom_conv3 = nn.Sequential(\n", 234 | " # 1-1 conv layer\n", 235 | " nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1),\n", 236 | " nn.ReLU()\n", 237 | " )\n", 238 | "\n", 239 | " # fourth top convolution layer\n", 240 | " # split feature map by half\n", 241 | " self.bottom_top_conv4 = nn.Sequential(\n", 242 | " # 1-1 conv layer\n", 243 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 244 | " nn.ReLU()\n", 245 | " )\n", 246 | "\n", 247 | " self.bottom_bot_conv4 = nn.Sequential(\n", 248 | " # 1-1 conv layer\n", 249 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 250 | " nn.ReLU()\n", 251 | " )\n", 252 | "\n", 253 | "\n", 254 | " # fifth top convolution layer\n", 255 | " self.bottom_top_conv5 = nn.Sequential(\n", 256 | " # 1-1 conv layer\n", 257 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 258 | " nn.ReLU(),\n", 259 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 260 | " )\n", 261 | " self.bottom_bot_conv5 = nn.Sequential(\n", 262 | " # 1-1 conv layer\n", 263 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 264 | " nn.ReLU(),\n", 265 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 266 | " )\n", 267 | "\n", 268 | " # Fully-connected layer\n", 269 | " self.classifier = nn.Sequential(\n", 270 | " nn.Linear(5*5*64*4, 4096),\n", 271 | " nn.ReLU(),\n", 272 | " nn.Dropout(0.7),\n", 273 | " nn.Linear(4096, 4096),\n", 274 | " nn.ReLU(),\n", 275 | " nn.Dropout(0.6),\n", 276 | " nn.Linear(4096, 8)\n", 277 | " )\n", 278 | "\n", 279 | " def forward(self,x):\n", 280 | " # print(x.shape)\n", 281 | " x_top = self.top_conv1(x)\n", 282 | " # print(x_top.shape)\n", 283 | "\n", 284 | " x_top_conv = torch.split(x_top, 24, 1)\n", 285 | "\n", 286 | " x_top_top_conv2 = self.top_top_conv2(x_top_conv[0])\n", 287 | " x_top_bot_conv2 = self.top_bot_conv2(x_top_conv[1])\n", 288 | "\n", 289 | " x_top_cat1 = torch.cat([x_top_top_conv2,x_top_bot_conv2],1)\n", 290 | "\n", 291 | " x_top_conv3 = self.top_conv3(x_top_cat1)\n", 292 | "\n", 293 | " x_top_conv3 = torch.split(x_top_conv3, 96, 1)\n", 294 | "\n", 295 | " x_top_top_conv4 = self.top_top_conv4(x_top_conv3[0])\n", 296 | " x_top_bot_conv4 = self.top_bot_conv4(x_top_conv3[1])\n", 297 | "\n", 298 | " x_top_top_conv5 = self.top_top_conv5(x_top_top_conv4)\n", 299 | " x_top_bot_conv5 = self.top_bot_conv5(x_top_bot_conv4)\n", 300 | "\n", 301 | " x_bottom = self.bottom_conv1(x)\n", 302 | "\n", 303 | " x_bottom_conv = torch.split(x_bottom, 24, 1)\n", 304 | "\n", 305 | " x_bottom_top_conv2 = self.bottom_top_conv2(x_bottom_conv[0])\n", 306 | " x_bottom_bot_conv2 = self.bottom_bot_conv2(x_bottom_conv[1])\n", 307 | "\n", 308 | " x_bottom_cat1 = torch.cat([x_bottom_top_conv2,x_bottom_bot_conv2],1)\n", 309 | "\n", 310 | " x_bottom_conv3 = self.bottom_conv3(x_bottom_cat1)\n", 311 | "\n", 312 | " x_bottom_conv3 = torch.split(x_bottom_conv3, 96, 1)\n", 313 | "\n", 314 | " x_bottom_top_conv4 = self.bottom_top_conv4(x_bottom_conv3[0])\n", 315 | " x_bottom_bot_conv4 = self.bottom_bot_conv4(x_bottom_conv3[1])\n", 316 | "\n", 317 | " x_bottom_top_conv5 = self.bottom_top_conv5(x_bottom_top_conv4)\n", 318 | " x_bottom_bot_conv5 = self.bottom_bot_conv5(x_bottom_bot_conv4)\n", 319 | "\n", 320 | " x_cat = torch.cat([x_top_top_conv5,x_top_bot_conv5,x_bottom_top_conv5,x_bottom_bot_conv5],1)\n", 321 | "\n", 322 | "\n", 323 | " flatten = x_cat.view(x_cat.size(0), -1)\n", 324 | "\n", 325 | " output = self.classifier(flatten)\n", 326 | "\n", 327 | " #output = F.softmax(output)\n", 328 | "\n", 329 | "\n", 330 | " return output\n", 331 | "\n" 332 | ], 333 | "metadata": { 334 | "collapsed": false, 335 | "pycharm": { 336 | "name": "#%%\n" 337 | } 338 | } 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 43, 343 | "outputs": [], 344 | "source": [ 345 | "# Save Model\n", 346 | "Model_Path = '/home/jeetu/Project/VehicleColorA'\n", 347 | "logger = Logger(Model_Path, \"Exp1\", 1)" 348 | ], 349 | "metadata": { 350 | "collapsed": false, 351 | "pycharm": { 352 | "name": "#%%\n" 353 | } 354 | } 355 | }, 356 | { 357 | "cell_type": "code", 358 | "source": [ 359 | "model = VehicleColorModel()\n", 360 | "model.cuda()\n", 361 | "opt = torch.optim.SGD(model.parameters(), momentum=0.9, lr = 0.001 )\n", 362 | "lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=20,min_lr=1e-08,factor=0.1,verbose=True)\n", 363 | "loss_fn = nn.CrossEntropyLoss()" 364 | ], 365 | "metadata": { 366 | "collapsed": false, 367 | "pycharm": { 368 | "name": "#%%\n" 369 | } 370 | }, 371 | "execution_count": 7, 372 | "outputs": [] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 44, 377 | "outputs": [ 378 | { 379 | "name": "stderr", 380 | "output_type": "stream", 381 | "text": [ 382 | "Train | Epoch 0: 100%|██████████| 68/68 [01:11<00:00, 1.05batch/s, loss=0.384]\n", 383 | "Test | Epoch 0: 100%|██████████| 68/68 [00:59<00:00, 1.17batch/s, accuracy=84.7, loss=0.413]\n", 384 | "Train | Epoch 1: 100%|██████████| 68/68 [01:11<00:00, 1.05batch/s, loss=0.368]\n", 385 | "Test | Epoch 1: 100%|██████████| 68/68 [00:59<00:00, 1.22batch/s, accuracy=85.6, loss=0.396]\n", 386 | "Train | Epoch 2: 100%|██████████| 68/68 [01:11<00:00, 1.01s/batch, loss=0.355]\n", 387 | "Test | Epoch 2: 100%|██████████| 68/68 [00:59<00:00, 1.20batch/s, accuracy=85.2, loss=0.398]\n", 388 | "Train | Epoch 3: 44%|████▍ | 30/68 [00:33<00:42, 1.12s/batch, loss=0.358]\n" 389 | ] 390 | }, 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "Saving Model...\n", 396 | "Saving Model...\n", 397 | "Saving Model...\n" 398 | ] 399 | }, 400 | { 401 | "ename": "KeyboardInterrupt", 402 | "evalue": "", 403 | "output_type": "error", 404 | "traceback": [ 405 | "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", 406 | "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", 407 | "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 6\u001B[0m \u001B[0mrunning_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 7\u001B[0m \u001B[0mbatch_\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 8\u001B[0;31m \u001B[0;32mfor\u001B[0m \u001B[0mX\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0my\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mtepoch\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 9\u001B[0m \u001B[0mtepoch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_description\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34mf\"Train | Epoch {epoch}\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 10\u001B[0m \u001B[0mX\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mX\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mto\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'cuda'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 408 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/tqdm/_tqdm.py\u001B[0m in \u001B[0;36m__iter__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1020\u001B[0m \"\"\"), fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001B[1;32m 1021\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1022\u001B[0;31m \u001B[0;32mfor\u001B[0m \u001B[0mobj\u001B[0m \u001B[0;32min\u001B[0m \u001B[0miterable\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1023\u001B[0m \u001B[0;32myield\u001B[0m \u001B[0mobj\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1024\u001B[0m \u001B[0;31m# Update and possibly print the progressbar.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 409 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001B[0m in \u001B[0;36m__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 343\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 344\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m__next__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 345\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_next_data\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 346\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_num_yielded\u001B[0m \u001B[0;34m+=\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 347\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_dataset_kind\u001B[0m \u001B[0;34m==\u001B[0m \u001B[0m_DatasetKind\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mIterable\u001B[0m \u001B[0;32mand\u001B[0m\u001B[0;31m \u001B[0m\u001B[0;31m\\\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 410 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001B[0m in \u001B[0;36m_next_data\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 383\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_next_data\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 384\u001B[0m \u001B[0mindex\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_next_index\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# may raise StopIteration\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 385\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_dataset_fetcher\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfetch\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mindex\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# may raise StopIteration\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 386\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_pin_memory\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 387\u001B[0m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0m_utils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mpin_memory\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mpin_memory\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 411 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001B[0m in \u001B[0;36mfetch\u001B[0;34m(self, possibly_batched_index)\u001B[0m\n\u001B[1;32m 42\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mfetch\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 43\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mauto_collation\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 44\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0midx\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0midx\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 45\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 46\u001B[0m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 412 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001B[0m in \u001B[0;36m\u001B[0;34m(.0)\u001B[0m\n\u001B[1;32m 42\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mfetch\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 43\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mauto_collation\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 44\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0midx\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0midx\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 45\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 46\u001B[0m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mpossibly_batched_index\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 413 | "\u001B[0;32m\u001B[0m in \u001B[0;36m__getitem__\u001B[0;34m(self, index)\u001B[0m\n\u001B[1;32m 12\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m__getitem__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mindex\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 13\u001B[0m \u001B[0mimage_path\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mimage_list\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mindex\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 14\u001B[0;31m \u001B[0mimage\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mImage\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mimage_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconvert\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'RGB'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 15\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtransform\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[0mimage\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtransform\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mimage\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 414 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/PIL/Image.py\u001B[0m in \u001B[0;36mconvert\u001B[0;34m(self, mode, matrix, dither, palette, colors)\u001B[0m\n\u001B[1;32m 932\u001B[0m \"\"\"\n\u001B[1;32m 933\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 934\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 935\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 936\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0mmode\u001B[0m \u001B[0;32mand\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmode\u001B[0m \u001B[0;34m==\u001B[0m \u001B[0;34m\"P\"\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 415 | "\u001B[0;32m~/.local/lib/python3.6/site-packages/PIL/ImageFile.py\u001B[0m in \u001B[0;36mload\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 251\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 252\u001B[0m \u001B[0mb\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mb\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0ms\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 253\u001B[0;31m \u001B[0mn\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0merr_code\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mdecoder\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdecode\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mb\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 254\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mn\u001B[0m \u001B[0;34m<\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 255\u001B[0m \u001B[0;32mbreak\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", 416 | "\u001B[0;31mKeyboardInterrupt\u001B[0m: " 417 | ] 418 | } 419 | ], 420 | "source": [ 421 | "epochs = 5\n", 422 | "from tqdm import tqdm\n", 423 | "from sklearn.metrics import accuracy_score\n", 424 | "for epoch in range(epochs):\n", 425 | " with tqdm(train_data_loader, unit=\"batch\") as tepoch:\n", 426 | " model.train()\n", 427 | " running_loss = 0\n", 428 | " batch_ = 0\n", 429 | " for X,y in tepoch:\n", 430 | " tepoch.set_description(f\"Train | Epoch {epoch}\")\n", 431 | " X = X.to('cuda')\n", 432 | " y = y.to('cuda')\n", 433 | " pred = model.forward(X)\n", 434 | " loss_value = loss_fn(pred, y)\n", 435 | " loss_value.backward()\n", 436 | " opt.step()\n", 437 | " opt.zero_grad()\n", 438 | " batch_ +=1\n", 439 | " running_loss += loss_value.item()\n", 440 | " tepoch.set_postfix(loss = running_loss/batch_)\n", 441 | " logger.log('train_loss', running_loss/batch_)\n", 442 | " with torch.no_grad():\n", 443 | " model.eval()\n", 444 | " with tqdm(test_data_loader, unit=\"batch\") as tepoch:\n", 445 | " tepoch.set_description(f\"Test | Epoch {epoch}\")\n", 446 | " correct = 0\n", 447 | " n_batch = 0\n", 448 | " running_loss = 0\n", 449 | " for X,y in tepoch:\n", 450 | " X, y = X.to('cuda') , y.to('cuda')\n", 451 | " pred = model.forward(X)\n", 452 | " pred_class = pred.argmax(dim = 1)\n", 453 | " loss_value = loss_fn(pred,y)\n", 454 | " running_loss += loss_value.item()\n", 455 | " curr_correct = (pred_class == y).float().sum().item()\n", 456 | " correct += curr_correct\n", 457 | " n_batch +=1\n", 458 | " tepoch.set_postfix(loss = running_loss/n_batch , accuracy = correct / (n_batch*115) *100)\n", 459 | " logger.log('test_loss', running_loss/n_batch)\n", 460 | " logger.log('test_acc', correct/(n_batch*115))\n", 461 | " logger.checkpoint(model)" 462 | ], 463 | "metadata": { 464 | "collapsed": false, 465 | "pycharm": { 466 | "name": "#%%\n" 467 | } 468 | } 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 42, 473 | "outputs": [], 474 | "source": [ 475 | "from collections import defaultdict\n", 476 | "import json\n", 477 | "class Logger(object):\n", 478 | " def __init__(self, log_dir, name, chkpt_interval):\n", 479 | " super(Logger,self).__init__()\n", 480 | " self.chkpt_interval = chkpt_interval\n", 481 | " self.log_dir = log_dir\n", 482 | " self.name = name\n", 483 | " os.makedirs(os.path.join(log_dir, name), exist_ok= True)\n", 484 | " self.log_path = os.path.join(log_dir, name, 'logs.json')\n", 485 | " self.model_path = os.path.join(log_dir, name, 'model.pt')\n", 486 | " self.logs = defaultdict(list)\n", 487 | " self.logs['epoch'] = 0\n", 488 | "\n", 489 | " def log(self, key, value ):\n", 490 | " if isinstance(value, dict):\n", 491 | " for k,v in value.items():\n", 492 | " self.log(f'{key}.{k}',v)\n", 493 | " else:\n", 494 | " self.logs[key].append(value)\n", 495 | "\n", 496 | " def checkpoint(self, model):\n", 497 | " if (self.logs['epoch'] + 1 ) % self.chkpt_interval == 0:\n", 498 | " self.save(model)\n", 499 | " self.logs['epoch'] +=1\n", 500 | "\n", 501 | " def save(self, model):\n", 502 | " print(\"Saving Model...\")\n", 503 | " with open(self.log_path, 'w') as f:\n", 504 | " json.dump(self.logs, f, sort_keys=True, indent=4)\n", 505 | " epch = self.logs['epoch'] + 1\n", 506 | " torch.save(model.state_dict(), os.path.join(self.log_dir, self.name, f'model_{epch}.pt'))\n", 507 | "\n" 508 | ], 509 | "metadata": { 510 | "collapsed": false, 511 | "pycharm": { 512 | "name": "#%%\n" 513 | } 514 | } 515 | } 516 | ], 517 | "metadata": { 518 | "kernelspec": { 519 | "name": "pycharm-72b5a700", 520 | "language": "python", 521 | "display_name": "PyCharm (pythonProject)" 522 | }, 523 | "language_info": { 524 | "codemirror_mode": { 525 | "name": "ipython", 526 | "version": 2 527 | }, 528 | "file_extension": ".py", 529 | "mimetype": "text/x-python", 530 | "name": "python", 531 | "nbconvert_exporter": "python", 532 | "pygments_lexer": "ipython2", 533 | "version": "2.7.6" 534 | } 535 | }, 536 | "nbformat": 4, 537 | "nbformat_minor": 0 538 | } -------------------------------------------------------------------------------- /Infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 37, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "from PIL import Image\n", 14 | "from torchvision.transforms import ToTensor, Resize, CenterCrop, Compose\n", 15 | "import torch.nn as nn\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import numpy as np" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "pycharm": { 25 | "name": "#%%\n" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "labels = ['black', 'blue' , 'cyan' , 'gray' , 'green' , 'red' , 'white' , 'yellow']\n", 31 | "def decode_label(index):\n", 32 | " return labels[index]\n", 33 | "\n", 34 | "def encode_label_from_path(path):\n", 35 | " for index,value in enumerate(labels):\n", 36 | " if value in path:\n", 37 | " return index" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 35, 43 | "metadata": { 44 | "pycharm": { 45 | "name": "#%%\n" 46 | } 47 | }, 48 | "outputs": [ 49 | { 50 | "ename": "ValueError", 51 | "evalue": "px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (3, 224, 224) was provided.Alternatively, 3- or 4-D single or multichannel datasets can bevisualized using the `facet_col` or/and `animation_frame` arguments.", 52 | "output_type": "error", 53 | "traceback": [ 54 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 55 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 56 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# from skimage import io\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# img = io.imread('https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Crab_Nebula.jpg/240px-Crab_Nebula.jpg')\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 57 | "\u001b[0;32m~/.local/lib/python3.6/site-packages/plotly/express/_imshow.py\u001b[0m in \u001b[0;36mimshow\u001b[0;34m(img, zmin, zmax, origin, labels, x, y, animation_frame, facet_col, facet_col_wrap, facet_col_spacing, facet_row_spacing, color_continuous_scale, color_continuous_midpoint, range_color, title, template, width, height, aspect, contrast_rescaling, binary_string, binary_backend, binary_compression_level, binary_format)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[0;34m\"Alternatively, 3- or 4-D single or multichannel datasets can be\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;34m\"visualized using the `facet_col` or/and `animation_frame` arguments.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 516\u001b[0;31m \u001b[0;34m%\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 517\u001b[0m )\n\u001b[1;32m 518\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 58 | "\u001b[0;31mValueError\u001b[0m: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (3, 224, 224) was provided.Alternatively, 3- or 4-D single or multichannel datasets can bevisualized using the `facet_col` or/and `animation_frame` arguments." 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "# import plotly.express as px\n", 64 | "# from skimage import io\n", 65 | "# img = io.imread('https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Crab_Nebula.jpg/240px-Crab_Nebula.jpg')\n", 66 | "# fig = px.imshow(image.squeeze_())\n", 67 | "# fig.show()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 54, 73 | "metadata": { 74 | "pycharm": { 75 | "name": "#%%\n" 76 | } 77 | }, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "(3, 224, 224)\n", 84 | "torch.Size([3, 224, 224])\n", 85 | "blue\n" 86 | ] 87 | }, 88 | { 89 | "data": { 90 | "image/png": "\n", 91 | "text/plain": [ 92 | "
" 93 | ] 94 | }, 95 | "metadata": { 96 | "needs_background": "light" 97 | }, 98 | "output_type": "display_data" 99 | } 100 | ], 101 | "source": [ 102 | "# imgae_path = '/home/jeetu/Desktop/red.jpg'\n", 103 | "# imgae_path = '/home/jeetu/Desktop/blue.png'\n", 104 | "imgae_path = '/home/jeetu/Desktop/b2.png'\n", 105 | "# imgae_path = '/home/jeetu/Desktop/AA.jpg'\n", 106 | "model_path = '/home/jeetu/Project/VehicleColorA/Exp1/model_3.pt'\n", 107 | "image = Image.open(imgae_path).convert('RGB')\n", 108 | "\n", 109 | "transforms = Compose([Resize(224), CenterCrop(224) , ToTensor()])\n", 110 | "image = transforms(image)\n", 111 | "model = VehicleColorModel()\n", 112 | "model.load_state_dict(torch.load(model_path))\n", 113 | "# t_img = image.numpy()\n", 114 | "print(t_img.shape)\n", 115 | "plt.imshow(image.permute(1,2,0))\n", 116 | "print(image.shape)\n", 117 | "image = image.unsqueeze(0)\n", 118 | "pred = model.forward(image).argmax(dim = 1)\n", 119 | "class_label = decode_label(pred)\n", 120 | "print(class_label)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 7, 126 | "metadata": { 127 | "pycharm": { 128 | "name": "#%%\n" 129 | } 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "class VehicleColorModel(nn.Module):\n", 134 | " def __init__(self):\n", 135 | " super(VehicleColorModel, self).__init__()\n", 136 | "\n", 137 | " self.top_conv1 = nn.Sequential(\n", 138 | " nn.Conv2d(3,48, kernel_size=(11,11) , stride=(4,4)),\n", 139 | " nn.ReLU(),\n", 140 | " nn.BatchNorm2d(48),\n", 141 | " nn.MaxPool2d(kernel_size=3 , stride=2)\n", 142 | " )\n", 143 | "\n", 144 | " # first top convolution layer after split\n", 145 | " self.top_top_conv2 = nn.Sequential(\n", 146 | "\n", 147 | " # 1-1 conv layer\n", 148 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 149 | " nn.ReLU(),\n", 150 | " nn.BatchNorm2d(64),\n", 151 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 152 | " )\n", 153 | "\n", 154 | " self.top_bot_conv2 = nn.Sequential(\n", 155 | "\n", 156 | " # 1-1 conv layer\n", 157 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 158 | " nn.ReLU(),\n", 159 | " nn.BatchNorm2d(64),\n", 160 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 161 | " )\n", 162 | "\n", 163 | "\n", 164 | " # need a concat\n", 165 | "\n", 166 | " # after concat\n", 167 | " self.top_conv3 = nn.Sequential(\n", 168 | " # 1-1 conv layer\n", 169 | " nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1),\n", 170 | " nn.ReLU()\n", 171 | " )\n", 172 | "\n", 173 | " # fourth top convolution layer\n", 174 | " # split feature map by half\n", 175 | " self.top_top_conv4 = nn.Sequential(\n", 176 | " # 1-1 conv layer\n", 177 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 178 | " nn.ReLU()\n", 179 | " )\n", 180 | "\n", 181 | " self.top_bot_conv4 = nn.Sequential(\n", 182 | " # 1-1 conv layer\n", 183 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 184 | " nn.ReLU()\n", 185 | " )\n", 186 | "\n", 187 | "\n", 188 | " # fifth top convolution layer\n", 189 | " self.top_top_conv5 = nn.Sequential(\n", 190 | " # 1-1 conv layer\n", 191 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 192 | " nn.ReLU(),\n", 193 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 194 | " )\n", 195 | " self.top_bot_conv5 = nn.Sequential(\n", 196 | " # 1-1 conv layer\n", 197 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 198 | " nn.ReLU(),\n", 199 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 200 | " )\n", 201 | "\n", 202 | "# # =============================== bottom ================================\n", 203 | "\n", 204 | "\n", 205 | "# # first bottom convolution layer\n", 206 | " self.bottom_conv1 = nn.Sequential(\n", 207 | "\n", 208 | " # 1-1 conv layer\n", 209 | " nn.Conv2d(3, 48, kernel_size=(11,11), stride=(4,4)),\n", 210 | " nn.ReLU(),\n", 211 | " nn.BatchNorm2d(48),\n", 212 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 213 | " )\n", 214 | "\n", 215 | "\n", 216 | " # first top convolution layer after split\n", 217 | " self.bottom_top_conv2 = nn.Sequential(\n", 218 | "\n", 219 | " # 1-1 conv layer\n", 220 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 221 | " nn.ReLU(),\n", 222 | " nn.BatchNorm2d(64),\n", 223 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 224 | " )\n", 225 | "\n", 226 | " self.bottom_bot_conv2 = nn.Sequential(\n", 227 | "\n", 228 | " # 1-1 conv layer\n", 229 | " nn.Conv2d(24, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 230 | " nn.ReLU(),\n", 231 | " nn.BatchNorm2d(64),\n", 232 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 233 | " )\n", 234 | "\n", 235 | "\n", 236 | " # need a concat\n", 237 | "\n", 238 | " # after concat\n", 239 | " self.bottom_conv3 = nn.Sequential(\n", 240 | " # 1-1 conv layer\n", 241 | " nn.Conv2d(128, 192, kernel_size=(3,3), stride=(1,1),padding=1),\n", 242 | " nn.ReLU()\n", 243 | " )\n", 244 | "\n", 245 | " # fourth top convolution layer\n", 246 | " # split feature map by half\n", 247 | " self.bottom_top_conv4 = nn.Sequential(\n", 248 | " # 1-1 conv layer\n", 249 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 250 | " nn.ReLU()\n", 251 | " )\n", 252 | "\n", 253 | " self.bottom_bot_conv4 = nn.Sequential(\n", 254 | " # 1-1 conv layer\n", 255 | " nn.Conv2d(96, 96, kernel_size=(3,3), stride=(1,1),padding=1),\n", 256 | " nn.ReLU()\n", 257 | " )\n", 258 | "\n", 259 | "\n", 260 | " # fifth top convolution layer\n", 261 | " self.bottom_top_conv5 = nn.Sequential(\n", 262 | " # 1-1 conv layer\n", 263 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 264 | " nn.ReLU(),\n", 265 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 266 | " )\n", 267 | " self.bottom_bot_conv5 = nn.Sequential(\n", 268 | " # 1-1 conv layer\n", 269 | " nn.Conv2d(96, 64, kernel_size=(3,3), stride=(1,1),padding=1),\n", 270 | " nn.ReLU(),\n", 271 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 272 | " )\n", 273 | "\n", 274 | " # Fully-connected layer\n", 275 | " self.classifier = nn.Sequential(\n", 276 | " nn.Linear(5*5*64*4, 4096),\n", 277 | " nn.ReLU(),\n", 278 | " nn.Dropout(0.7),\n", 279 | " nn.Linear(4096, 4096),\n", 280 | " nn.ReLU(),\n", 281 | " nn.Dropout(0.6),\n", 282 | " nn.Linear(4096, 8)\n", 283 | " )\n", 284 | "\n", 285 | " def forward(self,x):\n", 286 | " # print(x.shape)\n", 287 | " x_top = self.top_conv1(x)\n", 288 | " # print(x_top.shape)\n", 289 | "\n", 290 | " x_top_conv = torch.split(x_top, 24, 1)\n", 291 | "\n", 292 | " x_top_top_conv2 = self.top_top_conv2(x_top_conv[0])\n", 293 | " x_top_bot_conv2 = self.top_bot_conv2(x_top_conv[1])\n", 294 | "\n", 295 | " x_top_cat1 = torch.cat([x_top_top_conv2,x_top_bot_conv2],1)\n", 296 | "\n", 297 | " x_top_conv3 = self.top_conv3(x_top_cat1)\n", 298 | "\n", 299 | " x_top_conv3 = torch.split(x_top_conv3, 96, 1)\n", 300 | "\n", 301 | " x_top_top_conv4 = self.top_top_conv4(x_top_conv3[0])\n", 302 | " x_top_bot_conv4 = self.top_bot_conv4(x_top_conv3[1])\n", 303 | "\n", 304 | " x_top_top_conv5 = self.top_top_conv5(x_top_top_conv4)\n", 305 | " x_top_bot_conv5 = self.top_bot_conv5(x_top_bot_conv4)\n", 306 | "\n", 307 | " x_bottom = self.bottom_conv1(x)\n", 308 | "\n", 309 | " x_bottom_conv = torch.split(x_bottom, 24, 1)\n", 310 | "\n", 311 | " x_bottom_top_conv2 = self.bottom_top_conv2(x_bottom_conv[0])\n", 312 | " x_bottom_bot_conv2 = self.bottom_bot_conv2(x_bottom_conv[1])\n", 313 | "\n", 314 | " x_bottom_cat1 = torch.cat([x_bottom_top_conv2,x_bottom_bot_conv2],1)\n", 315 | "\n", 316 | " x_bottom_conv3 = self.bottom_conv3(x_bottom_cat1)\n", 317 | "\n", 318 | " x_bottom_conv3 = torch.split(x_bottom_conv3, 96, 1)\n", 319 | "\n", 320 | " x_bottom_top_conv4 = self.bottom_top_conv4(x_bottom_conv3[0])\n", 321 | " x_bottom_bot_conv4 = self.bottom_bot_conv4(x_bottom_conv3[1])\n", 322 | "\n", 323 | " x_bottom_top_conv5 = self.bottom_top_conv5(x_bottom_top_conv4)\n", 324 | " x_bottom_bot_conv5 = self.bottom_bot_conv5(x_bottom_bot_conv4)\n", 325 | "\n", 326 | " x_cat = torch.cat([x_top_top_conv5,x_top_bot_conv5,x_bottom_top_conv5,x_bottom_bot_conv5],1)\n", 327 | "\n", 328 | "\n", 329 | " flatten = x_cat.view(x_cat.size(0), -1)\n", 330 | "\n", 331 | " output = self.classifier(flatten)\n", 332 | "\n", 333 | " #output = F.softmax(output)\n", 334 | "\n", 335 | "\n", 336 | " return output\n" 337 | ] 338 | } 339 | ], 340 | "metadata": { 341 | "kernelspec": { 342 | "display_name": "PyCharm (pythonProject)", 343 | "language": "python", 344 | "name": "pycharm-72b5a700" 345 | }, 346 | "language_info": { 347 | "codemirror_mode": { 348 | "name": "ipython", 349 | "version": 3 350 | }, 351 | "file_extension": ".py", 352 | "mimetype": "text/x-python", 353 | "name": "python", 354 | "nbconvert_exporter": "python", 355 | "pygments_lexer": "ipython3", 356 | "version": "3.6.9" 357 | } 358 | }, 359 | "nbformat": 4, 360 | "nbformat_minor": 1 361 | } 362 | --------------------------------------------------------------------------------