├── dataloaders ├── r ├── PairwiseImg_test.py ├── PairwiseImg_video_test_try.py ├── PairwiseImg_video.py └── PairwiseImg_video_try.py ├── deeplab ├── e ├── __init__.py ├── utils.py ├── siamese_model_conf.py └── siamese_model_conf_try_single.py ├── pretrained └── deep_labv3 │ └── readme.md ├── framework.png ├── README.md ├── densecrf_apply_cvpr2019.py ├── test_iteration_conf_group.py ├── test_coattention_conf.py ├── train_iteration_conf.py └── train_iteration_conf_group.py /dataloaders/r: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deeplab/e: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pretrained/deep_labv3/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carrierlxk/COSNet/HEAD/framework.png -------------------------------------------------------------------------------- /deeplab/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #from tensorboard_logger import log_value 3 | from torch.autograd import Variable 4 | 5 | 6 | def loss_calc(pred, label, ignore_label): 7 | """ 8 | This function returns cross entropy loss for semantic segmentation 9 | """ 10 | # out shape batch_size x channels x h x w -> batch_size x channels x h x w 11 | # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w 12 | label = Variable(label.long()).cuda() 13 | criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label).cuda() 14 | 15 | return criterion(pred, label) 16 | 17 | 18 | def lr_poly(base_lr, iter, max_iter, power): 19 | return base_lr * ((1 - float(iter) / max_iter) ** power) 20 | 21 | 22 | def get_1x_lr_params(model): 23 | """ 24 | This generator returns all the parameters of the net except for 25 | the last classification layer. Note that for each batchnorm layer, 26 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 27 | any batchnorm parameter 28 | """ 29 | b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4] 30 | 31 | for i in range(len(b)): 32 | for j in b[i].modules(): 33 | jj = 0 34 | for k in j.parameters(): 35 | jj += 1 36 | if k.requires_grad: 37 | yield k 38 | 39 | 40 | def get_10x_lr_params(model): 41 | """ 42 | This generator returns all the parameters for the last layer of the net, 43 | which does the classification of pixel into classes 44 | """ 45 | b = [model.layer5.parameters(), model.main_classifier.parameters()] 46 | 47 | for j in range(len(b)): 48 | for i in b[j]: 49 | yield i 50 | 51 | 52 | def adjust_learning_rate(optimizer, i_iter, learning_rate, num_steps, power): 53 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 54 | lr = lr_poly(learning_rate, i_iter, num_steps, power) 55 | #log_value('learning', lr, i_iter) 56 | optimizer.param_groups[0]['lr'] = lr 57 | optimizer.param_groups[1]['lr'] = lr * 10 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COSNet 2 | Code for CVPR 2019 paper: 3 | 4 | [See More, Know More: Unsupervised Video Object Segmentation with 5 | Co-Attention Siamese Networks](http://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_See_More_Know_More_Unsupervised_Video_Object_Segmentation_With_Co-Attention_CVPR_2019_paper.pdf) 6 | 7 | [Xiankai Lu](https://sites.google.com/site/xiankailu111/), [Wenguan Wang](https://sites.google.com/view/wenguanwang), Chao Ma, Jianbing Shen, Ling Shao, Fatih Porikli 8 | 9 | ## 10 | 11 | ![](../master/framework.png) 12 | 13 | - - - 14 | :new: 15 | 16 | Our group co-attention achieves a further performance gain (81.1 mean J on DAVIS-16 dataset), related codes have also been released. 17 | 18 | The pre-trained model, testing and training code: 19 | 20 | ### Quick Start 21 | 22 | #### Testing 23 | 24 | 1. Install pytorch (version:1.0.1). 25 | 26 | 2. Download the pretrained model. Run 'test_coattention_conf.py' and change the davis dataset path, pretrainde model path and result path. 27 | 28 | 3. Run command: python test_coattention_conf.py --dataset davis --gpus 0 29 | 30 | 4. Post CRF processing code comes from: https://github.com/lucasb-eyer/pydensecrf. 31 | 32 | The pretrained weight can be download from [GoogleDrive](https://drive.google.com/open?id=14ya3ZkneeHsegCgDrvkuFtGoAfVRgErz) or [BaiduPan](https://pan.baidu.com/s/16oFzRmn4Meuq83fCYr4boQ), pass code: xwup. 33 | 34 | The segmentation results on DAVIS, FBMS and Youtube-objects can be download from DAVIS_benchmark(https://davischallenge.org/davis2016/soa_compare.html) or 35 | [GoogleDrive](https://drive.google.com/open?id=1JRPc2kZmzx0b7WLjxTPD-kdgFdXh5gBq) or [BaiduPan](https://pan.baidu.com/s/11n7zAt3Lo2P3-42M2lsw6Q), pass code: q37f. 36 | 37 | The youtube-objects dataset can be downloaded from [here](http://calvin-vision.net/datasets/youtube-objects-dataset/) and annotation can be found [here](http://vision.cs.utexas.edu/projects/videoseg/data_download_register.html). 38 | 39 | The FBMS dataset can be downloaded from [here](https://lmb.informatik.uni-freiburg.de/resources/datasets/moseg.en.html). 40 | #### Training 41 | 42 | 1. Download all the training datasets, including MARA10K and DUT saliency datasets. Create a folder called images and put these two datasets into the folder. 43 | 44 | 2. Download the deeplabv3 model from [GoogleDrive](https://drive.google.com/open?id=1hy0-BAEestT9H4a3Sv78xrHrzmZga9mj). Put it into the folder pretrained/deep_labv3. 45 | 46 | 3. Change the video path, image path and deeplabv3 path in train_iteration_conf.py. Create two txt files which store the saliency dataset name and DAVIS16 training sequences name. Change the txt path in PairwiseImg_video.py. 47 | 48 | 4. Run command: python train_iteration_conf.py --dataset davis --gpus 0,1 49 | 50 | ### Citation 51 | 52 | If you find the code and dataset useful in your research, please consider citing: 53 | ``` 54 | @InProceedings{Lu_2019_CVPR, 55 | author = {Lu, Xiankai and Wang, Wenguan and Ma, Chao and Shen, Jianbing and Shao, Ling and Porikli, Fatih}, 56 | title = {See More, Know More: Unsupervised Video Object Segmentation With Co-Attention Siamese Networks}, 57 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 58 | year = {2019} 59 | } 60 | @article{lu2020_pami, 61 | title={Zero-Shot Video Object Segmentation with Co-Attention Siamese Networks}, 62 | author={Lu, Xiankai and Wang, Wenguan and Shen, Jianbing and Crandall, David and Luo, Jiebo}, 63 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 64 | year={2020}, 65 | publisher={IEEE} 66 | } 67 | ``` 68 | ### Other related projects/papers: 69 | [Saliency-Aware Geodesic Video Object Segmentation (CVPR15)](https://github.com/wenguanwang/saliencysegment) 70 | 71 | [Learning Unsupervised Video Primary Object Segmentation through Visual Attention (CVPR19)](https://github.com/wenguanwang/AGS) 72 | 73 | Any comments, please email: carrierlxk@gmail.com 74 | -------------------------------------------------------------------------------- /densecrf_apply_cvpr2019.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Mar 1 20:37:37 2019 5 | 6 | @author: xiankai 7 | """ 8 | 9 | import pydensecrf.densecrf as dcrf 10 | import numpy as np 11 | import sys 12 | import os 13 | 14 | 15 | from skimage.io import imread, imsave 16 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax 17 | from os import listdir, makedirs 18 | from os.path import isfile, join 19 | from multiprocessing import Process 20 | 21 | 22 | def worker(scale, g_dim, g_factor,s_dim,C_dim,c_factor): 23 | davis_path = '/home/xiankai/work/DAVIS-2016/JPEGImages/480p'#'/home/ying/tracking/pdb_results/FBMS-results' 24 | origin_path = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/COS-78.2'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/ECCV'#'/media/xiankai/Data/segmentation/match-Weaksup_VideoSeg/result/test/davis_iteration_conf_sal_match_scale/COS/' 25 | out_folder = '/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/cvpr2019_crfs'#'/media/xiankai/Data/ECCV-crf'#'/home/xiankai/work/DAVIS-2016/Results/Segmentations/480p/davis_ICCV_new/' 26 | if not os.path.exists(out_folder): 27 | os.makedirs(out_folder) 28 | origin_file = listdir(origin_path) 29 | origin_file.sort() 30 | for i in range(0, len(origin_file)): 31 | d = origin_file[i] 32 | vidDir = join(davis_path, d) 33 | out_folder1 = join(out_folder,'f'+str(scale)+str(g_dim)+str(g_factor)+'_'+'s'+str(s_dim)+'_'+'c'+str(C_dim)+str(c_factor)) 34 | resDir = join(out_folder1, d) 35 | if not os.path.exists(resDir): 36 | os.makedirs(resDir) 37 | rgb_file = listdir(vidDir) 38 | rgb_file.sort() 39 | for ii in range(0,len(rgb_file)): 40 | f = rgb_file[ii] 41 | img = imread(join(vidDir, f)) 42 | segDir = join(origin_path, d) 43 | frameName = str.split(f, '.')[0] 44 | anno_rgb = imread(segDir + '/' + frameName + '.png').astype(np.uint32) 45 | min_val = np.min(anno_rgb.ravel()) 46 | max_val = np.max(anno_rgb.ravel()) 47 | out = (anno_rgb.astype('float') - min_val) / (max_val - min_val) 48 | labels = np.zeros((2, img.shape[0], img.shape[1])) 49 | labels[1, :, :] = out 50 | labels[0, :, :] = 1 - out 51 | 52 | colors = [0, 255] 53 | colorize = np.empty((len(colors), 1), np.uint8) 54 | colorize[:,0] = colors 55 | n_labels = 2 56 | 57 | crf = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels) 58 | 59 | U = unary_from_softmax(labels,scale) 60 | crf.setUnaryEnergy(U) 61 | 62 | feats = create_pairwise_gaussian(sdims=(g_dim, g_dim), shape=img.shape[:2]) 63 | 64 | crf.addPairwiseEnergy(feats, compat=g_factor, 65 | kernel=dcrf.DIAG_KERNEL, 66 | normalization=dcrf.NORMALIZE_SYMMETRIC) 67 | 68 | feats = create_pairwise_bilateral(sdims=(s_dim,s_dim), schan=(C_dim, C_dim, C_dim),# 30,5 69 | img=img, chdim=2) 70 | 71 | crf.addPairwiseEnergy(feats, compat=c_factor, 72 | kernel=dcrf.DIAG_KERNEL, 73 | normalization=dcrf.NORMALIZE_SYMMETRIC) 74 | 75 | #Q = crf.inference(5) 76 | Q, tmp1, tmp2 = crf.startInference() 77 | for i in range(5): 78 | #print("KL-divergence at {}: {}".format(i, crf.klDivergence(Q))) 79 | crf.stepInference(Q, tmp1, tmp2) 80 | 81 | MAP = np.argmax(Q, axis=0) 82 | MAP = colorize[MAP] 83 | 84 | imsave(resDir + '/' + frameName + '.png', MAP.reshape(anno_rgb.shape)) 85 | print ("Saving: " + resDir + '/' + frameName + '.png') 86 | scales = [1]#[0.5,1]#[0.1,0.3,0.5,0.6]#[0.5, 1.0] 87 | g_dims = [1]#[1,3]#[1,3] 88 | g_factors =[5]#[3,5,10] #[ 3, 5,10] 89 | s_dims = [10,15,20] #[5,10,20]#[11, 12, 13]#[9,10,11] #10 90 | Cs = [7]#[5]#[8]# [ 7,8,9,10] #8 91 | b_factors = [8,9,10] 92 | for scale in scales: 93 | for g_dim in g_dims: 94 | for ii in range(0,len(g_factors)): 95 | g_factor = g_factors[ii] 96 | for jj in range(0,len(s_dims)): 97 | s_dim = s_dims[jj] 98 | for cs in Cs: 99 | p1 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[0])) 100 | p2 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[1])) 101 | p3 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, b_factors[2])) 102 | #p4 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 4)) 103 | #p5 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1)) 104 | #p6 = Process(target = worker, args = (scale, g_dim, g_factor,s_dim,cs, 1)) 105 | 106 | p1.start() 107 | p2.start() 108 | p3.start() 109 | #p4.start() 110 | #p5.start() 111 | #p6.start() 112 | 113 | 114 | -------------------------------------------------------------------------------- /test_iteration_conf_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 17 17:53:20 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import data 12 | import numpy as np 13 | import pickle 14 | import cv2 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | import scipy.misc 18 | import torch.backends.cudnn as cudnn 19 | import sys 20 | import os 21 | import os.path as osp 22 | from dataloaders import PairwiseImg_video_test_try as db 23 | #from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法 24 | import matplotlib.pyplot as plt 25 | import random 26 | import timeit 27 | from PIL import Image 28 | from collections import OrderedDict 29 | import matplotlib.pyplot as plt 30 | import torch.nn as nn 31 | from utils.colorize_mask import cityscapes_colorize_mask, VOCColorize 32 | #import pydensecrf.densecrf as dcrf 33 | #from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian 34 | from deeplab.siamese_model_conf_try import CoattentionNet 35 | from torchvision.utils import save_image 36 | 37 | def get_arguments(): 38 | """Parse all the arguments provided from the CLI. 39 | 40 | Returns: 41 | A list of parsed arguments. 42 | """ 43 | parser = argparse.ArgumentParser(description="PSPnet") 44 | parser.add_argument("--dataset", type=str, default='cityscapes', 45 | help="voc12, cityscapes, or pascal-context") 46 | 47 | # GPU configuration 48 | parser.add_argument("--cuda", default=True, help="Run on CPU or GPU") 49 | parser.add_argument("--gpus", type=str, default="0", 50 | help="choose gpu device.") 51 | parser.add_argument("--seq_name", default = 'bmx-bumps') 52 | parser.add_argument("--use_crf", default = 'True') 53 | parser.add_argument("--sample_range", default =3) 54 | 55 | return parser.parse_args() 56 | 57 | def configure_dataset_model(args): 58 | 59 | args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 60 | args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 61 | args.data_dir = '/home/xiankai/work/DAVIS-2016' # 37572 image pairs 62 | args.data_list = '/home/xiankai/work/DAVIS-2016/test_seqs.txt' # Path to the file listing the images in the dataset 63 | args.ignore_label = 255 #The index of the label to ignore during the training 64 | args.input_size = '473,473' #Comma-separated string with height and width of images 65 | args.num_classes = 2 #Number of classes to predict (including background) 66 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) # saving model file and log record during the process of training 67 | args.restore_from = './co_attention_davis_43.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' # 68 | args.snapshot_dir = './snapshots/davis_iteration/' #Where to save snapshots of the model 69 | args.save_segimage = True 70 | args.seg_save_dir = "./result/test/davis_iteration_conf_try" 71 | args.vis_save_dir = "./result/test/davis_vis" 72 | args.corp_size =(473, 473) 73 | 74 | def convert_state_dict(state_dict): 75 | """Converts a state dict saved from a dataParallel module to normal 76 | module state_dict inplace 77 | :param state_dict is the loaded DataParallel model_state 78 | You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it 79 | without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can 80 | load the weights file, create a new ordered dict without the module prefix, and load it back 81 | """ 82 | state_dict_new = OrderedDict() 83 | #print(type(state_dict)) 84 | for k, v in state_dict.items(): 85 | #print(k) 86 | name = k[7:] # remove the prefix module. 87 | # My heart is broken, the pytorch have no ability to do with the problem. 88 | state_dict_new[name] = v 89 | if name == 'linear_e.weight': 90 | np.save('weight_matrix.npy',v.cpu().numpy()) 91 | return state_dict_new 92 | 93 | def sigmoid(inX): 94 | return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法,其本质就是1/(1+e^-x) 95 | 96 | def main(): 97 | args = get_arguments() 98 | print("=====> Configure dataset and model") 99 | configure_dataset_model(args) 100 | print(args) 101 | 102 | print("=====> Set GPU for training") 103 | if args.cuda: 104 | print("====> Use gpu id: '{}'".format(args.gpus)) 105 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 106 | if not torch.cuda.is_available(): 107 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 108 | model = CoattentionNet(num_classes=args.num_classes, nframes = args.sample_range) 109 | for param in model.parameters(): 110 | param.requires_grad = False 111 | 112 | saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage) 113 | #print(saved_state_dict.keys()) 114 | #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()}) 115 | model.load_state_dict( convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"]) 116 | 117 | model.eval() 118 | model.cuda() 119 | 120 | db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir, transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path 121 | testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0) 122 | voc_colorize = VOCColorize() 123 | 124 | data_list = [] 125 | 126 | if args.save_segimage: 127 | if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir): 128 | os.makedirs(args.seg_save_dir) 129 | #os.makedirs(args.vis_save_dir) 130 | print("======> test set size:", len(testloader)) 131 | my_index = 0 132 | old_temp='' 133 | for index, batch in enumerate(testloader): 134 | print('%d processd'%(index)) 135 | target = batch['target'] 136 | np.save('target.npy', target.float().data) 137 | #search = batch['search'] 138 | temp = batch['seq_name'] 139 | args.seq_name=temp[0] 140 | print(args.seq_name) 141 | if old_temp==args.seq_name: 142 | my_index = my_index+1 143 | else: 144 | my_index = 0 145 | output_sum = 0 146 | for i in range(0,1): 147 | search = batch['search'+'_'+str(i)] 148 | search_im = search 149 | print('input size:', search_im.size(),len(search.size())) 150 | if len(search.size()) <5: 151 | search_im = search_im.unsqueeze(0) 152 | output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda()) 153 | #print(output[0]) # output有两个 154 | output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果 155 | 156 | #np.save('infer'+str(i)+'.npy',output1) 157 | #output2 = output[1].data[0, 0].cpu().numpy() #interp' 158 | 159 | output1 = output_sum#/args.sample_range 160 | #target_mask = output[3].data[0,0].cpu().numpy() 161 | #print('output size:', output1.shape, type(output1)) 162 | first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg')) 163 | original_shape = first_image.shape 164 | output1 = cv2.resize(output1, (original_shape[1],original_shape[0])) 165 | #output2 = cv2.resize(target_mask, (original_shape[1], original_shape[0])) 166 | mask = (output1*255).astype(np.uint8) 167 | #target_mask = (output2*255).astype(np.uint8) 168 | mask = Image.fromarray(mask) 169 | #target_mask = Image.fromarray(target_mask) 170 | 171 | save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name) 172 | old_temp=args.seq_name 173 | if not os.path.exists(save_dir_res): 174 | os.makedirs(save_dir_res) 175 | if args.save_segimage: 176 | my_index1 = str(my_index).zfill(5) 177 | seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1)) 178 | gate_filename = os.path.join(save_dir_res, '{}_gate.png'.format(my_index1)) 179 | mask.save(seg_filename) 180 | #target_mask.save(gate_filename) 181 | 182 | if __name__ == '__main__': 183 | main() 184 | -------------------------------------------------------------------------------- /dataloaders/PairwiseImg_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 12 11:39:54 2018 4 | 5 | @author: carri 6 | """ 7 | # for testing case 8 | from __future__ import division 9 | 10 | import os 11 | import numpy as np 12 | import cv2 13 | from scipy.misc import imresize 14 | import scipy.misc 15 | import random 16 | 17 | #from dataloaders.helpers import * 18 | from torch.utils.data import Dataset 19 | 20 | def flip(I,flip_p): 21 | if flip_p>0.5: 22 | return np.fliplr(I) 23 | else: 24 | return I 25 | 26 | def scale_im(img_temp,scale): 27 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 28 | return cv2.resize(img_temp,new_dims).astype(float) 29 | 30 | def scale_gt(img_temp,scale): 31 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 32 | return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float) 33 | 34 | def my_crop(img,gt): 35 | H = int(0.9 * img.shape[0]) 36 | W = int(0.9 * img.shape[1]) 37 | H_offset = random.choice(range(img.shape[0] - H)) 38 | W_offset = random.choice(range(img.shape[1] - W)) 39 | H_slice = slice(H_offset, H_offset + H) 40 | W_slice = slice(W_offset, W_offset + W) 41 | img = img[H_slice, W_slice, :] 42 | gt = gt[H_slice, W_slice] 43 | 44 | return img, gt 45 | 46 | class PairwiseImg(Dataset): 47 | """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities""" 48 | 49 | def __init__(self, train=True, 50 | inputRes=None, 51 | db_root_dir='/DAVIS-2016', 52 | transform=None, 53 | meanval=(104.00699, 116.66877, 122.67892), 54 | seq_name=None, sample_range=10): 55 | """Loads image to label pairs for tool pose estimation 56 | db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" 57 | """ 58 | self.train = train 59 | self.range = sample_range 60 | self.inputRes = inputRes 61 | self.db_root_dir = db_root_dir 62 | self.transform = transform 63 | self.meanval = meanval 64 | self.seq_name = seq_name 65 | 66 | if self.train: 67 | fname = 'train_seqs' 68 | else: 69 | fname = 'val_seqs' 70 | 71 | if self.seq_name is None: #所有的数据集都参与训练 72 | with open(os.path.join(db_root_dir, fname + '.txt')) as f: 73 | seqs = f.readlines() 74 | img_list = [] 75 | labels = [] 76 | Index = {} 77 | for seq in seqs: 78 | images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n')))) 79 | images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) 80 | start_num = len(img_list) 81 | img_list.extend(images_path) 82 | end_num = len(img_list) 83 | Index[seq.strip('\n')]= np.array([start_num, end_num]) 84 | lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n')))) 85 | lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) 86 | labels.extend(lab_path) 87 | else: #针对所有的训练样本, img_list存放的是图片的路径 88 | 89 | # Initialize the per sequence images for online training 90 | names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 91 | img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img)) 92 | #name_label = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 93 | labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])] 94 | labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None 95 | if self.train: 96 | img_list = [img_list[0]] 97 | labels = [labels[0]] 98 | 99 | assert (len(labels) == len(img_list)) 100 | 101 | self.img_list = img_list 102 | self.labels = labels 103 | self.Index = Index 104 | #img_files = open('all_im.txt','w+') 105 | 106 | def __len__(self): 107 | return len(self.img_list) 108 | 109 | def __getitem__(self, idx): 110 | target, target_gt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧 111 | target_id = idx 112 | seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称 113 | sample = {'target': target, 'target_gt': target_gt, 'seq_name': sequence_name, 'search_0': None} 114 | if self.range>=1: 115 | my_index = self.Index[seq_name1] 116 | search_num = list(range(my_index[0], my_index[1])) 117 | search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1)) 118 | 119 | for i in range(0,self.range): 120 | search_id = search_ids[i] 121 | search, search_gt,sequence_name = self.make_img_gt_pair(search_id) 122 | if sample['search_0'] is None: 123 | sample['search_0'] = search 124 | else: 125 | sample['search'+'_'+str(i)] = search 126 | #np.save('search1.npy',search) 127 | #np.save('search_gt.npy',search_gt) 128 | if self.seq_name is not None: 129 | fname = os.path.join(self.seq_name, "%05d" % idx) 130 | sample['fname'] = fname 131 | 132 | else: 133 | img, gt = self.make_img_gt_pair(idx) 134 | sample = {'image': img, 'gt': gt} 135 | if self.seq_name is not None: 136 | fname = os.path.join(self.seq_name, "%05d" % idx) 137 | sample['fname'] = fname 138 | 139 | return sample #这个类最后的输出 140 | 141 | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 142 | """ 143 | Make the image-ground-truth pair 144 | """ 145 | img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR) 146 | if self.labels[idx] is not None and self.train: 147 | label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE) 148 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 149 | else: 150 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 151 | 152 | ## 已经读取了image以及对应的ground truth可以进行data augmentation了 153 | if self.train: #scaling, cropping and flipping 154 | img, label = my_crop(img,label) 155 | scale = random.uniform(0.7, 1.3) 156 | flip_p = random.uniform(0, 1) 157 | img_temp = scale_im(img,scale) 158 | img_temp = flip(img_temp,flip_p) 159 | gt_temp = scale_gt(label,scale) 160 | gt_temp = flip(gt_temp,flip_p) 161 | 162 | img = img_temp 163 | label = gt_temp 164 | 165 | if self.inputRes is not None: 166 | img = imresize(img, self.inputRes) 167 | #print('ok1') 168 | #scipy.misc.imsave('label.png',label) 169 | #scipy.misc.imsave('img.png',img) 170 | if self.labels[idx] is not None and self.train: 171 | label = imresize(label, self.inputRes, interp='nearest') 172 | 173 | img = np.array(img, dtype=np.float32) 174 | #img = img[:, :, ::-1] 175 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 176 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 177 | 178 | if self.labels[idx] is not None and self.train: 179 | gt = np.array(label, dtype=np.int32) 180 | gt[gt!=0]=1 181 | #gt = gt/np.max([gt.max(), 1e-8]) 182 | #np.save('gt.npy') 183 | sequence_name = self.img_list[idx].split('/')[2] 184 | return img, gt, sequence_name 185 | 186 | def get_img_size(self): 187 | img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) 188 | 189 | return list(img.shape[:2]) 190 | 191 | 192 | if __name__ == '__main__': 193 | import custom_transforms as tr 194 | import torch 195 | from torchvision import transforms 196 | from matplotlib import pyplot as plt 197 | 198 | transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()]) 199 | 200 | #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016', 201 | # train=True, transform=transforms) 202 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) 203 | # 204 | # for i, data in enumerate(dataloader): 205 | # plt.figure() 206 | # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt']))) 207 | # if i == 10: 208 | # break 209 | # 210 | # plt.show(block=True) 211 | -------------------------------------------------------------------------------- /dataloaders/PairwiseImg_video_test_try.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 12 11:39:54 2018 4 | 5 | @author: carri 6 | """ 7 | # for testing case 8 | from __future__ import division 9 | 10 | import os 11 | import numpy as np 12 | import cv2 13 | from scipy.misc import imresize 14 | import scipy.misc 15 | import random 16 | import torch 17 | from dataloaders.helpers import * 18 | from torch.utils.data import Dataset 19 | 20 | def flip(I,flip_p): 21 | if flip_p>0.5: 22 | return np.fliplr(I) 23 | else: 24 | return I 25 | 26 | def scale_im(img_temp,scale): 27 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 28 | return cv2.resize(img_temp,new_dims).astype(float) 29 | 30 | def scale_gt(img_temp,scale): 31 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 32 | return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float) 33 | 34 | def my_crop(img,gt): 35 | H = int(0.9 * img.shape[0]) 36 | W = int(0.9 * img.shape[1]) 37 | H_offset = random.choice(range(img.shape[0] - H)) 38 | W_offset = random.choice(range(img.shape[1] - W)) 39 | H_slice = slice(H_offset, H_offset + H) 40 | W_slice = slice(W_offset, W_offset + W) 41 | img = img[H_slice, W_slice, :] 42 | gt = gt[H_slice, W_slice] 43 | 44 | return img, gt 45 | 46 | class PairwiseImg(Dataset): 47 | """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities""" 48 | 49 | def __init__(self, train=True, 50 | inputRes=None, 51 | db_root_dir='/DAVIS-2016', 52 | transform=None, 53 | meanval=(104.00699, 116.66877, 122.67892), 54 | seq_name=None, sample_range=10): 55 | """Loads image to label pairs for tool pose estimation 56 | db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" 57 | """ 58 | self.train = train 59 | self.range = sample_range 60 | self.inputRes = inputRes 61 | self.db_root_dir = db_root_dir 62 | self.transform = transform 63 | self.meanval = meanval 64 | self.seq_name = seq_name 65 | 66 | if self.train: 67 | fname = 'train_seqs' 68 | else: 69 | fname = 'val_seqs' 70 | 71 | if self.seq_name is None: #所有的数据集都参与训练 72 | with open(os.path.join(db_root_dir, fname + '.txt')) as f: 73 | seqs = f.readlines() 74 | img_list = [] 75 | labels = [] 76 | Index = {} 77 | for seq in seqs: 78 | images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n')))) 79 | images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) 80 | start_num = len(img_list) 81 | img_list.extend(images_path) 82 | end_num = len(img_list) 83 | Index[seq.strip('\n')]= np.array([start_num, end_num]) 84 | lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n')))) 85 | lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) 86 | labels.extend(lab_path) 87 | else: #针对所有的训练样本, img_list存放的是图片的路径 88 | 89 | # Initialize the per sequence images for online training 90 | names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 91 | img_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img)) 92 | #name_label = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 93 | labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])] 94 | labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None 95 | if self.train: 96 | img_list = [img_list[0]] 97 | labels = [labels[0]] 98 | 99 | assert (len(labels) == len(img_list)) 100 | 101 | self.img_list = img_list 102 | self.labels = labels 103 | self.Index = Index 104 | #img_files = open('all_im.txt','w+') 105 | 106 | def __len__(self): 107 | return len(self.img_list) 108 | 109 | def __getitem__(self, idx): 110 | target, target_grt,sequence_name = self.make_img_gt_pair(idx) #测试时候要分割的帧 111 | target_id = idx 112 | seq_name1 = self.img_list[target_id].split('/')[-2] #获取视频名称 113 | 114 | #target_grts = torch.stack((torch.from_numpy(target_grt), torch.from_numpy(target_grt_1))) 115 | #print('video name', seq_name1 ) 116 | sample = {'target': target, 'target_gt': target_grt, 'seq_name': sequence_name, 'search_0': None} 117 | if self.range>=1: 118 | my_index = self.Index[seq_name1] 119 | search_num = list(range(my_index[0], my_index[1])) 120 | search_ids = random.sample(search_num, self.range)#min(len(self.img_list)-1, target_id+np.random.randint(1,self.range+1)) 121 | searchs=[] 122 | for i in range(0,self.range): 123 | 124 | search_id = search_ids[i] 125 | search, search_grt,sequence_name = self.make_img_gt_pair(search_id) 126 | searchs.append(torch.from_numpy(search)) 127 | #search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1))) 128 | if sample['search_0'] is None: 129 | sample['search_0'] = torch.stack(searchs,dim=0) 130 | else: 131 | sample['search'+'_'+str(i)] = torch.stack(searchs) 132 | #np.save('search1.npy',search) 133 | #np.save('search_gt.npy',search_gt) 134 | if self.seq_name is not None: 135 | fname = os.path.join(self.seq_name, "%05d" % idx) 136 | sample['fname'] = fname 137 | 138 | else: 139 | img, gt = self.make_img_gt_pair(idx) 140 | sample = {'image': img, 'gt': gt} 141 | if self.seq_name is not None: 142 | fname = os.path.join(self.seq_name, "%05d" % idx) 143 | sample['fname'] = fname 144 | 145 | return sample #这个类最后的输出 146 | 147 | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 148 | """ 149 | Make the image-ground-truth pair 150 | """ 151 | img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]), cv2.IMREAD_COLOR) 152 | if self.labels[idx] is not None and self.train: 153 | label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE) 154 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 155 | else: 156 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 157 | 158 | ## 已经读取了image以及对应的ground truth可以进行data augmentation了 159 | if self.train: #scaling, cropping and flipping 160 | img, label = my_crop(img,label) 161 | scale = random.uniform(0.7, 1.3) 162 | flip_p = random.uniform(0, 1) 163 | img_temp = scale_im(img,scale) 164 | img_temp = flip(img_temp,flip_p) 165 | gt_temp = scale_gt(label,scale) 166 | gt_temp = flip(gt_temp,flip_p) 167 | 168 | img = img_temp 169 | label = gt_temp 170 | 171 | if self.inputRes is not None: 172 | img = imresize(img, self.inputRes) 173 | #print('ok1') 174 | #scipy.misc.imsave('label.png',label) 175 | #scipy.misc.imsave('img.png',img) 176 | if self.labels[idx] is not None and self.train: 177 | label = imresize(label, self.inputRes, interp='nearest') 178 | 179 | img = np.array(img, dtype=np.float32) 180 | #img = img[:, :, ::-1] 181 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 182 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 183 | 184 | if self.labels[idx] is not None and self.train: 185 | gt = np.array(label, dtype=np.int32) 186 | gt[gt!=0]=1 187 | #gt = gt/np.max([gt.max(), 1e-8]) 188 | #np.save('gt.npy') 189 | sequence_name = self.img_list[idx].split('/')[2] 190 | return img, gt, sequence_name 191 | 192 | def get_img_size(self): 193 | img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) 194 | 195 | return list(img.shape[:2]) 196 | 197 | 198 | if __name__ == '__main__': 199 | import custom_transforms as tr 200 | import torch 201 | from torchvision import transforms 202 | from matplotlib import pyplot as plt 203 | 204 | transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()]) 205 | 206 | #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016', 207 | # train=True, transform=transforms) 208 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) 209 | # 210 | # for i, data in enumerate(dataloader): 211 | # plt.figure() 212 | # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt']))) 213 | # if i == 10: 214 | # break 215 | # 216 | # plt.show(block=True) 217 | -------------------------------------------------------------------------------- /test_coattention_conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 17 17:53:20 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import data 12 | import numpy as np 13 | import pickle 14 | import cv2 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | import scipy.misc 18 | import torch.backends.cudnn as cudnn 19 | import sys 20 | import os 21 | import os.path as osp 22 | from dataloaders import PairwiseImg_test as db 23 | #from dataloaders import StaticImg as db #采用voc dataset的数据设置格式方法 24 | import matplotlib.pyplot as plt 25 | import random 26 | import timeit 27 | from PIL import Image 28 | from collections import OrderedDict 29 | import matplotlib.pyplot as plt 30 | import torch.nn as nn 31 | #from utils.colorize_mask import cityscapes_colorize_mask, VOCColorize 32 | #import pydensecrf.densecrf as dcrf 33 | #from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian 34 | from deeplab.siamese_model_conf import CoattentionNet 35 | from torchvision.utils import save_image 36 | 37 | def get_arguments(): 38 | """Parse all the arguments provided from the CLI. 39 | 40 | Returns: 41 | A list of parsed arguments. 42 | """ 43 | parser = argparse.ArgumentParser(description="PSPnet") 44 | parser.add_argument("--dataset", type=str, default='cityscapes', 45 | help="voc12, cityscapes, or pascal-context") 46 | 47 | # GPU configuration 48 | parser.add_argument("--cuda", default=True, help="Run on CPU or GPU") 49 | parser.add_argument("--gpus", type=str, default="0", 50 | help="choose gpu device.") 51 | parser.add_argument("--seq_name", default = 'bmx-bumps') 52 | parser.add_argument("--use_crf", default = 'True') 53 | parser.add_argument("--sample_range", default =5) 54 | 55 | return parser.parse_args() 56 | 57 | def configure_dataset_model(args): 58 | if args.dataset == 'voc12': 59 | args.data_dir ='/home/wty/AllDataSet/VOC2012' #Path to the directory containing the PASCAL VOC dataset 60 | args.data_list = './dataset/list/VOC2012/test.txt' #Path to the file listing the images in the dataset 61 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 62 | #RBG mean, first subtract mean and then change to BGR 63 | args.ignore_label = 255 #The index of the label to ignore during the training 64 | args.num_classes = 21 #Number of classes to predict (including background) 65 | args.restore_from = './snapshots/voc12/psp_voc12_14.pth' #Where restore model parameters from 66 | args.save_segimage = True 67 | args.seg_save_dir = "./result/test/VOC2012" 68 | args.corp_size =(505, 505) 69 | 70 | elif args.dataset == 'davis': 71 | args.batch_size = 1# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 72 | args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 73 | args.data_dir = 'your_path/DAVIS-2016' # 37572 image pairs 74 | args.data_list = 'your_path/DAVIS-2016/test_seqs.txt' # Path to the file listing the images in the dataset 75 | args.ignore_label = 255 #The index of the label to ignore during the training 76 | args.input_size = '473,473' #Comma-separated string with height and width of images 77 | args.num_classes = 2 #Number of classes to predict (including background) 78 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) # saving model file and log record during the process of training 79 | args.restore_from = './your_path.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' # 80 | args.snapshot_dir = './snapshots/davis_iteration/' #Where to save snapshots of the model 81 | args.save_segimage = True 82 | args.seg_save_dir = "./result/test/davis_iteration_conf" 83 | args.vis_save_dir = "./result/test/davis_vis" 84 | args.corp_size =(473, 473) 85 | 86 | else: 87 | print("dataset error") 88 | 89 | def convert_state_dict(state_dict): 90 | """Converts a state dict saved from a dataParallel module to normal 91 | module state_dict inplace 92 | :param state_dict is the loaded DataParallel model_state 93 | You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it 94 | without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can 95 | load the weights file, create a new ordered dict without the module prefix, and load it back 96 | """ 97 | state_dict_new = OrderedDict() 98 | #print(type(state_dict)) 99 | for k, v in state_dict.items(): 100 | #print(k) 101 | name = k[7:] # remove the prefix module. 102 | # My heart is broken, the pytorch have no ability to do with the problem. 103 | state_dict_new[name] = v 104 | if name == 'linear_e.weight': 105 | np.save('weight_matrix.npy',v.cpu().numpy()) 106 | return state_dict_new 107 | 108 | def sigmoid(inX): 109 | return 1.0/(1+np.exp(-inX))#定义一个sigmoid方法,其本质就是1/(1+e^-x) 110 | 111 | def main(): 112 | args = get_arguments() 113 | print("=====> Configure dataset and model") 114 | configure_dataset_model(args) 115 | print(args) 116 | model = CoattentionNet(num_classes=args.num_classes) 117 | 118 | saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage) 119 | #print(saved_state_dict.keys()) 120 | #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()}) 121 | model.load_state_dict( convert_state_dict(saved_state_dict["model"]) ) #convert_state_dict(saved_state_dict["model"]) 122 | 123 | model.eval() 124 | model.cuda() 125 | if args.dataset == 'voc12': 126 | testloader = data.DataLoader(VOCDataTestSet(args.data_dir, args.data_list, crop_size=(505, 505),mean= args.img_mean), 127 | batch_size=1, shuffle=False, pin_memory=True) 128 | interp = nn.Upsample(size=(505, 505), mode='bilinear') 129 | voc_colorize = VOCColorize() 130 | 131 | elif args.dataset == 'davis': #for davis 2016 132 | db_test = db.PairwiseImg(train=False, inputRes=(473,473), db_root_dir=args.data_dir, transform=None, seq_name = None, sample_range = args.sample_range) #db_root_dir() --> '/path/to/DAVIS-2016' train path 133 | testloader = data.DataLoader(db_test, batch_size= 1, shuffle=False, num_workers=0) 134 | #voc_colorize = VOCColorize() 135 | else: 136 | print("dataset error") 137 | 138 | data_list = [] 139 | 140 | if args.save_segimage: 141 | if not os.path.exists(args.seg_save_dir) and not os.path.exists(args.vis_save_dir): 142 | os.makedirs(args.seg_save_dir) 143 | os.makedirs(args.vis_save_dir) 144 | print("======> test set size:", len(testloader)) 145 | my_index = 0 146 | old_temp='' 147 | for index, batch in enumerate(testloader): 148 | print('%d processd'%(index)) 149 | target = batch['target'] 150 | #search = batch['search'] 151 | temp = batch['seq_name'] 152 | args.seq_name=temp[0] 153 | print(args.seq_name) 154 | if old_temp==args.seq_name: 155 | my_index = my_index+1 156 | else: 157 | my_index = 0 158 | output_sum = 0 159 | for i in range(0,args.sample_range): 160 | search = batch['search'+'_'+str(i)] 161 | search_im = search 162 | #print(search_im.size()) 163 | output = model(Variable(target, volatile=True).cuda(),Variable(search_im, volatile=True).cuda()) 164 | #print(output[0]) # output有两个 165 | output_sum = output_sum + output[0].data[0,0].cpu().numpy() #分割那个分支的结果 166 | #np.save('infer'+str(i)+'.npy',output1) 167 | #output2 = output[1].data[0, 0].cpu().numpy() #interp' 168 | 169 | output1 = output_sum/args.sample_range 170 | 171 | first_image = np.array(Image.open(args.data_dir+'/JPEGImages/480p/blackswan/00000.jpg')) 172 | original_shape = first_image.shape 173 | output1 = cv2.resize(output1, (original_shape[1],original_shape[0])) 174 | 175 | mask = (output1*255).astype(np.uint8) 176 | #print(mask.shape[0]) 177 | mask = Image.fromarray(mask) 178 | 179 | 180 | if args.dataset == 'voc12': 181 | print(output.shape) 182 | print(size) 183 | output = output[:,:size[0],:size[1]] 184 | output = output.transpose(1,2,0) 185 | output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8) 186 | if args.save_segimage: 187 | seg_filename = os.path.join(args.seg_save_dir, '{}.png'.format(name[0])) 188 | color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB') 189 | color_file.save(seg_filename) 190 | 191 | elif args.dataset == 'davis': 192 | 193 | save_dir_res = os.path.join(args.seg_save_dir, 'Results', args.seq_name) 194 | old_temp=args.seq_name 195 | if not os.path.exists(save_dir_res): 196 | os.makedirs(save_dir_res) 197 | if args.save_segimage: 198 | my_index1 = str(my_index).zfill(5) 199 | seg_filename = os.path.join(save_dir_res, '{}.png'.format(my_index1)) 200 | #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB') 201 | mask.save(seg_filename) 202 | #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0) 203 | #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True) 204 | else: 205 | print("dataset error") 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /dataloaders/PairwiseImg_video.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 12 11:39:54 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | from __future__ import division 9 | 10 | import os 11 | import numpy as np 12 | import cv2 13 | from scipy.misc import imresize 14 | import scipy.misc 15 | import random 16 | 17 | from dataloaders.helpers import * 18 | from torch.utils.data import Dataset 19 | 20 | def flip(I,flip_p): 21 | if flip_p>0.5: 22 | return np.fliplr(I) 23 | else: 24 | return I 25 | 26 | def scale_im(img_temp,scale): 27 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 28 | return cv2.resize(img_temp,new_dims).astype(float) 29 | 30 | def scale_gt(img_temp,scale): 31 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 32 | return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float) 33 | 34 | def my_crop(img,gt): 35 | H = int(0.9 * img.shape[0]) 36 | W = int(0.9 * img.shape[1]) 37 | H_offset = random.choice(range(img.shape[0] - H)) 38 | W_offset = random.choice(range(img.shape[1] - W)) 39 | H_slice = slice(H_offset, H_offset + H) 40 | W_slice = slice(W_offset, W_offset + W) 41 | img = img[H_slice, W_slice, :] 42 | gt = gt[H_slice, W_slice] 43 | 44 | return img, gt 45 | 46 | class PairwiseImg(Dataset): 47 | """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities""" 48 | 49 | def __init__(self, train=True, 50 | inputRes=None, 51 | db_root_dir='/DAVIS-2016', 52 | img_root_dir = None, 53 | transform=None, 54 | meanval=(104.00699, 116.66877, 122.67892), 55 | seq_name=None, sample_range=10): 56 | """Loads image to label pairs for tool pose estimation 57 | db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" 58 | """ 59 | self.train = train 60 | self.range = sample_range 61 | self.inputRes = inputRes 62 | self.img_root_dir = img_root_dir 63 | self.db_root_dir = db_root_dir 64 | self.transform = transform 65 | self.meanval = meanval 66 | self.seq_name = seq_name 67 | 68 | if self.train: 69 | fname = 'train_seqs' 70 | else: 71 | fname = 'val_seqs' 72 | 73 | if self.seq_name is None: #所有的数据集都参与训练 74 | with open(os.path.join(db_root_dir, fname + '.txt')) as f: 75 | seqs = f.readlines() 76 | video_list = [] 77 | labels = [] 78 | Index = {} 79 | image_list = [] 80 | im_label = [] 81 | for seq in seqs: 82 | images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n')))) 83 | images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) 84 | start_num = len(video_list) 85 | video_list.extend(images_path) 86 | end_num = len(video_list) 87 | Index[seq.strip('\n')]= np.array([start_num, end_num]) 88 | lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n')))) 89 | lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) 90 | labels.extend(lab_path) 91 | 92 | with open('/home/ubuntu/xiankai/saliency_data.txt') as f: 93 | seqs = f.readlines() 94 | #data_list = np.sort(os.listdir(db_root_dir)) 95 | for seq in seqs: #所有数据集 96 | seq = seq.strip('\n') 97 | images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集,比如DUT 98 | # Initialize the original DAVIS splits for training the parent network 99 | images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images)) 100 | image_list.extend(images_path) 101 | lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps')) 102 | lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab)) 103 | im_label.extend(lab_path) 104 | else: #针对所有的训练样本, video_list存放的是图片的路径 105 | 106 | # Initialize the per sequence images for online training 107 | names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 108 | video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img)) 109 | #name_label = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 110 | labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])] 111 | labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None 112 | if self.train: 113 | video_list = [video_list[0]] 114 | labels = [labels[0]] 115 | 116 | assert (len(labels) == len(video_list)) 117 | 118 | self.video_list = video_list 119 | self.labels = labels 120 | self.image_list = image_list 121 | self.img_labels = im_label 122 | self.Index = Index 123 | #img_files = open('all_im.txt','w+') 124 | 125 | def __len__(self): 126 | print(len(self.video_list), len(self.image_list)) 127 | return len(self.video_list) 128 | 129 | def __getitem__(self, idx): 130 | target, target_gt = self.make_video_gt_pair(idx) 131 | target_id = idx 132 | img_idx = np.random.randint(1,len(self.image_list)-1) 133 | seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称 134 | if self.train: 135 | my_index = self.Index[seq_name1] 136 | search_id = np.random.randint(my_index[0], my_index[1])#min(len(self.video_list)-1, target_id+np.random.randint(1,self.range+1)) 137 | if search_id == target_id: 138 | search_id = np.random.randint(my_index[0], my_index[1]) 139 | search, search_gt = self.make_video_gt_pair(search_id) 140 | img, img_gt = self.make_img_gt_pair(img_idx) 141 | sample = {'target': target, 'target_gt': target_gt, 'search': search, 'search_gt': search_gt, \ 142 | 'img': img, 'img_gt': img_gt} 143 | #np.save('search1.npy',search) 144 | #np.save('search_gt.npy',search_gt) 145 | if self.seq_name is not None: 146 | fname = os.path.join(self.seq_name, "%05d" % idx) 147 | sample['fname'] = fname 148 | 149 | if self.transform is not None: 150 | sample = self.transform(sample) 151 | 152 | else: 153 | img, gt = self.make_video_gt_pair(idx) 154 | sample = {'image': img, 'gt': gt} 155 | if self.seq_name is not None: 156 | fname = os.path.join(self.seq_name, "%05d" % idx) 157 | sample['fname'] = fname 158 | 159 | 160 | 161 | return sample #这个类最后的输出 162 | 163 | def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 164 | """ 165 | Make the image-ground-truth pair 166 | """ 167 | img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR) 168 | if self.labels[idx] is not None and self.train: 169 | label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE) 170 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 171 | else: 172 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 173 | 174 | ## 已经读取了image以及对应的ground truth可以进行data augmentation了 175 | if self.train: #scaling, cropping and flipping 176 | img, label = my_crop(img,label) 177 | scale = random.uniform(0.7, 1.3) 178 | flip_p = random.uniform(0, 1) 179 | img_temp = scale_im(img,scale) 180 | img_temp = flip(img_temp,flip_p) 181 | gt_temp = scale_gt(label,scale) 182 | gt_temp = flip(gt_temp,flip_p) 183 | 184 | img = img_temp 185 | label = gt_temp 186 | 187 | if self.inputRes is not None: 188 | img = imresize(img, self.inputRes) 189 | #print('ok1') 190 | #scipy.misc.imsave('label.png',label) 191 | #scipy.misc.imsave('img.png',img) 192 | if self.labels[idx] is not None and self.train: 193 | label = imresize(label, self.inputRes, interp='nearest') 194 | 195 | img = np.array(img, dtype=np.float32) 196 | #img = img[:, :, ::-1] 197 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 198 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 199 | 200 | if self.labels[idx] is not None and self.train: 201 | gt = np.array(label, dtype=np.int32) 202 | gt[gt!=0]=1 203 | #gt = gt/np.max([gt.max(), 1e-8]) 204 | #np.save('gt.npy') 205 | return img, gt 206 | 207 | def get_img_size(self): 208 | img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0])) 209 | 210 | return list(img.shape[:2]) 211 | 212 | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 213 | """ 214 | Make the image-ground-truth pair 215 | """ 216 | img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR) 217 | #print(os.path.join(self.db_root_dir, self.img_list[idx])) 218 | if self.img_labels[idx] is not None and self.train: 219 | label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE) 220 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 221 | else: 222 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 223 | 224 | if self.inputRes is not None: 225 | img = imresize(img, self.inputRes) 226 | if self.img_labels[idx] is not None and self.train: 227 | label = imresize(label, self.inputRes, interp='nearest') 228 | 229 | img = np.array(img, dtype=np.float32) 230 | #img = img[:, :, ::-1] 231 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 232 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 233 | 234 | if self.img_labels[idx] is not None and self.train: 235 | gt = np.array(label, dtype=np.int32) 236 | gt[gt!=0]=1 237 | #gt = gt/np.max([gt.max(), 1e-8]) 238 | #np.save('gt.npy') 239 | return img, gt 240 | 241 | if __name__ == '__main__': 242 | import custom_transforms as tr 243 | import torch 244 | from torchvision import transforms 245 | from matplotlib import pyplot as plt 246 | 247 | transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()]) 248 | 249 | #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016', 250 | # train=True, transform=transforms) 251 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) 252 | # 253 | # for i, data in enumerate(dataloader): 254 | # plt.figure() 255 | # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt']))) 256 | # if i == 10: 257 | # break 258 | # 259 | # plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/PairwiseImg_video_try.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 12 11:39:54 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | from __future__ import division 9 | 10 | import os 11 | import numpy as np 12 | import cv2 13 | from scipy.misc import imresize 14 | import scipy.misc 15 | import random 16 | import torch 17 | from dataloaders.helpers import * 18 | from torch.utils.data import Dataset 19 | 20 | def flip(I,flip_p): 21 | if flip_p>0.5: 22 | return np.fliplr(I) 23 | else: 24 | return I 25 | 26 | def scale_im(img_temp,scale): 27 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 28 | return cv2.resize(img_temp,new_dims).astype(float) 29 | 30 | def scale_gt(img_temp,scale): 31 | new_dims = ( int(img_temp.shape[0]*scale), int(img_temp.shape[1]*scale) ) 32 | return cv2.resize(img_temp,new_dims,interpolation = cv2.INTER_NEAREST).astype(float) 33 | 34 | def my_crop(img,gt): 35 | H = int(0.9 * img.shape[0]) 36 | W = int(0.9 * img.shape[1]) 37 | H_offset = random.choice(range(img.shape[0] - H)) 38 | W_offset = random.choice(range(img.shape[1] - W)) 39 | H_slice = slice(H_offset, H_offset + H) 40 | W_slice = slice(W_offset, W_offset + W) 41 | img = img[H_slice, W_slice, :] 42 | gt = gt[H_slice, W_slice] 43 | 44 | return img, gt 45 | 46 | class PairwiseImg(Dataset): 47 | """DAVIS 2016 dataset constructed using the PyTorch built-in functionalities""" 48 | 49 | def __init__(self, train=True, 50 | inputRes=None, 51 | db_root_dir='/DAVIS-2016', 52 | img_root_dir = None, 53 | transform=None, 54 | meanval=(104.00699, 116.66877, 122.67892), 55 | seq_name=None, sample_range=10): 56 | """Loads image to label pairs for tool pose estimation 57 | db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" 58 | """ 59 | self.train = train 60 | self.range = sample_range 61 | self.inputRes = inputRes 62 | self.img_root_dir = img_root_dir 63 | self.db_root_dir = db_root_dir 64 | self.transform = transform 65 | self.meanval = meanval 66 | self.seq_name = seq_name 67 | 68 | if self.train: 69 | fname = 'train_seqs' 70 | else: 71 | fname = 'val_seqs' 72 | 73 | if self.seq_name is None: #所有的数据集都参与训练 74 | with open(os.path.join(db_root_dir, fname + '.txt')) as f: 75 | seqs = f.readlines() 76 | video_list = [] 77 | labels = [] 78 | Index = {} 79 | image_list = [] 80 | im_label = [] 81 | for seq in seqs: 82 | images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip('\n')))) 83 | images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) 84 | start_num = len(video_list) 85 | video_list.extend(images_path) 86 | end_num = len(video_list) 87 | Index[seq.strip('\n')]= np.array([start_num, end_num]) 88 | lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip('\n')))) 89 | lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) 90 | labels.extend(lab_path) 91 | 92 | with open('/home/ubuntu/xiankai/saliency_data.txt') as f: 93 | seqs = f.readlines() 94 | #data_list = np.sort(os.listdir(db_root_dir)) 95 | for seq in seqs: #所有数据集 96 | seq = seq.strip('\n') 97 | images = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/images/'))#针对某个数据集,比如DUT 98 | # Initialize the original DAVIS splits for training the parent network 99 | images_path = list(map(lambda x: os.path.join((seq +'/images'), x), images)) 100 | image_list.extend(images_path) 101 | lab = np.sort(os.listdir(os.path.join(img_root_dir,seq.strip())+'/saliencymaps')) 102 | lab_path = list(map(lambda x: os.path.join((seq +'/saliencymaps'),x), lab)) 103 | im_label.extend(lab_path) 104 | else: #针对所有的训练样本, video_list存放的是图片的路径 105 | 106 | # Initialize the per sequence images for online training 107 | names_img = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 108 | video_list = list(map(lambda x: os.path.join(( str(seq_name)), x), names_img)) 109 | #name_label = np.sort(os.listdir(os.path.join(db_root_dir, str(seq_name)))) 110 | labels = [os.path.join( (str(seq_name)+'/saliencymaps'), names_img[0])] 111 | labels.extend([None]*(len(names_img)-1)) #在labels这个列表后面添加元素None 112 | if self.train: 113 | video_list = [video_list[0]] 114 | labels = [labels[0]] 115 | 116 | assert (len(labels) == len(video_list)) 117 | 118 | self.video_list = video_list 119 | self.labels = labels 120 | self.image_list = image_list 121 | self.img_labels = im_label 122 | self.Index = Index 123 | #img_files = open('all_im.txt','w+') 124 | 125 | def __len__(self): 126 | print(len(self.video_list), len(self.image_list)) 127 | return len(self.video_list) 128 | 129 | def __getitem__(self, idx): 130 | target, target_grt = self.make_video_gt_pair(idx) 131 | target_id = idx 132 | img_idx = random.sample([my_i for my_i in range(0,len(self.image_list))],2) 133 | 134 | seq_name1 = self.video_list[idx].split('/')[-2] #获取视频名称 135 | my_index = self.Index[seq_name1] 136 | video_idx = random.sample([my_i for my_i in range(my_index[0],my_index[1])],3) 137 | target_1, target_grt_1 = self.make_video_gt_pair(video_idx[0]) 138 | #print('type:', type(target)) 139 | 140 | #targets = torch.stack((torch.from_numpy(target),torch.from_numpy(target_1))) 141 | #target_grts = torch.stack((torch.from_numpy(target_grt),torch.from_numpy(target_grt_1))) 142 | #print('size:', torch.from_numpy(target_grt).size(), torch.from_numpy(target_grt_1).size()) 143 | if self.train: 144 | #my_index = self.Index[seq_name1] 145 | search, search_grt = self.make_video_gt_pair(video_idx[1]) 146 | search_1, search_grt_1 = self.make_video_gt_pair(video_idx[2]) 147 | searchs = torch.stack((torch.from_numpy(search), torch.from_numpy(search_1))) 148 | search_grts = torch.stack((torch.from_numpy(search_grt), torch.from_numpy(search_grt_1))) 149 | img, img_grt = self.make_img_gt_pair(img_idx[0]) 150 | #img_1, img_grt_1 = self.make_img_gt_pair(img_idx[1]) 151 | #imgs = torch.stack((torch.from_numpy(img), torch.from_numpy(img_1))) 152 | #img_grts = torch.stack((torch.torch.from_numpy(img_grt), torch.from_numpy(img_grt_1))) 153 | sample = {'target': target, 'target_grt': target_grt, 'search': searchs, 'search_grt': search_grts, \ 154 | 'img': img, 'img_grt': img_grt} 155 | #np.save('search1.npy',search) 156 | if self.seq_name is not None: 157 | fname = os.path.join(self.seq_name, "%05d" % idx) 158 | sample['fname'] = fname 159 | 160 | if self.transform is not None: 161 | sample = self.transform(sample) 162 | 163 | else: 164 | img, gt = self.make_video_gt_pair(idx) 165 | sample = {'image': img, 'gt': gt} 166 | if self.seq_name is not None: 167 | fname = os.path.join(self.seq_name, "%05d" % idx) 168 | sample['fname'] = fname 169 | 170 | 171 | 172 | return sample #这个类最后的输出 173 | 174 | def make_video_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 175 | """ 176 | Make the image-ground-truth pair 177 | """ 178 | img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[idx]), cv2.IMREAD_COLOR) 179 | if self.labels[idx] is not None and self.train: 180 | label = cv2.imread(os.path.join(self.db_root_dir, self.labels[idx]), cv2.IMREAD_GRAYSCALE) 181 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 182 | else: 183 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 184 | 185 | ## 已经读取了image以及对应的ground truth可以进行data augmentation了 186 | if self.train: #scaling, cropping and flipping 187 | img, label = my_crop(img,label) 188 | scale = random.uniform(0.7, 1.3) 189 | flip_p = random.uniform(0, 1) 190 | img_temp = scale_im(img,scale) 191 | img_temp = flip(img_temp,flip_p) 192 | gt_temp = scale_gt(label,scale) 193 | gt_temp = flip(gt_temp,flip_p) 194 | 195 | img = img_temp 196 | label = gt_temp 197 | 198 | if self.inputRes is not None: 199 | img = imresize(img, self.inputRes) 200 | #print('ok1') 201 | #scipy.misc.imsave('label.png',label) 202 | #scipy.misc.imsave('img.png',img) 203 | if self.labels[idx] is not None and self.train: 204 | label = imresize(label, self.inputRes, interp='nearest') 205 | 206 | img = np.array(img, dtype=np.float32) 207 | #img = img[:, :, ::-1] 208 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 209 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 210 | 211 | if self.labels[idx] is not None and self.train: 212 | gt = np.array(label, dtype=np.int32) 213 | gt[gt!=0]=1 214 | #gt = gt/np.max([gt.max(), 1e-8]) 215 | #np.save('gt.npy') 216 | return img, gt 217 | 218 | def get_img_size(self): 219 | img = cv2.imread(os.path.join(self.db_root_dir, self.video_list[0])) 220 | 221 | return list(img.shape[:2]) 222 | 223 | def make_img_gt_pair(self, idx): #这个函数存在的意义是为了getitem函数服务的 224 | """ 225 | Make the image-ground-truth pair 226 | """ 227 | img = cv2.imread(os.path.join(self.img_root_dir, self.image_list[idx]),cv2.IMREAD_COLOR) 228 | #print(os.path.join(self.db_root_dir, self.img_list[idx])) 229 | if self.img_labels[idx] is not None and self.train: 230 | label = cv2.imread(os.path.join(self.img_root_dir, self.img_labels[idx]),cv2.IMREAD_GRAYSCALE) 231 | #print(os.path.join(self.db_root_dir, self.labels[idx])) 232 | else: 233 | gt = np.zeros(img.shape[:-1], dtype=np.uint8) 234 | 235 | if self.inputRes is not None: 236 | img = imresize(img, self.inputRes) 237 | if self.img_labels[idx] is not None and self.train: 238 | label = imresize(label, self.inputRes, interp='nearest') 239 | 240 | img = np.array(img, dtype=np.float32) 241 | #img = img[:, :, ::-1] 242 | img = np.subtract(img, np.array(self.meanval, dtype=np.float32)) 243 | img = img.transpose((2, 0, 1)) # NHWC -> NCHW 244 | 245 | if self.img_labels[idx] is not None and self.train: 246 | gt = np.array(label, dtype=np.int32) 247 | gt[gt!=0]=1 248 | #gt = gt/np.max([gt.max(), 1e-8]) 249 | #np.save('gt.npy') 250 | return img, gt 251 | 252 | if __name__ == '__main__': 253 | import custom_transforms as tr 254 | import torch 255 | from torchvision import transforms 256 | from matplotlib import pyplot as plt 257 | 258 | transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()]) 259 | 260 | #dataset = DAVIS2016(db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016', 261 | # train=True, transform=transforms) 262 | #dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) 263 | # 264 | # for i, data in enumerate(dataloader): 265 | # plt.figure() 266 | # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt']))) 267 | # if i == 10: 268 | # break 269 | # 270 | # plt.show(block=True) -------------------------------------------------------------------------------- /deeplab/siamese_model_conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Sep 16 10:01:14 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | affine_par = True 13 | #区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络,然后加上了非对称的分支 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 59 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 60 | padding = dilation 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 62 | padding=padding, bias=False, dilation=dilation) 63 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ASPP(nn.Module): 94 | def __init__(self, dilation_series, padding_series, depth): 95 | super(ASPP, self).__init__() 96 | self.mean = nn.AdaptiveAvgPool2d((1,1)) 97 | self.conv= nn.Conv2d(2048, depth, 1,1) 98 | self.bn_x = nn.BatchNorm2d(depth) 99 | self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1) 100 | self.bn_0 = nn.BatchNorm2d(depth) 101 | self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0]) 102 | self.bn_1 = nn.BatchNorm2d(depth) 103 | self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1]) 104 | self.bn_2 = nn.BatchNorm2d(depth) 105 | self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2]) 106 | self.bn_3 = nn.BatchNorm2d(depth) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 ) #512 1x1Conv 109 | self.bn = nn.BatchNorm2d(256) 110 | self.prelu = nn.PReLU() 111 | #for m in self.conv2d_list: 112 | # m.weight.data.normal_(0, 0.01) 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, 0.01) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_stage_(self, dilation1, padding1): 122 | Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes 123 | Bn = nn.BatchNorm2d(256) 124 | Relu = nn.ReLU(inplace=True) 125 | return nn.Sequential(Conv, Bn, Relu) 126 | 127 | 128 | def forward(self, x): 129 | #out = self.conv2d_list[0](x) 130 | #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list] 131 | size=x.shape[2:] 132 | image_features=self.mean(x) 133 | image_features=self.conv(image_features) 134 | image_features = self.bn_x(image_features) 135 | image_features = self.relu(image_features) 136 | image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True) 137 | out_0 = self.conv2d_0(x) 138 | out_0 = self.bn_0(out_0) 139 | out_0 = self.relu(out_0) 140 | out_1 = self.conv2d_1(x) 141 | out_1 = self.bn_1(out_1) 142 | out_1 = self.relu(out_1) 143 | out_2 = self.conv2d_2(x) 144 | out_2 = self.bn_2(out_2) 145 | out_2 = self.relu(out_2) 146 | out_3 = self.conv2d_3(x) 147 | out_3 = self.bn_3(out_3) 148 | out_3 = self.relu(out_3) 149 | out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1) 150 | out = self.bottleneck(out) 151 | out = self.bn(out) 152 | out = self.prelu(out) 153 | #for i in range(len(self.conv2d_list) - 1): 154 | # out += self.conv2d_list[i + 1](x) 155 | 156 | return out 157 | 158 | 159 | 160 | class ResNet(nn.Module): 161 | def __init__(self, block, layers, num_classes): 162 | self.inplanes = 64 163 | super(ResNet, self).__init__() 164 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 165 | self.bn1 = nn.BatchNorm2d(64, affine=affine_par) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 172 | self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512) 173 | self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1) 174 | self.softmax = nn.Sigmoid()#nn.Softmax() 175 | 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 179 | m.weight.data.normal_(0, 0.01) 180 | elif isinstance(m, nn.BatchNorm2d): 181 | m.weight.data.fill_(1) 182 | m.bias.data.zero_() 183 | 184 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 185 | downsample = None 186 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 187 | downsample = nn.Sequential( 188 | nn.Conv2d(self.inplanes, planes * block.expansion, 189 | kernel_size=1, stride=stride, bias=False), 190 | nn.BatchNorm2d(planes * block.expansion, affine=affine_par)) 191 | for i in downsample._modules['1'].parameters(): 192 | i.requires_grad = False 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) 195 | self.inplanes = planes * block.expansion 196 | for i in range(1, blocks): 197 | layers.append(block(self.inplanes, planes, dilation=dilation)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def _make_pred_layer(self, block, dilation_series, padding_series, num_classes): 202 | return block(dilation_series, padding_series, num_classes) 203 | 204 | def forward(self, x): 205 | input_size = x.size()[2:] 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | fea = self.layer5(x) 215 | x = self.main_classifier(fea) 216 | #print("before upsample, tensor size:", x.size()) 217 | x = F.upsample(x, input_size, mode='bilinear') #upsample to the size of input image, scale=8 218 | #print("after upsample, tensor size:", x.size()) 219 | x = self.softmax(x) 220 | return fea, x 221 | 222 | class CoattentionModel(nn.Module): 223 | def __init__(self, block, layers, num_classes, all_channel=256, all_dim=60*60): #473./8=60 224 | super(CoattentionModel, self).__init__() 225 | self.encoder = ResNet(block, layers, num_classes) 226 | self.linear_e = nn.Linear(all_channel, all_channel,bias = False) 227 | self.channel = all_channel 228 | self.dim = all_dim 229 | self.gate = nn.Conv2d(all_channel, 1, kernel_size = 1, bias = False) 230 | self.gate_s = nn.Sigmoid() 231 | self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False) 232 | self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False) 233 | self.bn1 = nn.BatchNorm2d(all_channel) 234 | self.bn2 = nn.BatchNorm2d(all_channel) 235 | self.prelu = nn.ReLU(inplace=True) 236 | self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True) 237 | self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True) 238 | self.softmax = nn.Sigmoid() 239 | 240 | for m in self.modules(): 241 | if isinstance(m, nn.Conv2d): 242 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 243 | m.weight.data.normal_(0, 0.01) 244 | #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 245 | #init.xavier_normal(m.weight.data) 246 | #m.bias.data.fill_(0) 247 | elif isinstance(m, nn.BatchNorm2d): 248 | m.weight.data.fill_(1) 249 | m.bias.data.zero_() 250 | 251 | 252 | def forward(self, input1, input2): #注意input2 可以是多帧图像 253 | 254 | #input1_att, input2_att = self.coattention(input1, input2) 255 | input_size = input1.size()[2:] 256 | exemplar, temp = self.encoder(input1) 257 | query, temp = self.encoder(input2) 258 | fea_size = query.size()[2:] 259 | all_dim = fea_size[0]*fea_size[1] 260 | exemplar_flat = exemplar.view(-1, query.size()[1], all_dim) #N,C,H*W 261 | query_flat = query.view(-1, query.size()[1], all_dim) 262 | exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous() #batch size x dim x num 263 | exemplar_corr = self.linear_e(exemplar_t) # 264 | A = torch.bmm(exemplar_corr, query_flat) 265 | A1 = F.softmax(A.clone(), dim = 1) # 266 | B = F.softmax(torch.transpose(A,1,2),dim=1) 267 | query_att = torch.bmm(exemplar_flat, A1).contiguous() #注意我们这个地方要不要用交互以及Residual的结构 268 | exemplar_att = torch.bmm(query_flat, B).contiguous() 269 | 270 | input1_att = exemplar_att.view(-1, query.size()[1], fea_size[0], fea_size[1]) 271 | input2_att = query_att.view(-1, query.size()[1], fea_size[0], fea_size[1]) 272 | input1_mask = self.gate(input1_att) 273 | input2_mask = self.gate(input2_att) 274 | input1_mask = self.gate_s(input1_mask) 275 | input2_mask = self.gate_s(input2_mask) 276 | input1_att = input1_att * input1_mask 277 | input2_att = input2_att * input2_mask 278 | input1_att = torch.cat([input1_att, exemplar],1) 279 | input2_att = torch.cat([input2_att, query],1) 280 | input1_att = self.conv1(input1_att ) 281 | input2_att = self.conv2(input2_att ) 282 | input1_att = self.bn1(input1_att ) 283 | input2_att = self.bn2(input2_att ) 284 | input1_att = self.prelu(input1_att ) 285 | input2_att = self.prelu(input2_att ) 286 | x1 = self.main_classifier1(input1_att) 287 | x2 = self.main_classifier2(input2_att) 288 | x1 = F.upsample(x1, input_size, mode='bilinear') #upsample to the size of input image, scale=8 289 | x2 = F.upsample(x2, input_size, mode='bilinear') #upsample to the size of input image, scale=8 290 | #print("after upsample, tensor size:", x.size()) 291 | x1 = self.softmax(x1) 292 | x2 = self.softmax(x2) 293 | 294 | # x1 = self.softmax(x1) 295 | # x2 = self.softmax(x2) 296 | return x1, x2, temp #shape: NxCx 297 | 298 | 299 | def Res_Deeplab(num_classes=2): 300 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes-1) 301 | return model 302 | 303 | def CoattentionNet(num_classes=2): 304 | model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1) 305 | 306 | return model 307 | -------------------------------------------------------------------------------- /deeplab/siamese_model_conf_try_single.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Sep 16 10:01:14 2018 4 | 5 | @author: carri 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | affine_par = True 13 | import numpy as np 14 | #区别于siamese_model_concat的地方就是采用的最标准的deeplab_v3的基础网络,然后加上了非对称的分支 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 60 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 61 | padding = dilation 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 63 | padding=padding, bias=False, dilation=dilation) 64 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ASPP(nn.Module): 95 | def __init__(self, dilation_series, padding_series, depth): 96 | super(ASPP, self).__init__() 97 | self.mean = nn.AdaptiveAvgPool2d((1,1)) 98 | self.conv= nn.Conv2d(2048, depth, 1,1) 99 | self.bn_x = nn.BatchNorm2d(depth) 100 | self.conv2d_0 = nn.Conv2d(2048, depth, kernel_size=1, stride=1) 101 | self.bn_0 = nn.BatchNorm2d(depth) 102 | self.conv2d_1 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[0], dilation=dilation_series[0]) 103 | self.bn_1 = nn.BatchNorm2d(depth) 104 | self.conv2d_2 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[1], dilation=dilation_series[1]) 105 | self.bn_2 = nn.BatchNorm2d(depth) 106 | self.conv2d_3 = nn.Conv2d(2048, depth, kernel_size=3, stride=1, padding=padding_series[2], dilation=dilation_series[2]) 107 | self.bn_3 = nn.BatchNorm2d(depth) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.bottleneck = nn.Conv2d( depth*5, 256, kernel_size=3, padding=1 ) #512 1x1Conv 110 | self.bn = nn.BatchNorm2d(256) 111 | self.prelu = nn.PReLU() 112 | #for m in self.conv2d_list: 113 | # m.weight.data.normal_(0, 0.01) 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, 0.01) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_stage_(self, dilation1, padding1): 123 | Conv = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=padding1, dilation=dilation1, bias=True)#classes 124 | Bn = nn.BatchNorm2d(256) 125 | Relu = nn.ReLU(inplace=True) 126 | return nn.Sequential(Conv, Bn, Relu) 127 | 128 | 129 | def forward(self, x): 130 | #out = self.conv2d_list[0](x) 131 | #mulBranches = [conv2d_l(x) for conv2d_l in self.conv2d_list] 132 | size=x.shape[2:] 133 | image_features=self.mean(x) 134 | image_features=self.conv(image_features) 135 | image_features = self.bn_x(image_features) 136 | image_features = self.relu(image_features) 137 | image_features=F.upsample(image_features, size=size, mode='bilinear', align_corners=True) 138 | out_0 = self.conv2d_0(x) 139 | out_0 = self.bn_0(out_0) 140 | out_0 = self.relu(out_0) 141 | out_1 = self.conv2d_1(x) 142 | out_1 = self.bn_1(out_1) 143 | out_1 = self.relu(out_1) 144 | out_2 = self.conv2d_2(x) 145 | out_2 = self.bn_2(out_2) 146 | out_2 = self.relu(out_2) 147 | out_3 = self.conv2d_3(x) 148 | out_3 = self.bn_3(out_3) 149 | out_3 = self.relu(out_3) 150 | out = torch.cat([image_features, out_0, out_1, out_2, out_3], 1) 151 | out = self.bottleneck(out) 152 | out = self.bn(out) 153 | out = self.prelu(out) 154 | #for i in range(len(self.conv2d_list) - 1): 155 | # out += self.conv2d_list[i + 1](x) 156 | 157 | return out 158 | 159 | 160 | 161 | class ResNet(nn.Module): 162 | def __init__(self, block, layers, num_classes): 163 | self.inplanes = 64 164 | super(ResNet, self).__init__() 165 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 166 | self.bn1 = nn.BatchNorm2d(64, affine=affine_par) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 169 | self.layer1 = self._make_layer(block, 64, layers[0]) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 172 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 173 | self.layer5 = self._make_pred_layer(ASPP, [ 6, 12, 18], [6, 12, 18], 512) 174 | self.main_classifier = nn.Conv2d(256, num_classes, kernel_size=1) 175 | self.softmax = nn.Sigmoid()#nn.Softmax() 176 | 177 | for m in self.modules(): 178 | if isinstance(m, nn.Conv2d): 179 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 180 | m.weight.data.normal_(0, 0.01) 181 | elif isinstance(m, nn.BatchNorm2d): 182 | m.weight.data.fill_(1) 183 | m.bias.data.zero_() 184 | 185 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 186 | downsample = None 187 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 188 | downsample = nn.Sequential( 189 | nn.Conv2d(self.inplanes, planes * block.expansion, 190 | kernel_size=1, stride=stride, bias=False), 191 | nn.BatchNorm2d(planes * block.expansion, affine=affine_par)) 192 | for i in downsample._modules['1'].parameters(): 193 | i.requires_grad = False 194 | layers = [] 195 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) 196 | self.inplanes = planes * block.expansion 197 | for i in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, dilation=dilation)) 199 | 200 | return nn.Sequential(*layers) 201 | 202 | def _make_pred_layer(self, block, dilation_series, padding_series, num_classes): 203 | return block(dilation_series, padding_series, num_classes) 204 | 205 | def forward(self, x): 206 | input_size = x.size()[2:] 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | x = self.maxpool(x) 211 | x = self.layer1(x) 212 | x = self.layer2(x) 213 | x = self.layer3(x) 214 | x = self.layer4(x) 215 | fea = self.layer5(x) 216 | x = self.main_classifier(fea) 217 | #print("before upsample, tensor size:", x.size()) 218 | x = F.upsample(x, input_size, mode='bilinear') #upsample to the size of input image, scale=8 219 | #print("after upsample, tensor size:", x.size()) 220 | x = self.softmax(x) 221 | return fea, x 222 | 223 | class CoattentionModel(nn.Module): 224 | def __init__(self, block, layers, num_classes, all_channel=256, all_dim=60*60): #473./8=60 225 | super(CoattentionModel, self).__init__() 226 | self.nframes = 2 227 | self.encoder = ResNet(block, layers, num_classes) 228 | self.linear_e = nn.Linear(all_channel, all_channel,bias = False) 229 | self.channel = all_channel 230 | self.dim = all_dim 231 | self.gate = nn.Conv2d(all_channel, 1, kernel_size = 1, bias = False) 232 | self.gate_s = nn.Sigmoid() 233 | self.conv1 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False) 234 | self.conv2 = nn.Conv2d(all_channel*2, all_channel, kernel_size=3, padding=1, bias = False) 235 | self.bn1 = nn.BatchNorm2d(all_channel, affine=affine_par) 236 | self.bn2 = nn.BatchNorm2d(all_channel, affine=affine_par) 237 | self.prelu = nn.ReLU(inplace=True) 238 | self.main_classifier1 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True) 239 | self.main_classifier2 = nn.Conv2d(all_channel, num_classes, kernel_size=1, bias = True) 240 | self.softmax = nn.Sigmoid() 241 | 242 | for m in self.modules(): 243 | if isinstance(m, nn.Conv2d): 244 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 245 | m.weight.data.normal_(0, 0.01) 246 | #init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 247 | #init.xavier_normal(m.weight.data) 248 | #m.bias.data.fill_(0) 249 | elif isinstance(m, nn.BatchNorm2d): 250 | m.weight.data.fill_(1) 251 | m.bias.data.zero_() 252 | 253 | 254 | def forward(self, input1, input2): #注意input2 可以是多帧图像 255 | 256 | #input1_att, input2_att = self.coattention(input1, input2) 257 | exemplar, temp = self.encoder(input1) 258 | 259 | #print('feature size:', input1.size()) 260 | if len(input2.size() )>4: 261 | B, N, C, H, W = input2.size() # 2,2,3,473,473 262 | input_size = [H, W] 263 | video_frames2 = [elem.view(B, C, H, W) for elem in input2.split(1, dim=1)] 264 | # the length of exemplars is equal to the nframes 265 | querys= [self.encoder(video_frames2[frame]) for frame in range(0,self.nframes)] 266 | #query = torch.cat([querys[][0]],dim=1) 267 | #query1 = torch.cat([querys[1]], dim=1) 268 | query = torch.cat(([querys[frame][0] for frame in range(0,self.nframes )]), dim=2) 269 | #print('query size:', query.size()) 2*512*49*49 270 | predict_mask = torch.cat(([querys[frame][1] for frame in range(0,self.nframes )]), dim=1) 271 | #print('feature size:', exemplar.size()) 272 | fea_size = exemplar.size()[2:] 273 | exemplar_flat = exemplar.view(-1, self.channel, fea_size[0]*fea_size[1]) #N,C,H*W 274 | exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous() # batch size x dim x num 275 | exemplar_corr = self.linear_e(exemplar_t) 276 | #coattention_fea = 0 277 | query_flat = query.view(-1, self.channel, self.nframes*fea_size[0]*fea_size[1]) 278 | A = torch.bmm(exemplar_corr, query_flat) 279 | A = F.softmax(A, dim = 1) # 280 | B = F.softmax(torch.transpose(A,1,2),dim=1) 281 | query_att = torch.bmm(exemplar_flat, A).contiguous() #注意我们这个地方要不要用交互以及Residual的结构 282 | exemplar_att = torch.bmm(query_flat, B).contiguous() 283 | 284 | input1_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1]) 285 | input2_att = query_att.view(-1, self.channel, self.nframes*fea_size[0], fea_size[1]) 286 | input1_mask = self.gate(input1_att) 287 | #input2_mask = self.gate(input2_att) 288 | input1_mask = self.gate_s(input1_mask) 289 | #input2_mask = self.gate_s(input2_mask) 290 | input1_att_org = input1_att * input1_mask 291 | #coattention_fea = coattention_fea + input1_att_org 292 | 293 | #print('h_v size, h_v_org size:', torch.max(input1_att), torch.max(exemplar), torch.min(input1_att), torch.max(exemplar)) 294 | input1_att = torch.cat([input1_att_org, exemplar],1) 295 | input1_att = self.conv1(input1_att ) 296 | input1_att = self.bn1(input1_att ) 297 | input1_att = self.prelu(input1_att ) 298 | x1 = self.main_classifier1(input1_att) 299 | x1 = F.upsample(x1, input_size, mode='bilinear') #upsample to the size of input image, scale=8 300 | #upsample to the size of input image, scale=8 301 | #print("after upsample, tensor size:", x.size()) 302 | x1 = self.softmax(x1) 303 | else: 304 | x1 = exemplar 305 | 306 | return x1, temp #shape: NxCx 307 | 308 | 309 | def CoattentionNet(num_classes=2,nframes=2): 310 | model = CoattentionModel(Bottleneck,[3, 4, 23, 3], num_classes-1) 311 | 312 | return model -------------------------------------------------------------------------------- /train_iteration_conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 15 10:52:26 2018 4 | 5 | @author: carri 6 | """ 7 | #区别于deeplab_co_attention_concat在于采用了新的model(siamese_model_concat_new)来train 8 | 9 | import argparse 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils import data 13 | import numpy as np 14 | import pickle 15 | import cv2 16 | from torch.autograd import Variable 17 | import torch.optim as optim 18 | import scipy.misc 19 | import torch.backends.cudnn as cudnn 20 | import sys 21 | import os 22 | #from utils.balanced_BCE import class_balanced_cross_entropy_loss 23 | import os.path as osp 24 | #from psp.model import PSPNet 25 | #from dataloaders import davis_2016 as db 26 | from dataloaders import PairwiseImg_video as db #采用voc dataset的数据设置格式方法 27 | import matplotlib.pyplot as plt 28 | import random 29 | import timeit 30 | #from psp.model1 import CoattentionNet #基于pspnet搭建的co-attention 模型 31 | from deeplab.siamese_model_conf import CoattentionNet #siame_model 是直接将attend的model之后的结果输出 32 | #from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc 33 | start = timeit.default_timer() 34 | 35 | def get_arguments(): 36 | """Parse all the arguments provided from the CLI. 37 | 38 | Returns: 39 | A list of parsed arguments. 40 | """ 41 | parser = argparse.ArgumentParser(description="PSPnet Network") 42 | 43 | # optimatization configuration 44 | parser.add_argument("--is-training", action="store_true", 45 | help="Whether to updates the running means and variances during the training.") 46 | parser.add_argument("--learning-rate", type=float, default= 0.00025, 47 | help="Base learning rate for training with polynomial decay.") #0.001 48 | parser.add_argument("--weight-decay", type=float, default= 0.0005, 49 | help="Regularization parameter for L2-loss.") # 0.0005 50 | parser.add_argument("--momentum", type=float, default= 0.9, 51 | help="Momentum component of the optimiser.") 52 | parser.add_argument("--power", type=float, default= 0.9, 53 | help="Decay parameter to compute the learning rate.") 54 | # dataset information 55 | parser.add_argument("--dataset", type=str, default='cityscapes', 56 | help="voc12, cityscapes, or pascal-context.") 57 | parser.add_argument("--random-mirror", action="store_true", 58 | help="Whether to randomly mirror the inputs during the training.") 59 | parser.add_argument("--random-scale", action="store_true", 60 | help="Whether to randomly scale the inputs during the training.") 61 | 62 | parser.add_argument("--not-restore-last", action="store_true", 63 | help="Whether to not restore last (FC) layers.") 64 | parser.add_argument("--random-seed", type=int, default= 1234, 65 | help="Random seed to have reproducible results.") 66 | parser.add_argument('--logFile', default='log.txt', 67 | help='File that stores the training and validation logs') 68 | # GPU configuration 69 | parser.add_argument("--cuda", default=True, help="Run on CPU or GPU") 70 | parser.add_argument("--gpus", type=str, default="3", help="choose gpu device.") #使用3号GPU 71 | 72 | 73 | return parser.parse_args() 74 | 75 | args = get_arguments() 76 | 77 | 78 | def configure_dataset_init_model(args): 79 | if args.dataset == 'voc12': 80 | 81 | args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 82 | args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 83 | args.data_dir = '/home/wty/AllDataSet/VOC2012' # Path to the directory containing the PASCAL VOC dataset 84 | args.data_list = './dataset/list/VOC2012/train_aug.txt' # Path to the file listing the images in the dataset 85 | args.ignore_label = 255 #The index of the label to ignore during the training 86 | args.input_size = '473,473' #Comma-separated string with height and width of images 87 | args.num_classes = 21 #Number of classes to predict (including background) 88 | 89 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 90 | # saving model file and log record during the process of training 91 | 92 | #Where restore model pretrained on other dataset, such as COCO.") 93 | args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth' 94 | args.snapshot_dir = './snapshots/voc12/' #Where to save snapshots of the model 95 | args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training 96 | 97 | elif args.dataset == 'davis': 98 | args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 99 | args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 100 | args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016' # 37572 image pairs 101 | args.img_dir = '/home/ubuntu/xiankai/dataset/images' 102 | args.data_list = './dataset/list/VOC2012/train_aug.txt' # Path to the file listing the images in the dataset 103 | args.ignore_label = 255 #The index of the label to ignore during the training 104 | args.input_size = '473,473' #Comma-separated string with height and width of images 105 | args.num_classes = 2 #Number of classes to predict (including background) 106 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) # saving model file and log record during the process of training 107 | #Where restore model pretrained on other dataset, such as COCO.") 108 | args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' # 109 | args.snapshot_dir = './snapshots/davis_iteration_conf/' #Where to save snapshots of the model 110 | args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training 111 | 112 | elif args.dataset == 'cityscapes': 113 | args.batch_size = 8 #Number of images sent to the network in one step, batch_size/num_GPU=2 114 | args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size') 115 | # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2 116 | args.data_dir = '/home/wty/AllDataSet/CityScapes' # Path to the directory containing the PASCAL VOC dataset 117 | args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt' # Path to the file listing the images in the dataset 118 | args.ignore_label = 255 #The index of the label to ignore during the training 119 | args.input_size = '720,720' #Comma-separated string with height and width of images 120 | args.num_classes = 19 #Number of classes to predict (including background) 121 | 122 | args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32) 123 | # saving model file and log record during the process of training 124 | 125 | #Where restore model pretrained on other dataset, such as coarse cityscapes 126 | args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth' 127 | args.snapshot_dir = './snapshots/cityscapes/' #Where to save snapshots of the model 128 | args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training 129 | 130 | else: 131 | print("dataset error") 132 | 133 | def adjust_learning_rate(optimizer, i_iter, epoch, max_iter): 134 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 135 | 136 | lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch) 137 | optimizer.param_groups[0]['lr'] = lr 138 | if i_iter%3 ==0: 139 | optimizer.param_groups[0]['lr'] = lr 140 | optimizer.param_groups[1]['lr'] = 0 141 | else: 142 | optimizer.param_groups[0]['lr'] = 0.01*lr 143 | optimizer.param_groups[1]['lr'] = lr * 10 144 | 145 | return lr 146 | 147 | def loss_calc1(pred, label): 148 | """ 149 | This function returns cross entropy loss for semantic segmentation 150 | """ 151 | labels = torch.ge(label, 0.5).float() 152 | # 153 | batch_size = label.size() 154 | #print(batch_size) 155 | num_labels_pos = torch.sum(labels) 156 | # 157 | batch_1 = batch_size[0]* batch_size[2] 158 | batch_1 = batch_1* batch_size[3] 159 | weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio 160 | weight_1 = torch.reciprocal(weight_1) 161 | #print(num_labels_pos, batch_1) 162 | weight_2 = torch.div(batch_1-num_labels_pos, batch_1) 163 | #print('postive ratio', weight_2, weight_1) 164 | weight_22 = torch.mul(weight_1, torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda()) 165 | #weight_11 = torch.mul(weight_1, torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda()) 166 | criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 167 | #loss = class_balanced_cross_entropy_loss(pred, label).cuda() 168 | 169 | return criterion(pred, label) 170 | 171 | def loss_calc2(pred, label): 172 | """ 173 | This function returns cross entropy loss for semantic segmentation 174 | """ 175 | # out shape batch_size x channels x h x w -> batch_size x channels x h x w 176 | # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w 177 | # Variable(label.long()).cuda() 178 | criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 179 | 180 | return criterion(pred, label) 181 | 182 | 183 | 184 | def get_1x_lr_params(model): 185 | """ 186 | This generator returns all the parameters of the net except for 187 | the last classification layer. Note that for each batchnorm layer, 188 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 189 | any batchnorm parameter 190 | """ 191 | b = [] 192 | if torch.cuda.device_count() == 1: 193 | #b.append(model.encoder.conv1) 194 | #b.append(model.encoder.bn1) 195 | #b.append(model.encoder.layer1) 196 | #b.append(model.encoder.layer2) 197 | #b.append(model.encoder.layer3) 198 | #b.append(model.encoder.layer4) 199 | b.append(model.encoder.layer5) 200 | else: 201 | b.append(model.module.encoder.conv1) 202 | b.append(model.module.encoder.bn1) 203 | b.append(model.module.encoder.layer1) 204 | b.append(model.module.encoder.layer2) 205 | b.append(model.module.encoder.layer3) 206 | b.append(model.module.encoder.layer4) 207 | b.append(model.module.encoder.layer5) 208 | b.append(model.module.encoder.main_classifier) 209 | for i in range(len(b)): 210 | for j in b[i].modules(): 211 | jj = 0 212 | for k in j.parameters(): 213 | jj+=1 214 | if k.requires_grad: 215 | yield k 216 | 217 | 218 | def get_10x_lr_params(model): 219 | """ 220 | This generator returns all the parameters for the last layer of the net, 221 | which does the classification of pixel into classes 222 | """ 223 | b = [] 224 | if torch.cuda.device_count() == 1: 225 | b.append(model.linear_e.parameters()) 226 | b.append(model.main_classifier.parameters()) 227 | else: 228 | #b.append(model.module.encoder.layer5.parameters()) 229 | b.append(model.module.linear_e.parameters()) 230 | b.append(model.module.conv1.parameters()) 231 | b.append(model.module.conv2.parameters()) 232 | b.append(model.module.gate.parameters()) 233 | b.append(model.module.bn1.parameters()) 234 | b.append(model.module.bn2.parameters()) 235 | b.append(model.module.main_classifier1.parameters()) 236 | b.append(model.module.main_classifier2.parameters()) 237 | 238 | for j in range(len(b)): 239 | for i in b[j]: 240 | yield i 241 | 242 | def lr_poly(base_lr, iter, max_iter, power, epoch): 243 | if epoch<=2: 244 | factor = 1 245 | elif epoch>2 and epoch< 6: 246 | factor = 1 247 | else: 248 | factor = 0.5 249 | return base_lr*factor*((1-float(iter)/max_iter)**(power)) 250 | 251 | 252 | def netParams(model): 253 | ''' 254 | Computing total network parameters 255 | Args: 256 | model: model 257 | return: total network parameters 258 | ''' 259 | total_paramters = 0 260 | for parameter in model.parameters(): 261 | i = len(parameter.size()) 262 | #print(parameter.size()) 263 | p = 1 264 | for j in range(i): 265 | p *= parameter.size(j) 266 | total_paramters += p 267 | 268 | return total_paramters 269 | 270 | def main(): 271 | 272 | 273 | print("=====> Configure dataset and pretrained model") 274 | configure_dataset_init_model(args) 275 | print(args) 276 | 277 | print(" current dataset: ", args.dataset) 278 | print(" init model: ", args.restore_from) 279 | print("=====> Set GPU for training") 280 | if args.cuda: 281 | print("====> Use gpu id: '{}'".format(args.gpus)) 282 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 283 | if not torch.cuda.is_available(): 284 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 285 | # Select which GPU, -1 if CPU 286 | #gpu_id = args.gpus 287 | #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 288 | print("=====> Random Seed: ", args.random_seed) 289 | torch.manual_seed(args.random_seed) 290 | if args.cuda: 291 | torch.cuda.manual_seed(args.random_seed) 292 | 293 | h, w = map(int, args.input_size.split(',')) 294 | input_size = (h, w) 295 | 296 | cudnn.enabled = True 297 | 298 | print("=====> Building network") 299 | saved_state_dict = torch.load(args.restore_from) 300 | model = CoattentionNet(num_classes=args.num_classes) 301 | #print(model) 302 | new_params = model.state_dict().copy() 303 | for i in saved_state_dict["model"]: 304 | #Scale.layer5.conv2d_list.3.weight 305 | i_parts = i.split('.') # 针对多GPU的情况 306 | #i_parts.pop(1) 307 | #print('i_parts: ', '.'.join(i_parts[1:-1])) 308 | #if not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn': #init model pretrained on COCO, class name=21, layer5 is ASPP 309 | new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i] 310 | #print('copy {}'.format('.'.join(i_parts[1:]))) 311 | 312 | 313 | print("=====> Loading init weights, pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes") 314 | 315 | 316 | model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数 317 | #print(model.keys()) 318 | if args.cuda: 319 | #model.to(device) 320 | if torch.cuda.device_count()>1: 321 | print("torch.cuda.device_count()=",torch.cuda.device_count()) 322 | model = torch.nn.DataParallel(model).cuda() #multi-card data parallel 323 | else: 324 | print("single GPU for training") 325 | model = model.cuda() #1-card data parallel 326 | start_epoch=0 327 | 328 | print("=====> Whether resuming from a checkpoint, for continuing training") 329 | if args.resume: 330 | if os.path.isfile(args.resume): 331 | print("=> loading checkpoint '{}'".format(args.resume)) 332 | checkpoint = torch.load(args.resume) 333 | start_epoch = checkpoint["epoch"] 334 | model.load_state_dict(checkpoint["model"]) 335 | else: 336 | print("=> no checkpoint found at '{}'".format(args.resume)) 337 | 338 | 339 | model.train() 340 | cudnn.benchmark = True 341 | 342 | if not os.path.exists(args.snapshot_dir): 343 | os.makedirs(args.snapshot_dir) 344 | 345 | print('=====> Computing network parameters') 346 | total_paramters = netParams(model) 347 | print('Total network parameters: ' + str(total_paramters)) 348 | 349 | print("=====> Preparing training data") 350 | if args.dataset == 'voc12': 351 | trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 352 | scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 353 | batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) 354 | elif args.dataset == 'cityscapes': 355 | trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 356 | scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 357 | batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) 358 | elif args.dataset == 'davis': #for davis 2016 359 | db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir, transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path 360 | trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0) 361 | else: 362 | print("dataset error") 363 | 364 | optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate }, #针对特定层进行学习,有些层不学习 365 | {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], 366 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 367 | optimizer.zero_grad() 368 | 369 | 370 | 371 | logFileLoc = args.snapshot_dir + args.logFile 372 | if os.path.isfile(logFileLoc): 373 | logger = open(logFileLoc, 'a') 374 | else: 375 | logger = open(logFileLoc, 'w') 376 | logger.write("Parameters: %s" % (str(total_paramters))) 377 | logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n')) 378 | logger.flush() 379 | 380 | print("=====> Begin to train") 381 | train_len=len(trainloader) 382 | print(" iteration numbers of per epoch: ", train_len) 383 | print(" epoch num: ", args.maxEpoches) 384 | print(" max iteration: ", args.maxEpoches*train_len) 385 | 386 | for epoch in range(start_epoch, int(args.maxEpoches)): 387 | 388 | np.random.seed(args.random_seed + epoch) 389 | for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1 390 | #print("i_iter=", i_iter, "epoch=", epoch) 391 | target, target_gt, search, search_gt = batch['target'], batch['target_gt'], batch['search'], batch['search_gt'] 392 | images, labels = batch['img'], batch['img_gt'] 393 | #print(labels.size()) 394 | images.requires_grad_() 395 | images = Variable(images).cuda() 396 | labels = Variable(labels.float().unsqueeze(1)).cuda() 397 | 398 | target.requires_grad_() 399 | target = Variable(target).cuda() 400 | target_gt = Variable(target_gt.float().unsqueeze(1)).cuda() 401 | 402 | search.requires_grad_() 403 | search = Variable(search).cuda() 404 | search_gt = Variable(search_gt.float().unsqueeze(1)).cuda() 405 | 406 | optimizer.zero_grad() 407 | 408 | lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch, 409 | max_iter = args.maxEpoches * train_len) 410 | #print(images.size()) 411 | if i_iter%3 ==0: #对于静态图片的训练 412 | 413 | pred1, pred2, pred3 = model(images, images) 414 | loss = 0.1*(loss_calc1(pred3, labels) + 0.8* loss_calc2(pred3, labels) ) 415 | loss.backward() 416 | 417 | else: 418 | 419 | pred1, pred2, pred3 = model(target, search) 420 | loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) + loss_calc1(pred2, search_gt) + 0.8* loss_calc2(pred2, search_gt)#class_balanced_cross_entropy_loss(pred, labels, size_average=False) 421 | loss.backward() 422 | 423 | optimizer.step() 424 | 425 | print("===> Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr)) 426 | logger.write("Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr)) 427 | logger.flush() 428 | 429 | print("=====> saving model") 430 | state={"epoch": epoch+1, "model": model.state_dict()} 431 | torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth')) 432 | 433 | 434 | end = timeit.default_timer() 435 | print( float(end-start)/3600, 'h') 436 | logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600)) 437 | logger.close() 438 | 439 | 440 | if __name__ == '__main__': 441 | main() 442 | -------------------------------------------------------------------------------- /train_iteration_conf_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Sep 15 10:52:26 2018 4 | 5 | @author: carri 6 | """ 7 | #区别于deeplab_co_attention_concat在于采用了新的model(siamese_model_concat_new)来train 8 | 9 | import argparse 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils import data 13 | import numpy as np 14 | import pickle 15 | import cv2 16 | from torch.autograd import Variable 17 | import torch.optim as optim 18 | import scipy.misc 19 | import torch.backends.cudnn as cudnn 20 | import sys 21 | import os 22 | from utils.balanced_BCE import class_balanced_cross_entropy_loss 23 | import os.path as osp 24 | #from psp.model import PSPNet 25 | #from dataloaders import davis_2016 as db 26 | from dataloaders import PairwiseImg_video_try as db #采用voc dataset的数据设置格式方法 27 | import matplotlib.pyplot as plt 28 | import random 29 | import timeit 30 | #from psp.model1 import CoattentionNet #基于pspnet搭建的co-attention 模型 31 | from deeplab.siamese_model_conf_try import CoattentionNet #siame_model 是直接将attend的model之后的结果输出 32 | #from deeplab.utils import get_1x_lr_params, get_10x_lr_params#, adjust_learning_rate #, loss_calc 33 | start = timeit.default_timer() 34 | 35 | def get_arguments(): 36 | """Parse all the arguments provided from the CLI. 37 | 38 | Returns: 39 | A list of parsed arguments. 40 | """ 41 | parser = argparse.ArgumentParser(description="PSPnet Network") 42 | 43 | # optimatization configuration 44 | parser.add_argument("--is-training", action="store_true", 45 | help="Whether to updates the running means and variances during the training.") 46 | parser.add_argument("--learning-rate", type=float, default= 0.00025, 47 | help="Base learning rate for training with polynomial decay.") #0.001 48 | parser.add_argument("--weight-decay", type=float, default= 0.0005, 49 | help="Regularization parameter for L2-loss.") # 0.0005 50 | parser.add_argument("--momentum", type=float, default= 0.9, 51 | help="Momentum component of the optimiser.") 52 | parser.add_argument("--power", type=float, default= 0.9, 53 | help="Decay parameter to compute the learning rate.") 54 | # dataset information 55 | parser.add_argument("--dataset", type=str, default='cityscapes', 56 | help="voc12, cityscapes, or pascal-context.") 57 | parser.add_argument("--random-mirror", action="store_true", 58 | help="Whether to randomly mirror the inputs during the training.") 59 | parser.add_argument("--random-scale", action="store_true", 60 | help="Whether to randomly scale the inputs during the training.") 61 | 62 | parser.add_argument("--not-restore-last", action="store_true", 63 | help="Whether to not restore last (FC) layers.") 64 | parser.add_argument("--random-seed", type=int, default= 1234, 65 | help="Random seed to have reproducible results.") 66 | parser.add_argument('--logFile', default='log.txt', 67 | help='File that stores the training and validation logs') 68 | # GPU configuration 69 | parser.add_argument("--cuda", default=True, help="Run on CPU or GPU") 70 | parser.add_argument("--gpus", type=str, default="3", help="choose gpu device.") #使用3号GPU 71 | 72 | 73 | return parser.parse_args() 74 | 75 | args = get_arguments() 76 | 77 | 78 | def configure_dataset_init_model(args): 79 | if args.dataset == 'voc12': 80 | 81 | args.batch_size = 10# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 82 | args.maxEpoches = 15 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 83 | args.data_dir = '/home/wty/AllDataSet/VOC2012' # Path to the directory containing the PASCAL VOC dataset 84 | args.data_list = './dataset/list/VOC2012/train_aug.txt' # Path to the file listing the images in the dataset 85 | args.ignore_label = 255 #The index of the label to ignore during the training 86 | args.input_size = '473,473' #Comma-separated string with height and width of images 87 | args.num_classes = 21 #Number of classes to predict (including background) 88 | 89 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 90 | # saving model file and log record during the process of training 91 | 92 | #Where restore model pretrained on other dataset, such as COCO.") 93 | args.restore_from = './pretrained/MS_DeepLab_resnet_pretrained_COCO_init.pth' 94 | args.snapshot_dir = './snapshots/voc12/' #Where to save snapshots of the model 95 | args.resume = './snapshots/voc12/psp_voc12_3.pth' #checkpoint log file, helping recovering training 96 | 97 | elif args.dataset == 'davis': 98 | args.batch_size = 16# 1 card: 5, 2 cards: 10 Number of images sent to the network in one step, 16 on paper 99 | args.maxEpoches = 60 # 1 card: 15, 2 cards: 15 epoches, equal to 30k iterations, max iterations= maxEpoches*len(train_aug)/batch_size_per_gpu'), 100 | args.data_dir = '/home/ubuntu/xiankai/dataset/DAVIS-2016' # 37572 image pairs 101 | args.img_dir = '/home/ubuntu/xiankai/dataset/images' 102 | args.data_list = './dataset/list/VOC2012/train_aug.txt' # Path to the file listing the images in the dataset 103 | args.ignore_label = 255 #The index of the label to ignore during the training 104 | args.input_size = '378, 378' #Comma-separated string with height and width of images 105 | args.num_classes = 2 #Number of classes to predict (including background) 106 | args.img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) # saving model file and log record during the process of training 107 | #Where restore model pretrained on other dataset, such as COCO.") 108 | args.restore_from = './pretrained/deep_labv3/deeplab_davis_12_0.pth' #resnet50-19c8e357.pth''/home/xiankai/PSPNet_PyTorch/snapshots/davis/psp_davis_0.pth' # 109 | args.snapshot_dir = './snapshots/davis_iteration_conf_try/' #Where to save snapshots of the model 110 | args.resume = './snapshots/davis/co_attention_davis_124.pth' #checkpoint log file, helping recovering training 111 | 112 | elif args.dataset == 'cityscapes': 113 | args.batch_size = 8 #Number of images sent to the network in one step, batch_size/num_GPU=2 114 | args.maxEpoches = 60 #epoch nums, 60 epoches is equal to 90k iterations, max iterations= maxEpoches*len(train)/batch_size') 115 | # 60x2975/2=89250 ~= 90k, single_GPU_batch_size=2 116 | args.data_dir = '/home/wty/AllDataSet/CityScapes' # Path to the directory containing the PASCAL VOC dataset 117 | args.data_list = './dataset/list/Cityscapes/cityscapes_train_list.txt' # Path to the file listing the images in the dataset 118 | args.ignore_label = 255 #The index of the label to ignore during the training 119 | args.input_size = '720,720' #Comma-separated string with height and width of images 120 | args.num_classes = 19 #Number of classes to predict (including background) 121 | 122 | args.img_mean = np.array((73.15835921, 82.90891754, 72.39239876), dtype=np.float32) 123 | # saving model file and log record during the process of training 124 | 125 | #Where restore model pretrained on other dataset, such as coarse cityscapes 126 | args.restore_from = './pretrained/resnet101_pretrained_for_cityscapes.pth' 127 | args.snapshot_dir = './snapshots/cityscapes/' #Where to save snapshots of the model 128 | args.resume = './snapshots/cityscapes/psp_cityscapes_12_3.pth' #checkpoint log file, helping recovering training 129 | 130 | else: 131 | print("dataset error") 132 | 133 | def adjust_learning_rate(optimizer, i_iter, epoch, max_iter): 134 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 135 | 136 | lr = lr_poly(args.learning_rate, i_iter, max_iter, args.power, epoch) 137 | optimizer.param_groups[0]['lr'] = lr 138 | if i_iter%3 ==0: 139 | optimizer.param_groups[0]['lr'] = lr 140 | optimizer.param_groups[1]['lr'] = 0 141 | else: 142 | optimizer.param_groups[0]['lr'] = 0.01*lr 143 | optimizer.param_groups[1]['lr'] = lr * 10 144 | 145 | return lr 146 | 147 | def loss_calc1(pred, label): 148 | """ 149 | This function returns cross entropy loss for semantic segmentation 150 | """ 151 | labels = torch.ge(label, 0.5).float() 152 | # 153 | batch_size = label.size() 154 | #print(batch_size) 155 | num_labels_pos = torch.sum(labels) 156 | # 157 | batch_1 = batch_size[0]* batch_size[2] 158 | batch_1 = batch_1* batch_size[3] 159 | weight_1 = torch.div(num_labels_pos, batch_1) # pos ratio 160 | weight_1 = torch.reciprocal(weight_1) 161 | #print(num_labels_pos, batch_1) 162 | weight_2 = torch.div(batch_1-num_labels_pos, batch_1) 163 | #print('postive ratio', weight_2, weight_1) 164 | weight_22 = torch.mul(weight_1, torch.ones(batch_size[0], batch_size[1], batch_size[2], batch_size[3]).cuda()) 165 | #weight_11 = torch.mul(weight_1, torch.ones(batch_size[0], batch_size[1], batch_size[2]).cuda()) 166 | criterion = torch.nn.BCELoss(weight = weight_22)#weight = torch.Tensor([0,1]) .cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 167 | #loss = class_balanced_cross_entropy_loss(pred, label).cuda() 168 | 169 | return criterion(pred, label) 170 | 171 | def loss_calc2(pred, label): 172 | """ 173 | This function returns cross entropy loss for semantic segmentation 174 | """ 175 | # out shape batch_size x channels x h x w -> batch_size x channels x h x w 176 | # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w 177 | # Variable(label.long()).cuda() 178 | criterion = torch.nn.L1Loss()#.cuda() #torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 179 | 180 | return criterion(pred, label) 181 | 182 | 183 | 184 | def get_1x_lr_params(model): 185 | """ 186 | This generator returns all the parameters of the net except for 187 | the last classification layer. Note that for each batchnorm layer, 188 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 189 | any batchnorm parameter 190 | """ 191 | b = [] 192 | if torch.cuda.device_count() == 1: 193 | #b.append(model.encoder.conv1) 194 | #b.append(model.encoder.bn1) 195 | #b.append(model.encoder.layer1) 196 | #b.append(model.encoder.layer2) 197 | #b.append(model.encoder.layer3) 198 | #b.append(model.encoder.layer4) 199 | b.append(model.encoder.layer5) 200 | else: 201 | b.append(model.module.encoder.conv1) 202 | b.append(model.module.encoder.bn1) 203 | b.append(model.module.encoder.layer1) 204 | b.append(model.module.encoder.layer2) 205 | b.append(model.module.encoder.layer3) 206 | b.append(model.module.encoder.layer4) 207 | b.append(model.module.encoder.layer5) 208 | b.append(model.module.encoder.main_classifier) 209 | for i in range(len(b)): 210 | for j in b[i].modules(): 211 | jj = 0 212 | for k in j.parameters(): 213 | jj+=1 214 | if k.requires_grad: 215 | yield k 216 | 217 | 218 | def get_10x_lr_params(model): 219 | """ 220 | This generator returns all the parameters for the last layer of the net, 221 | which does the classification of pixel into classes 222 | """ 223 | b = [] 224 | if torch.cuda.device_count() == 1: 225 | b.append(model.linear_e.parameters()) 226 | b.append(model.main_classifier.parameters()) 227 | else: 228 | #b.append(model.module.encoder.layer5.parameters()) 229 | b.append(model.module.linear_e.parameters()) 230 | b.append(model.module.conv1.parameters()) 231 | #b.append(model.module.conv2.parameters()) 232 | b.append(model.module.gate.parameters()) 233 | b.append(model.module.bn1.parameters()) 234 | #b.append(model.module.bn2.parameters()) 235 | b.append(model.module.main_classifier1.parameters()) 236 | #b.append(model.module.main_classifier2.parameters()) 237 | 238 | for j in range(len(b)): 239 | for i in b[j]: 240 | yield i 241 | 242 | def lr_poly(base_lr, iter, max_iter, power, epoch): 243 | if epoch<=2: 244 | factor = 1 245 | elif epoch>2 and epoch< 6: 246 | factor = 1 247 | else: 248 | factor = 0.5 249 | return base_lr*factor*((1-float(iter)/max_iter)**(power)) 250 | 251 | 252 | def netParams(model): 253 | ''' 254 | Computing total network parameters 255 | Args: 256 | model: model 257 | return: total network parameters 258 | ''' 259 | total_paramters = 0 260 | for parameter in model.parameters(): 261 | i = len(parameter.size()) 262 | #print(parameter.size()) 263 | p = 1 264 | for j in range(i): 265 | p *= parameter.size(j) 266 | total_paramters += p 267 | 268 | return total_paramters 269 | 270 | def main(): 271 | 272 | 273 | print("=====> Configure dataset and pretrained model") 274 | configure_dataset_init_model(args) 275 | print(args) 276 | 277 | print(" current dataset: ", args.dataset) 278 | print(" init model: ", args.restore_from) 279 | print("=====> Set GPU for training") 280 | if args.cuda: 281 | print("====> Use gpu id: '{}'".format(args.gpus)) 282 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 283 | if not torch.cuda.is_available(): 284 | raise Exception("No GPU found or Wrong gpu id, please run without --cuda") 285 | # Select which GPU, -1 if CPU 286 | #gpu_id = args.gpus 287 | #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 288 | print("=====> Random Seed: ", args.random_seed) 289 | torch.manual_seed(args.random_seed) 290 | if args.cuda: 291 | torch.cuda.manual_seed(args.random_seed) 292 | 293 | h, w = map(int, args.input_size.split(',')) 294 | input_size = (h, w) 295 | 296 | cudnn.enabled = True 297 | 298 | print("=====> Building network") 299 | saved_state_dict = torch.load(args.restore_from) 300 | model = CoattentionNet(num_classes=args.num_classes) 301 | #print(model) 302 | new_params = model.state_dict().copy() 303 | for i in saved_state_dict["model"]: 304 | #Scale.layer5.conv2d_list.3.weight 305 | i_parts = i.split('.') # 针对多GPU的情况 306 | #i_parts.pop(1) 307 | #print('i_parts: ', '.'.join(i_parts[1:-1])) 308 | #if not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn': #init model pretrained on COCO, class name=21, layer5 is ASPP 309 | new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i] 310 | #print('copy {}'.format('.'.join(i_parts[1:]))) 311 | 312 | 313 | print("=====> Loading init weights, pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes") 314 | 315 | 316 | model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数 317 | #print(model.keys()) 318 | if args.cuda: 319 | #model.to(device) 320 | if torch.cuda.device_count()>1: 321 | print("torch.cuda.device_count()=",torch.cuda.device_count()) 322 | model = torch.nn.DataParallel(model).cuda() #multi-card data parallel 323 | else: 324 | print("single GPU for training") 325 | model = model.cuda() #1-card data parallel 326 | start_epoch=0 327 | 328 | print("=====> Whether resuming from a checkpoint, for continuing training") 329 | if args.resume: 330 | if os.path.isfile(args.resume): 331 | print("=> loading checkpoint '{}'".format(args.resume)) 332 | checkpoint = torch.load(args.resume) 333 | start_epoch = checkpoint["epoch"] 334 | model.load_state_dict(checkpoint["model"]) 335 | else: 336 | print("=> no checkpoint found at '{}'".format(args.resume)) 337 | 338 | 339 | model.train() 340 | cudnn.benchmark = True 341 | 342 | if not os.path.exists(args.snapshot_dir): 343 | os.makedirs(args.snapshot_dir) 344 | 345 | print('=====> Computing network parameters') 346 | total_paramters = netParams(model) 347 | print('Total network parameters: ' + str(total_paramters)) 348 | 349 | print("=====> Preparing training data") 350 | if args.dataset == 'voc12': 351 | trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 352 | scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 353 | batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) 354 | elif args.dataset == 'cityscapes': 355 | trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 356 | scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 357 | batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) 358 | elif args.dataset == 'davis': #for davis 2016 359 | db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir, transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path 360 | trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0) 361 | else: 362 | print("dataset error") 363 | 364 | optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate }, #针对特定层进行学习,有些层不学习 365 | {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], 366 | lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 367 | optimizer.zero_grad() 368 | 369 | 370 | 371 | logFileLoc = args.snapshot_dir + args.logFile 372 | if os.path.isfile(logFileLoc): 373 | logger = open(logFileLoc, 'a') 374 | else: 375 | logger = open(logFileLoc, 'w') 376 | logger.write("Parameters: %s" % (str(total_paramters))) 377 | logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n')) 378 | logger.flush() 379 | 380 | print("=====> Begin to train") 381 | train_len=len(trainloader) 382 | print(" iteration numbers of per epoch: ", train_len) 383 | print(" epoch num: ", args.maxEpoches) 384 | print(" max iteration: ", args.maxEpoches*train_len) 385 | 386 | for epoch in range(start_epoch, int(args.maxEpoches)): 387 | 388 | np.random.seed(args.random_seed + epoch) 389 | for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1 390 | #print("i_iter=", i_iter, "epoch=", epoch) 391 | target, target_gt, search, search_gt = batch['target'], batch['target_grt'], batch['search'], batch['search_grt'] 392 | images, labels = batch['img'], batch['img_grt'] 393 | #print('input size:', len(target), target.size(),labels.size()) 394 | #8,2,3,473,473 395 | images.requires_grad_() 396 | images = Variable(images).cuda() 397 | labels = Variable(labels.float().unsqueeze(1)).cuda() 398 | 399 | target.requires_grad_() 400 | target = Variable(target).cuda() 401 | target_gt = Variable(target_gt.float().unsqueeze(1)).cuda() 402 | 403 | search.requires_grad_() 404 | search = Variable(search).cuda() 405 | search_gt = Variable(search_gt.float().unsqueeze(1)).cuda() 406 | 407 | optimizer.zero_grad() 408 | 409 | lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch, 410 | max_iter = args.maxEpoches * train_len) 411 | #print(images.size()) 412 | if i_iter%3 ==0: #对于静态图片的训练 413 | 414 | pred1, pred2 = model(images, images) 415 | loss = 0.1*(loss_calc1(pred2, labels) + 0.8* loss_calc2(pred2, labels)) 416 | loss.backward() 417 | 418 | else: 419 | 420 | pred1, pred2 = model(target, search) 421 | #print('video prediction size:', pred2.size(),target_gt.size()) 422 | loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) 423 | loss.backward() 424 | 425 | optimizer.step() 426 | 427 | print("===> Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr)) 428 | logger.write("Epoch[{}]({}/{}): Loss: {:.10f} lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr)) 429 | logger.flush() 430 | 431 | print("=====> saving model") 432 | state={"epoch": epoch+1, "model": model.state_dict()} 433 | torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth')) 434 | 435 | 436 | end = timeit.default_timer() 437 | print( float(end-start)/3600, 'h') 438 | logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600)) 439 | logger.close() 440 | 441 | 442 | if __name__ == '__main__': 443 | main() 444 | --------------------------------------------------------------------------------