├── .gitignore ├── README.md ├── keras ├── draw_heatmap.py ├── imgs │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── model.py └── trainer.py ├── pytorch ├── draw_heatmap.py ├── imgs │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── model.py └── trainer.py └── result_pic ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg └── tt /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly-supervised-object-localization 2 | simple weakly dectector 3 | 4 | we rewrite two version: 5 | keras + tensorflow2 6 | pytorch 7 | 8 | # dataset 9 | cifar10, it can be loaded by keras or pytorch with one line 10 | 11 | # author's paper 12 | [Learning Deep Features for Discriminative Localization](https://arxiv.org/pdf/1512.04150.pdf) 13 | 14 | 15 | # code 16 | using the InceptionV3 to extract conv_features
17 | after 3 iterations,training accuracy is 99.87,testing accuracy is 95.68
18 | 19 | Step: 20 | (1):run train.py
21 | (2):run draw_detector.py
22 | 23 | also,some tricks are learned from:https://github.com/jazzsaxmafia/Weakly_detector
24 | 25 | 26 | # Results show 27 | Here we just show some samples... 28 | 29 | 30 | 31 | Sample 1: 32 | ![sample and hotmat 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/1.jpg) 33 | ![combine 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/2.jpg) 34 | 35 | 36 | Sample 2: 37 | ![sample and hotmat 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/3.jpg) 38 | ![combine 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/4.jpg) 39 | 40 | Sample 3: 41 | ![sample and hotmat 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/5.jpg) 42 | ![combine 1](https://github.com/ray0809/weakly-supervised-object-localization/blob/master/result_pic/6.jpg) 43 | -------------------------------------------------------------------------------- /keras/draw_heatmap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pylab as plt 4 | 5 | import tensorflow as tf 6 | # tf.enable_eager_execution() 7 | from tensorflow.keras.applications.inception_v3 import preprocess_input 8 | 9 | 10 | 11 | from model import * 12 | 13 | class WeaklyLocation(): 14 | def __init__(self, multi_model, w): 15 | self.multi_model = multi_model 16 | self.w = w 17 | 18 | def _getOutputs(self, inp): 19 | # here we get featmap before globalpooling and softmax output 20 | conv_feat, softmax_prob = self.multi_model.predict(inp) 21 | return conv_feat[0], softmax_prob[0] 22 | 23 | def _preprocess(self, img): 24 | img = np.expand_dims(img, axis=0) 25 | img = preprocess_input(img) 26 | return img 27 | 28 | def getHeatmap(self, img): 29 | # once with one pic 30 | img = self._preprocess(img) 31 | conv_feat, softmax_prob = self._getOutputs(img) 32 | 33 | 34 | max_prob_idx = np.argmax(softmax_prob) 35 | w = self.w[:, max_prob_idx] 36 | w = w.reshape(1, 1, -1) 37 | 38 | heatmap = (conv_feat * w).sum(axis=2) 39 | return heatmap 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | # the pretrained weight path 45 | h5_file = './checkpoint/best.h5' 46 | 47 | # load pretrained model 48 | m = create_inceptionv3(32, 256, 10) 49 | m.load_weights(h5_file) 50 | mm, w = create_multi_inceptionv3(m, 32, 256, 10) 51 | net = WeaklyLocation(mm, w) 52 | 53 | 54 | # predict heatmap 55 | img = cv2.imread('./imgs/7.jpg', 1) 56 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 57 | 58 | heatmap = net.getHeatmap(img_rgb) 59 | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) 60 | img = cv2.resize(img, (256, 256)) 61 | 62 | # drawing heatmap 63 | plt.figure(1) 64 | plt.subplot(1,2,1) 65 | plt.imshow(img[:,:,::-1]) 66 | 67 | plt.subplot(1,2,2) 68 | plt.imshow(img[:,:,::-1]) 69 | plt.imshow(heatmap, cmap=plt.cm.jet, alpha=0.5, interpolation='nearest' ) 70 | plt.show() 71 | 72 | -------------------------------------------------------------------------------- /keras/imgs/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/0.jpg -------------------------------------------------------------------------------- /keras/imgs/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/1.jpg -------------------------------------------------------------------------------- /keras/imgs/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/2.jpg -------------------------------------------------------------------------------- /keras/imgs/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/3.jpg -------------------------------------------------------------------------------- /keras/imgs/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/4.jpg -------------------------------------------------------------------------------- /keras/imgs/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/5.jpg -------------------------------------------------------------------------------- /keras/imgs/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/6.jpg -------------------------------------------------------------------------------- /keras/imgs/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/7.jpg -------------------------------------------------------------------------------- /keras/imgs/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/8.jpg -------------------------------------------------------------------------------- /keras/imgs/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/keras/imgs/9.jpg -------------------------------------------------------------------------------- /keras/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Input, Lambda, GlobalAveragePooling2D, Dense 4 | from tensorflow.keras.applications import InceptionV3 5 | 6 | 7 | 8 | 9 | 10 | def Resize(x): 11 | y = tf.image.resize(x, size=(256,256)) 12 | return y 13 | 14 | 15 | def create_inceptionv3(inp_size, rz_size, class_num): 16 | inputs = Input(shape=(inp_size, inp_size, 3)) 17 | resize = Lambda(Resize,(rz_size, rz_size, 3))(inputs) 18 | base_model = InceptionV3(include_top=False) 19 | conv = base_model(resize) 20 | GAV = GlobalAveragePooling2D()(conv) 21 | outputs = Dense(class_num, activation='softmax')(GAV) 22 | model = Model(inputs, outputs) 23 | return model 24 | 25 | 26 | def create_multi_inceptionv3(inceptionv3_model, inp_size, rz_size, class_num): 27 | inputs = Input(shape=(inp_size, inp_size, 3)) 28 | resize = Lambda(Resize,(rz_size, rz_size, 3))(inputs) 29 | 30 | inception_v3 = inceptionv3_model.get_layer('inception_v3') 31 | conv = inception_v3(resize) 32 | 33 | # resize, for the same size with original pic, concat for imshow 34 | resized_conv = Lambda(Resize,(rz_size, rz_size, 3))(conv) 35 | 36 | GAV = GlobalAveragePooling2D()(conv) 37 | 38 | dense = inceptionv3_model.get_layer('dense') 39 | outputs = dense(GAV) 40 | middle_model = Model(inputs, [resized_conv, outputs]) 41 | 42 | # the last dense layer's weight 2048*class_num 43 | w = middle_model.get_layer('dense').weights[0].numpy() 44 | return middle_model, w 45 | 46 | 47 | if __name__ == "__main__": 48 | m = create_inceptionv3(32, 256, 10) 49 | mm, w = create_multi_inceptionv3(m, 32, 256, 10) 50 | 51 | 52 | 53 | import numpy as np 54 | inp = np.zeros((1,32,32,3)) 55 | rz_conv, output = mm(inp) 56 | print(rz_conv.shape) 57 | print(output) -------------------------------------------------------------------------------- /keras/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | 5 | import tensorflow as tf 6 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 7 | for gpu in gpus: 8 | tf.config.experimental.set_memory_growth(gpu, True) 9 | 10 | from tensorflow.keras.datasets import cifar10 11 | from tensorflow.keras.utils import to_categorical 12 | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint 13 | from tensorflow.keras.applications.inception_v3 import preprocess_input 14 | from model import create_inceptionv3 15 | 16 | 17 | 18 | def train(): 19 | 20 | # Load Dataset 21 | cifar_data = cifar10.load_data() 22 | train_data = preprocess_input((cifar_data[0][0]).astype('float32')) 23 | train_label = cifar_data[0][1] 24 | test_data = preprocess_input((cifar_data[1][0]).astype('float32')) 25 | test_label = cifar_data[1][1] 26 | train_label = to_categorical(train_label,10) 27 | test_label = to_categorical(test_label,10) 28 | 29 | 30 | 31 | 32 | 33 | 34 | # Init Classifier 35 | inception_model = create_inceptionv3(32, 256, 10) 36 | inception_model.compile(optimizer='sgd', 37 | loss='categorical_crossentropy', 38 | metrics=['accuracy']) 39 | 40 | 41 | 42 | # Traing Classifier 43 | if not os.path.isdir('./checkpoint'): 44 | os.makedirs('./checkpoint') 45 | earlystopping = EarlyStopping(patience=2) 46 | modelchenkpoint = ModelCheckpoint('./checkpoint/best.h5', 47 | save_best_only=True, 48 | save_weights_only=True) 49 | inception_model.fit(train_data, train_label, 50 | batch_size=64, 51 | epochs=10, 52 | validation_data=(test_data, test_label), 53 | callbacks=[earlystopping, modelchenkpoint]) 54 | 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | train() -------------------------------------------------------------------------------- /pytorch/draw_heatmap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import matplotlib.pylab as plt 5 | 6 | import torchvision.transforms as transforms 7 | from model import ResNet50Net 8 | 9 | 10 | 11 | class WeaklyLocation(): 12 | def __init__(self, net): 13 | self.net = net 14 | self.trans = transforms.Compose([transforms.ToTensor()]) 15 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | self.net = self.net.to(self.device) 18 | self.net.eval() 19 | 20 | # 2048 * 10 21 | self.w = self.net.resnet.fc.weight.data.cpu().numpy().T 22 | 23 | def _getOutputs(self, inp): 24 | # here we get featmap before globalpooling and softmax output 25 | conv_feat, softmax_prob = self.net.inference(inp) 26 | return conv_feat[0], softmax_prob[0] 27 | 28 | def _preprocess(self, img): 29 | img = self.trans(img) 30 | img = img.unsqueeze(0) 31 | img = img.to(self.device) 32 | return img 33 | 34 | def getHeatmap(self, img): 35 | # once with one pic 36 | img = self._preprocess(img) 37 | with torch.no_grad(): 38 | conv_feat, softmax_prob = self._getOutputs(img) 39 | conv_feat = conv_feat.data.cpu().numpy() 40 | softmax_prob = softmax_prob.data.cpu().numpy() 41 | 42 | conv_feat = conv_feat.transpose(1,2,0) 43 | max_prob_idx = np.argmax(softmax_prob) 44 | w = self.w[:, max_prob_idx] 45 | w = w.reshape(1, 1, -1) 46 | 47 | heatmap = (conv_feat * w).sum(axis=2) 48 | return heatmap 49 | 50 | 51 | 52 | if __name__ == "__main__": 53 | rz_size = 224 54 | num_class = 10 55 | m = ResNet50Net(rz_size, num_class) 56 | m.load_state_dict(torch.load('./checkpoint/resnet50.pkl', map_location='cpu')) 57 | net = WeaklyLocation(m) 58 | 59 | 60 | 61 | # predict heatmap 62 | img = cv2.imread('./imgs/0.jpg', 1) 63 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 64 | 65 | heatmap = net.getHeatmap(img_rgb) 66 | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) 67 | img = cv2.resize(img, (rz_size, rz_size)) 68 | 69 | # drawing heatmap 70 | plt.figure(1) 71 | plt.subplot(1,2,1) 72 | plt.imshow(img[:,:,::-1]) 73 | 74 | plt.subplot(1,2,2) 75 | plt.imshow(img[:,:,::-1]) 76 | plt.imshow(heatmap, cmap=plt.cm.jet, alpha=0.5, interpolation='nearest' ) 77 | plt.show() 78 | -------------------------------------------------------------------------------- /pytorch/imgs/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/0.jpg -------------------------------------------------------------------------------- /pytorch/imgs/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/1.jpg -------------------------------------------------------------------------------- /pytorch/imgs/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/2.jpg -------------------------------------------------------------------------------- /pytorch/imgs/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/3.jpg -------------------------------------------------------------------------------- /pytorch/imgs/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/4.jpg -------------------------------------------------------------------------------- /pytorch/imgs/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/5.jpg -------------------------------------------------------------------------------- /pytorch/imgs/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/6.jpg -------------------------------------------------------------------------------- /pytorch/imgs/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/7.jpg -------------------------------------------------------------------------------- /pytorch/imgs/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/8.jpg -------------------------------------------------------------------------------- /pytorch/imgs/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/pytorch/imgs/9.jpg -------------------------------------------------------------------------------- /pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | 7 | class ResNet50Net(nn.Module): 8 | def __init__(self, rz_size, class_num): 9 | super().__init__() 10 | 11 | resnet = torchvision.models.resnet50(pretrained=True) 12 | self.upsample = nn.Upsample((rz_size, rz_size)) 13 | resnet.fc = nn.Linear(resnet.fc.in_features, class_num, bias=True) 14 | self.resnet = resnet 15 | 16 | self._extract_feat_layers = ['layer4'] 17 | self.rz_size = rz_size 18 | 19 | 20 | def forward(self, x): 21 | x = self.upsample(x) 22 | x = self.resnet(x) 23 | return x 24 | 25 | def inference(self, x): 26 | outputs = [] 27 | x = self.upsample(x) 28 | for name, layer in self.resnet.named_children(): 29 | x = layer(x) 30 | if name in self._extract_feat_layers: 31 | out = F.interpolate(x, (self.rz_size, self.rz_size), mode='bilinear') 32 | outputs.append(out) 33 | if name == 'avgpool': 34 | x = torch.flatten(x, 1) 35 | x = F.softmax(x, dim=1) 36 | outputs.append(x) 37 | return outputs 38 | 39 | 40 | class Inceptionv3Net(nn.Module): 41 | def __init__(self, rz_size, class_num): 42 | super().__init__() 43 | 44 | inception = torchvision.models.inception_v3(pretrained=True, aux_logits=False) 45 | self.upsample = nn.Upsample((rz_size, rz_size)) 46 | 47 | self.fc = nn.Linear(inception.fc.in_features, class_num, bias=False) 48 | 49 | named_children = list(inception.named_children()) 50 | 51 | 52 | self._extract_feat_layers = ['Mixed_7c'] 53 | self.rz_size = rz_size 54 | 55 | def forward(self, x): 56 | x = self.upsample(x) 57 | x = self.inception(x) 58 | 59 | return x 60 | 61 | def inference(self, x): 62 | extra_tag = False 63 | outputs = [] 64 | x = self.upsample(x) 65 | for name, layer in self.inception.named_children(): 66 | x = layer(x) 67 | if name in self._extract_feat_layers: 68 | outputs.append(x) 69 | extra_tag = True 70 | 71 | 72 | 73 | if __name__ == "__main__": 74 | m = ResNet50Net(224, 10) 75 | m = m.eval() 76 | 77 | inp = torch.randn(1,3,32,32) 78 | out = m(inp) 79 | print(out.shape) 80 | 81 | 82 | out1 = m.inference(inp) 83 | print(out1[0].shape, out1[1].shape) 84 | 85 | 86 | w = m.resnet.fc.weight.data.cpu().numpy().T 87 | print(w.shape) 88 | 89 | 90 | 91 | mm = Inceptionv3Net(299, 10) 92 | mm = mm.eval() 93 | out= mm(inp) 94 | print(mm) 95 | print(out.shape) 96 | 97 | -------------------------------------------------------------------------------- /pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | from model import ResNet50Net 11 | from tqdm import tqdm 12 | 13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def create_cifar10_dataloader(): 17 | 18 | train_dataset = torchvision.datasets.CIFAR10(root='data/', 19 | train=True, 20 | transform=transforms.ToTensor()) 21 | 22 | valid_dataset = torchvision.datasets.CIFAR10(root='data/', 23 | train=False, 24 | transform=transforms.ToTensor()) 25 | 26 | 27 | trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, 28 | shuffle=True, num_workers=8) 29 | validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, 30 | shuffle=False, num_workers=8) 31 | 32 | 33 | return trainloader, validloader 34 | 35 | 36 | 37 | def train(): 38 | 39 | # init net 40 | net = ResNet50Net(224, 10) 41 | optimizer = optim.Adam(net.parameters(), lr=0.001) 42 | criterion = nn.CrossEntropyLoss() 43 | net = net.to(device) 44 | 45 | # load dataset 46 | trainloader, validloader = create_cifar10_dataloader() 47 | 48 | 49 | # start training 50 | if not os.path.isdir('./checkpoint'): 51 | os.makedirs('./checkpoint') 52 | epoches = 10 53 | best_acc = 0.0 54 | for i in range(epoches): 55 | net.train() 56 | 57 | train_correct = 0 58 | train_total = 0 59 | for (imgs, labels) in tqdm(trainloader): 60 | imgs = imgs.to(device) 61 | labels = labels.to(device).type(torch.cuda.LongTensor) 62 | 63 | output = net(imgs) 64 | _, predicted = torch.max(output.data, 1) 65 | train_correct += (predicted == labels).sum().item() 66 | train_total += labels.shape[0] 67 | loss = criterion(output, labels) 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | 72 | print('############## taining ##############') 73 | print('{}/{}, loss:{}, training Accuracy:{}'.format(i+1, epoches, loss.item(), train_correct / train_total)) 74 | print('############## taining ##############') 75 | 76 | 77 | net.eval() 78 | with torch.no_grad(): 79 | correct = 0 80 | total = 0 81 | for imgs, labels in tqdm(validloader): 82 | imgs = imgs.to(device) 83 | labels = labels.to(device).type(torch.cuda.LongTensor) 84 | output = net(imgs) 85 | _, predicted = torch.max(output.data, 1) 86 | correct += (predicted == labels).sum().item() 87 | total += labels.shape[0] 88 | valid_acc = correct / total 89 | 90 | print('############## testing ##############') 91 | print('testing Average Accuracy is {}'.format(valid_acc)) 92 | print('############## testing ##############') 93 | 94 | if valid_acc > best_acc: 95 | best_acc = valid_acc 96 | torch.save(net.state_dict(), './checkpoint/resnet50.pkl') 97 | 98 | if __name__ == "__main__": 99 | train() 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /result_pic/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/1.jpg -------------------------------------------------------------------------------- /result_pic/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/2.jpg -------------------------------------------------------------------------------- /result_pic/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/3.jpg -------------------------------------------------------------------------------- /result_pic/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/4.jpg -------------------------------------------------------------------------------- /result_pic/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/5.jpg -------------------------------------------------------------------------------- /result_pic/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray0809/weakly-supervised-object-localization/d5eaf4017a14f7b94ebaf2cc62e803a7f9223391/result_pic/6.jpg -------------------------------------------------------------------------------- /result_pic/tt: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------