├── info ├── SMB.jpg ├── result_1.JPG ├── resutl_5.JPG └── main_structure.jpg ├── .gitattributes ├── data ├── splits │ ├── coco │ │ ├── trn │ │ │ ├── fold0.pkl │ │ │ ├── fold1.pkl │ │ │ ├── fold2.pkl │ │ │ └── fold3.pkl │ │ └── val │ │ │ ├── fold0.pkl │ │ │ ├── fold1.pkl │ │ │ ├── fold2.pkl │ │ │ └── fold3.pkl │ ├── fss │ │ ├── val.txt │ │ ├── test.txt │ │ └── trn.txt │ └── pascal │ │ ├── val │ │ ├── fold0.txt │ │ ├── fold3.txt │ │ ├── fold1.txt │ │ └── fold2.txt │ │ └── trn │ │ └── fold3.txt ├── calc.py ├── dataset.py ├── fss.py ├── pascal.py └── coco.py ├── common ├── utils.py ├── evaluation.py ├── vis.py └── logger.py ├── model ├── base │ ├── correlation.py │ ├── feature.py │ ├── merge_cor.py │ ├── merge_pro.py │ └── merge.py └── mshnet.py ├── README.md ├── test.py ├── train.py └── train_coco.py /info/SMB.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/info/SMB.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /info/result_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/info/result_1.JPG -------------------------------------------------------------------------------- /info/resutl_5.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/info/resutl_5.JPG -------------------------------------------------------------------------------- /info/main_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/info/main_structure.jpg -------------------------------------------------------------------------------- /data/splits/coco/trn/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/trn/fold0.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/trn/fold1.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/trn/fold2.pkl -------------------------------------------------------------------------------- /data/splits/coco/trn/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/trn/fold3.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/val/fold0.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/val/fold1.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/val/fold2.pkl -------------------------------------------------------------------------------- /data/splits/coco/val/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alex-ShiLei/MSHNet/HEAD/data/splits/coco/val/fold3.pkl -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | r""" Helper functions """ 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def fix_randseed(seed): 9 | r""" Set random seeds for reproducibility """ 10 | if seed is None: 11 | seed = int(random.random() * 1e5) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def mean(x): 21 | return sum(x) / len(x) if len(x) > 0 else 0.0 22 | 23 | 24 | def to_cuda(batch): 25 | for key, value in batch.items(): 26 | if isinstance(value, torch.Tensor): 27 | batch[key] = value.cuda() 28 | return batch 29 | 30 | 31 | def to_cpu(tensor): 32 | return tensor.detach().clone().cpu() 33 | -------------------------------------------------------------------------------- /data/calc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | filepath = '/home/kexin/Project/DataSet/VOCdevkit/VOC2012/JPEGImages' # 数据集目录 6 | nameFile='/home/kexin/Project/sxw/hsnet-main/data/splits/pascal/trn/fold0.txt' 7 | namesFile=open(nameFile) 8 | pathDir=namesFile.readlines() 9 | pathDir=[name.split('__')[0]+'.jpg' for name in pathDir] 10 | numImg=len(pathDir) 11 | 12 | R_channel = 0.0 13 | G_channel = 0.0 14 | B_channel = 0.0 15 | for idx in range(numImg): 16 | filename = pathDir[idx] 17 | print('m',idx,filename) 18 | img = cv2.imread(os.path.join(filepath, filename)) 19 | img=cv2.resize(img,(473,473)) 20 | img=img/255.0 21 | R_channel = R_channel + np.sum(img[:, :, 2]) 22 | G_channel = G_channel + np.sum(img[:, :, 1]) 23 | B_channel = B_channel + np.sum(img[:, :, 0]) 24 | 25 | num = numImg * 473 * 473 26 | R_mean = R_channel / num 27 | G_mean = G_channel / num 28 | B_mean = B_channel / num 29 | 30 | R_channel = 0 31 | G_channel = 0 32 | B_channel = 0 33 | for idx in range(numImg): 34 | filename = pathDir[idx] 35 | print('d', idx, filename) 36 | img = cv2.imread(os.path.join(filepath, filename)) 37 | img=cv2.resize(img,(473,473)) 38 | img=img/255.0 39 | R_channel = R_channel + np.sum((img[:, :, 2] - R_mean) ** 2) 40 | G_channel = G_channel + np.sum((img[:, :, 1] - G_mean) ** 2) 41 | B_channel = B_channel + np.sum((img[:, :, 0] - B_mean) ** 2) 42 | 43 | R_var = R_channel / num 44 | G_var = G_channel / num 45 | B_var = B_channel / num 46 | print("RGB_mean [%f,%f,%f]" % (R_mean, G_mean, B_mean)) 47 | print("RGB_std [%f,%f,%f]" % (R_var**(0.5), G_var**(0.5), B_var**0.5)) -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | r""" Dataloader builder for few-shot semantic segmentation dataset """ 2 | from torchvision import transforms 3 | from torch.utils.data import DataLoader 4 | 5 | from data.pascal import DatasetPASCAL 6 | from data.coco import DatasetCOCO 7 | from data.fss import DatasetFSS 8 | 9 | 10 | class FSSDataset: 11 | 12 | @classmethod 13 | def initialize(cls, img_size, datapath, use_original_imgsize): 14 | 15 | cls.datasets = { 16 | 'pascal': DatasetPASCAL, 17 | 'coco': DatasetCOCO, 18 | 'fss': DatasetFSS, 19 | } 20 | 21 | cls.img_mean =[0.485,0.456,0.406] #[0.4485, 0.4196, 0.3810] 22 | cls.img_std = [0.229,0.224,0.225]#[0.2687, 0.2641, 0.2719] 23 | cls.datapath = datapath 24 | cls.use_original_imgsize = use_original_imgsize 25 | 26 | cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), 27 | transforms.ToTensor(), 28 | transforms.Normalize(cls.img_mean, cls.img_std)]) 29 | 30 | @classmethod 31 | def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): 32 | # Force randomness during training for diverse episode combinations 33 | # Freeze randomness during testing for reproducibility 34 | shuffle = split == 'trn' 35 | nworker = nworker if split == 'trn' else 0 36 | 37 | dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize) 38 | dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker) 39 | 40 | return dataloader 41 | -------------------------------------------------------------------------------- /common/evaluation.py: -------------------------------------------------------------------------------- 1 | r""" Evaluate mask prediction """ 2 | import torch 3 | 4 | 5 | class Evaluator: 6 | r""" Computes intersection and union between prediction and ground-truth """ 7 | @classmethod 8 | def initialize(cls): 9 | cls.ignore_index = 255 10 | 11 | @classmethod 12 | def classify_prediction(cls, pred_mask, batch): 13 | gt_mask = batch.get('query_mask') 14 | 15 | # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020)) 16 | query_ignore_idx = batch.get('query_ignore_idx') 17 | if query_ignore_idx is not None: 18 | assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0 19 | query_ignore_idx *= cls.ignore_index 20 | gt_mask = gt_mask + query_ignore_idx 21 | pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index 22 | 23 | # compute intersection and union of each episode in a batch 24 | area_inter, area_pred, area_gt = [], [], [] 25 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 26 | _inter = _pred_mask[_pred_mask == _gt_mask] 27 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 28 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 29 | else: 30 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 31 | area_inter.append(_area_inter) 32 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 33 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 34 | area_inter = torch.stack(area_inter).t() 35 | area_pred = torch.stack(area_pred).t() 36 | area_gt = torch.stack(area_gt).t() 37 | area_union = area_pred + area_gt - area_inter 38 | 39 | return area_inter, area_union 40 | -------------------------------------------------------------------------------- /model/base/correlation.py: -------------------------------------------------------------------------------- 1 | r""" Provides functions that builds/manipulates correlation tensors """ 2 | import torch 3 | 4 | 5 | class Correlation: 6 | 7 | @classmethod 8 | def multilayer_correlation(cls, query_feats, support_feats, stack_ids): 9 | eps = 1e-5 10 | corrs = [] 11 | sups=[] 12 | for idx, (query_feat, support_feat) in enumerate(zip(query_feats, support_feats)): 13 | queryShape = query_feat.shape#b,c,h,w 14 | corrI=[] 15 | realSupI=[] 16 | for j in range(len(support_feat)):#b 17 | queryIJ=query_feat[j].flatten(start_dim=1)#c,hw 18 | queryIJNorm=queryIJ/(queryIJ.norm(dim=0, p=2, keepdim=True) + eps) 19 | supIJ=support_feat[j]#c,hw 20 | supIJNorm=supIJ/(supIJ.norm(dim=0, p=2, keepdim=True) + eps) 21 | corr=(queryIJNorm.permute(1,0)).matmul(supIJNorm) 22 | corr = corr.clamp(min=0) 23 | corr=corr.mean(dim=1,keepdim=True) 24 | corr=(corr.permute(1,0)).unsqueeze(0)#1,1,hw 25 | corrI.append(corr)#b,1,hw 26 | resupJ=supIJ.mean(dim=1,keepdim=True) 27 | resupJsum=resupJ.sum() 28 | resupJ=resupJ.unsqueeze(0).expand(-1,-1,queryIJ.shape[-1])#1,c,hw 29 | queryIJ=queryIJ.unsqueeze(0)#1,c,hw 30 | if resupJsum==0: 31 | queryIJ=queryIJ*resupJ 32 | resupJ=torch.cat([queryIJ,resupJ],dim=1)#1,2c,hw 33 | realSupI.append(resupJ)#b,2c,hw 34 | corrI=torch.cat(corrI,dim=0)#b,1,h,w 35 | corrI=corrI.reshape((corrI.shape[0],corrI.shape[1],queryShape[-2],queryShape[-1]))#b,1,h,w 36 | realSupI=torch.cat(realSupI,dim=0)#b,2c,h,w 37 | realSupI=realSupI.reshape((realSupI.shape[0],realSupI.shape[1],queryShape[-2],queryShape[-1])) 38 | corrs.append(corrI)#n,b,1,h,w 39 | sups.append(realSupI)#n,b,c,h,w 40 | 41 | corr_l4 = torch.cat(corrs[-stack_ids[0]:],dim=1).contiguous()#b,n,h,w 42 | corr_l3 = torch.cat(corrs[-stack_ids[1]:-stack_ids[0]],dim=1).contiguous() 43 | corr_l2 = torch.cat(corrs[-stack_ids[2]:-stack_ids[1]],dim=1).contiguous() 44 | 45 | sup_l4=sups[-stack_ids[0]:]#n,b,2c,h,w 46 | sup_l3=sups[-stack_ids[1]:-stack_ids[0]] 47 | sup_l2=sups[-stack_ids[2]:-stack_ids[1]] 48 | #print(corr_l4.shape,corr_l3.shape,corr_l2.shape)#n,b,1,h,wtorch.Size([13, 3, 15, 15]) 49 | #print(len(sup_l4), len(sup_l3), len(sup_l2)) 50 | return [corr_l4, corr_l3, corr_l2],[sup_l4,sup_l3,sup_l2] 51 | -------------------------------------------------------------------------------- /model/base/feature.py: -------------------------------------------------------------------------------- 1 | r""" Extracts intermediate features from given backbone network & layer ids """ 2 | 3 | 4 | def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None): 5 | r""" Extract intermediate features from VGG """ 6 | feats = [] 7 | feat = img 8 | for lid, module in enumerate(backbone.features): 9 | feat = module(feat) 10 | if lid in feat_ids: 11 | feats.append(feat.clone()) 12 | return feats 13 | 14 | 15 | def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids): 16 | r""" Extract intermediate features from ResNet""" 17 | feats = [] 18 | 19 | # Layer 0 20 | feat = backbone.conv1.forward(img) 21 | feat = backbone.bn1.forward(feat) 22 | feat = backbone.relu.forward(feat) 23 | feat = backbone.maxpool.forward(feat) 24 | 25 | # Layer 1-4 26 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)): 27 | res = feat 28 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 29 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 30 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 31 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 32 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 33 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 34 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 35 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 36 | 37 | if bid == 0: 38 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 39 | 40 | feat += res 41 | 42 | if hid + 1 in feat_ids: 43 | feats.append(feat.clone()) 44 | 45 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 46 | 47 | return feats 48 | def extract_feat_res_sup(img, backbone, feat_ids, bottleneck_ids, lids,shot=1): 49 | r""" Extract intermediate features from ResNet""" 50 | feats = [] 51 | 52 | # Layer 0 53 | feat = backbone.conv1.forward(img) 54 | feat = backbone.bn1.forward(feat) 55 | feat = backbone.relu.forward(feat) 56 | feat = backbone.maxpool.forward(feat) 57 | 58 | # Layer 1-4 59 | for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)): 60 | res = feat 61 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 62 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 63 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 64 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 65 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 66 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 67 | feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 68 | feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 69 | 70 | if bid == 0: 71 | res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 72 | 73 | feat += res#bchw 74 | 75 | if hid + 1 in feat_ids: 76 | feats.append(feat.clone()) 77 | 78 | feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 79 | 80 | return feats -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Simlarity Based Hyperrelation Network for Few-Shot Segmentation 2 | 3 | This is the implementation of the paper ["Multi-Simlarity Based Hyperrealation Network for Few-Shot Segmentation for Few-Shot Segmentation"](https://arxiv.org/abs/2203.09550). 4 | 5 | Implemented on Python 3.7 and Pytorch 1.8. 6 | The main structure of the network is as follows: 7 |

8 | 9 |

10 | 11 | The following figure shows our proposed Symmetric Merging Block (SMB): 12 |

13 | 14 |

15 | 16 | ## Requirements 17 | 18 | - Python 3.7 19 | - PyTorch 1.8 20 | - cuda 11.1 21 | - opencv 4.3 22 | - tensorboard 1.14 23 | 24 | ## Preparing Few-Shot Segmentation Datasets 25 | Download following datasets: 26 | 27 | > #### 1. PASCAL-5i 28 | > Download PASCAL VOC2012 devkit (train/val data): 29 | > ```bash 30 | > wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 31 | > ``` 32 | > Download PASCAL VOC2012 SDS extended mask annotations from [[Google Drive](https://drive.google.com/file/d/10zxG2VExoEZUeyQl_uXga2OWHjGeZaf2/view?usp=sharing)]. It was created by Juhong Min et al. 33 | 34 | > #### 2. COCO-20i 35 | > Download COCO2014 train/val images and annotations: 36 | > ```bash 37 | > wget http://images.cocodataset.org/zips/train2014.zip 38 | > wget http://images.cocodataset.org/zips/val2014.zip 39 | > wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 40 | > ``` 41 | 42 | The Dataset directory is the folder where you put the Dataset. The Dataset directory structure is as follows: 43 | 44 | ── Dataset/ 45 | ├── VOC2012/ # PASCAL VOC2012 devkit 46 | │ ├── Annotations/ 47 | │ ├── ImageSets/ 48 | │ ├── ... 49 | │ └── SegmentationClassAug/ 50 | ├── COCO2014/ 51 | ├── annotations/ 52 | │ └── ..some json files.. 53 | ├── train2014/ 54 | └── val2014/ 55 | 56 | 57 | ## Training 58 | > ### 1. PASCAL-5i 59 | >Set parameters in file train.py and run the following command: 60 | >`python train.py` 61 | 62 | > ### 2. COCO-20i 63 | >Set parameters in file train.py and run the following command: 64 | > `python train_coco.py` 65 | 66 | 67 | ## Testing 68 | 69 | > Pretrained models are available on our [[Baidu Netdisk](https://pan.baidu.com/s/1nUUpWlRUaJ9Kq95M18DipA?pwd=gjpt)]. 70 | 71 | >Set the parameters in test.py and execute: 72 | > `python test.py` 73 | 74 | ## Visualization 75 | 76 | * To visualize mask predictions, add command line argument **--visualize**: 77 | (prediction results will be saved under vis/ directory) 78 | ```bash 79 | python test.py '...other arguments...' --visualize 80 | ``` 81 | ## Single Similarity 82 | To train and test for single similarity, change the mshnet in F file to mshtnet_cor or mshnet_pt 83 | #### Example qualitative results (5-shot and 1-shot): 84 | 85 |

86 | 87 |

88 | 89 | ## Acknowledgment 90 | Thanks to Juhong Min, Dahyun Kang and Minsu Ch for their contributions, much of our code is based on their shared [HSNet](https://github.com/juhongm999/hsnet). 91 | ## BibTeX 92 | If you use this code for your research, please consider citing: 93 | ````BibTeX 94 | @InProceedings{ 95 | title={Multi-Simlarity Based Hyperrelation Network for Few-Shot Segmentation}, 96 | author={Xiangwen Shi, Shaobing Zhang, Miao Cheng, Lian He, Zhe Cui, Xianghong Tang}, 97 | } 98 | ```` 99 | -------------------------------------------------------------------------------- /data/splits/fss/val.txt: -------------------------------------------------------------------------------- 1 | handcuff 2 | mortar 3 | matchstick 4 | wine_bottle 5 | dowitcher 6 | triumphal_arch 7 | gyromitra 8 | hatchet 9 | airliner 10 | broccoli 11 | olive 12 | pubg_lvl3backpack 13 | calculator 14 | toucan 15 | shovel 16 | sewing_machine 17 | icecream 18 | woodpecker 19 | pig 20 | relay_stick 21 | mcdonald_sign 22 | cpu 23 | peanut 24 | pumpkin 25 | sturgeon 26 | hammer 27 | hami_melon 28 | squirrel_monkey 29 | shuriken 30 | power_drill 31 | pingpong_ball 32 | crocodile 33 | carambola 34 | monarch_butterfly 35 | drum 36 | water_tower 37 | panda 38 | toilet_brush 39 | pay_phone 40 | yonex_icon 41 | cricketball 42 | revolver 43 | chimpanzee 44 | crab 45 | corn 46 | baseball 47 | rabbit 48 | croquet_ball 49 | artichoke 50 | abacus 51 | harp 52 | bell 53 | gas_tank 54 | scissors 55 | vase 56 | upright_piano 57 | typewriter 58 | bittern 59 | impala 60 | tray 61 | fire_hydrant 62 | beer_bottle 63 | sock 64 | soup_bowl 65 | spider 66 | cherry 67 | macaw 68 | toilet_seat 69 | fire_balloon 70 | french_ball 71 | fox_squirrel 72 | volleyball 73 | cornmeal 74 | folding_chair 75 | pubg_airdrop 76 | beagle 77 | skateboard 78 | narcissus 79 | whiptail 80 | cup 81 | arabian_camel 82 | badger 83 | stopwatch 84 | ab_wheel 85 | ox 86 | lettuce 87 | monocycle 88 | redshank 89 | vulture 90 | whistle 91 | smoothing_iron 92 | mashed_potato 93 | conveyor 94 | yoga_pad 95 | tow_truck 96 | siamese_cat 97 | cigar 98 | white_stork 99 | sniper_rifle 100 | stretcher 101 | tulip 102 | handkerchief 103 | basset 104 | iceberg 105 | gibbon 106 | lacewing 107 | thrush 108 | cheetah 109 | bighorn_sheep 110 | espresso_maker 111 | pretzel 112 | english_setter 113 | sandbar 114 | cheese 115 | daisy 116 | arctic_fox 117 | briard 118 | colubus 119 | balance_beam 120 | coffeepot 121 | soap_dispenser 122 | yawl 123 | consomme 124 | parking_meter 125 | cactus 126 | turnstile 127 | taro 128 | fire_screen 129 | digital_clock 130 | rose 131 | pomegranate 132 | bee_eater 133 | schooner 134 | ski_mask 135 | jay_bird 136 | plaice 137 | red_fox 138 | syringe 139 | camomile 140 | pickelhaube 141 | blenheim_spaniel 142 | pear 143 | parachute 144 | common_newt 145 | bowtie 146 | cigarette 147 | oscilloscope 148 | laptop 149 | african_crocodile 150 | apron 151 | coconut 152 | sandal 153 | kwanyin 154 | lion 155 | eel 156 | balloon 157 | crepe 158 | armadillo 159 | kazoo 160 | lemon 161 | spider_monkey 162 | tape_player 163 | ipod 164 | bee 165 | sea_cucumber 166 | suitcase 167 | television 168 | pillow 169 | banjo 170 | rock_snake 171 | partridge 172 | platypus 173 | lycaenid_butterfly 174 | pinecone 175 | conversion_plug 176 | wolf 177 | frying_pan 178 | timber_wolf 179 | bluetick 180 | crayon 181 | giant_schnauzer 182 | orang 183 | scarerow 184 | kobe_logo 185 | loguat 186 | saxophone 187 | ceiling_fan 188 | cardoon 189 | equestrian_helmet 190 | louvre_pyramid 191 | hotdog 192 | ironing_board 193 | razor 194 | nagoya_castle 195 | loggerhead_turtle 196 | lipstick 197 | cradle 198 | strongbox 199 | raven 200 | kit_fox 201 | albatross 202 | flat-coated_retriever 203 | beer_glass 204 | ice_lolly 205 | sungnyemun 206 | totem_pole 207 | vacuum 208 | bolete 209 | mango 210 | ginger 211 | weasel 212 | cabbage 213 | refrigerator 214 | school_bus 215 | hippo 216 | tiger_cat 217 | saltshaker 218 | piano_keyboard 219 | windsor_tie 220 | sea_urchin 221 | microsd 222 | barbell 223 | swim_ring 224 | bulbul_bird 225 | water_ouzel 226 | ac_ground 227 | sweatshirt 228 | umbrella 229 | hair_drier 230 | hammerhead_shark 231 | tomato 232 | projector 233 | cushion 234 | dishwasher 235 | three-toed_sloth 236 | tiger_shark 237 | har_gow 238 | baby 239 | thor's_hammer 240 | nike_logo 241 | -------------------------------------------------------------------------------- /data/splits/fss/test.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa 241 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | r""" MSHNet testing code """ 2 | import argparse 3 | 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import torch 7 | 8 | from model.mshnet import MsimilarityHyperrelationNetwork 9 | from common.logger import Logger, AverageMeter 10 | from common.vis import Visualizer 11 | from common.evaluation import Evaluator 12 | from common import utils 13 | from data.dataset import FSSDataset 14 | import os 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | 17 | def test(model, dataloader, nshot): 18 | r""" Test MSHNet """ 19 | 20 | # Freeze randomness during testing for reproducibility 21 | utils.fix_randseed(0) 22 | average_meter = AverageMeter(dataloader.dataset) 23 | 24 | for idx, batch in enumerate(dataloader): 25 | 26 | # 1. Hypercorrelation Squeeze Networks forward pass 27 | batch = utils.to_cuda(batch) 28 | pred_mask = model.predict_mask_nshot(batch, nshot=nshot) 29 | 30 | assert pred_mask.size() == batch['query_mask'].size() 31 | 32 | # 2. Evaluate prediction 33 | area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch) 34 | average_meter.update(area_inter, area_union, batch['class_id'], loss=None) 35 | average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) 36 | 37 | # Visualize predictions 38 | if Visualizer.visualize: 39 | Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'], 40 | batch['query_img'], batch['query_mask'], 41 | pred_mask, batch['class_id'], idx, 42 | area_inter[1].float() / area_union[1].float()) 43 | 44 | # Write evaluation results 45 | average_meter.write_result('Test', 0) 46 | miou, fb_iou = average_meter.compute_iou() 47 | 48 | return miou, fb_iou 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | # Arguments parsing 54 | parser = argparse.ArgumentParser(description='MSHNet Pytorch Implementation') 55 | parser.add_argument('--datapath', type=str, default='/home/alex/pytorch/data/VOCdevkit') 56 | #parser.add_argument('--datapath', type=str, default='/home/alex/pytorch/data') 57 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 58 | parser.add_argument('--logpath', type=str, default='') 59 | parser.add_argument('--bsz', type=int, default=1) 60 | parser.add_argument('--nworker', type=int, default=1) 61 | parser.add_argument('--load', type=str, default='./resume_pascal50_0.pth') 62 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 63 | parser.add_argument('--nshot', type=int, default=1) 64 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101']) 65 | parser.add_argument('--visualize', action='store_true') 66 | parser.add_argument('--use_original_imgsize', action='store_true') 67 | args = parser.parse_args() 68 | Logger.initialize(args, training=False) 69 | 70 | # Model initialization 71 | model = MsimilarityHyperrelationNetwork(args.backbone, args.use_original_imgsize,shot=args.nshot) 72 | model.eval() 73 | Logger.log_params(model) 74 | 75 | # Device setup 76 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 77 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 78 | model = model.cuda() 79 | #model.to(device) 80 | 81 | # Load trained model 82 | if args.load == '': raise Exception('Pretrained model not specified.') 83 | model.load_state_dict(torch.load(args.load)) 84 | 85 | # Helper classes (for testing) initialization 86 | Evaluator.initialize() 87 | Visualizer.initialize(args.visualize) 88 | 89 | # Dataset initialization 90 | FSSDataset.initialize(img_size=473, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize) 91 | dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) 92 | 93 | # Test MSHNet 94 | with torch.no_grad(): 95 | test_miou, test_fb_iou = test(model, dataloader_test, args.nshot) 96 | Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item())) 97 | Logger.info('==================== Finished Testing ====================') 98 | -------------------------------------------------------------------------------- /common/vis.py: -------------------------------------------------------------------------------- 1 | r""" Visualize model predictions """ 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | 8 | from . import utils 9 | 10 | 11 | class Visualizer: 12 | 13 | @classmethod 14 | def initialize(cls, visualize): 15 | cls.visualize = visualize 16 | if not visualize: 17 | return 18 | 19 | cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)} 20 | for key, value in cls.colors.items(): 21 | cls.colors[key] = tuple([c / 255 for c in cls.colors[key]]) 22 | 23 | cls.mean_img = [0.485, 0.456, 0.406] 24 | cls.std_img = [0.229, 0.224, 0.225] 25 | cls.to_pil = transforms.ToPILImage() 26 | cls.vis_path = './vis/' 27 | if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path) 28 | 29 | @classmethod 30 | def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None): 31 | spt_img_b = utils.to_cpu(spt_img_b) 32 | spt_mask_b = utils.to_cpu(spt_mask_b) 33 | qry_img_b = utils.to_cpu(qry_img_b) 34 | qry_mask_b = utils.to_cpu(qry_mask_b) 35 | pred_mask_b = utils.to_cpu(pred_mask_b) 36 | cls_id_b = utils.to_cpu(cls_id_b) 37 | 38 | for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \ 39 | enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)): 40 | iou = iou_b[sample_idx] if iou_b is not None else None 41 | cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou) 42 | 43 | @classmethod 44 | def to_numpy(cls, tensor, type): 45 | if type == 'img': 46 | return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8) 47 | elif type == 'mask': 48 | return np.array(tensor).astype(np.uint8) 49 | else: 50 | raise Exception('Undefined tensor type: %s' % type) 51 | 52 | @classmethod 53 | def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None): 54 | 55 | spt_color = cls.colors['blue'] 56 | qry_color = cls.colors['red'] 57 | pred_color = cls.colors['red'] 58 | 59 | spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs] 60 | spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs] 61 | spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks] 62 | spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)] 63 | 64 | qry_img = cls.to_numpy(qry_img, 'img') 65 | qry_pil = cls.to_pil(qry_img) 66 | qry_mask = cls.to_numpy(qry_mask, 'mask') 67 | pred_mask = cls.to_numpy(pred_mask, 'mask') 68 | pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color)) 69 | qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color)) 70 | 71 | merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil]) 72 | 73 | iou = iou.item() if iou else 0.0 74 | merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg') 75 | 76 | @classmethod 77 | def merge_image_pair(cls, pil_imgs): 78 | r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """ 79 | 80 | canvas_width = sum([pil.size[0] for pil in pil_imgs]) 81 | canvas_height = max([pil.size[1] for pil in pil_imgs]) 82 | canvas = Image.new('RGB', (canvas_width, canvas_height)) 83 | 84 | xpos = 0 85 | for pil in pil_imgs: 86 | canvas.paste(pil, (xpos, 0)) 87 | xpos += pil.size[0] 88 | 89 | return canvas 90 | 91 | @classmethod 92 | def apply_mask(cls, image, mask, color, alpha=0.5): 93 | r""" Apply mask to the given image. """ 94 | for c in range(3): 95 | image[:, :, c] = np.where(mask == 1, 96 | image[:, :, c] * 97 | (1 - alpha) + alpha * color[c] * 255, 98 | image[:, :, c]) 99 | return image 100 | 101 | @classmethod 102 | def unnormalize(cls, img): 103 | img = img.clone() 104 | for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img): 105 | im_channel.mul_(std).add_(mean) 106 | return img 107 | -------------------------------------------------------------------------------- /data/fss.py: -------------------------------------------------------------------------------- 1 | r""" FSS-1000 few-shot semantic segmentation dataset """ 2 | import os 3 | import glob 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | import numpy as np 10 | 11 | 12 | class DatasetFSS(Dataset): 13 | def __init__(self, datapath, fold, transform, split, shot): 14 | self.split = split 15 | self.benchmark = 'fss' 16 | self.shot = shot 17 | 18 | self.base_path = os.path.join(datapath, 'FSS-1000') 19 | 20 | # Given predefined test split, load randomly generated training/val splits: 21 | # (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7)) 22 | with open('./data/splits/fss/%s.txt' % split, 'r') as f: 23 | self.categories = f.read().split('\n')[:-1] 24 | self.categories = sorted(self.categories) 25 | 26 | self.class_ids = self.build_class_ids() 27 | self.img_metadata = self.build_img_metadata() 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.img_metadata) 33 | 34 | def __getitem__(self, idx): 35 | query_name, support_names, class_sample = self.sample_episode(idx) 36 | query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names) 37 | 38 | query_img = self.transform(query_img) 39 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 40 | 41 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 42 | 43 | support_masks_tmp = [] 44 | for smask in support_masks: 45 | smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 46 | support_masks_tmp.append(smask) 47 | support_masks = torch.stack(support_masks_tmp) 48 | 49 | batch = {'query_img': query_img, 50 | 'query_mask': query_mask, 51 | 'query_name': query_name, 52 | 53 | 'support_imgs': support_imgs, 54 | 'support_masks': support_masks, 55 | 'support_names': support_names, 56 | 57 | 'class_id': torch.tensor(class_sample)} 58 | 59 | return batch 60 | 61 | def load_frame(self, query_name, support_names): 62 | query_img = Image.open(query_name).convert('RGB') 63 | support_imgs = [Image.open(name).convert('RGB') for name in support_names] 64 | 65 | query_id = query_name.split('/')[-1].split('.')[0] 66 | query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png' 67 | support_ids = [name.split('/')[-1].split('.')[0] for name in support_names] 68 | support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)] 69 | 70 | query_mask = self.read_mask(query_name) 71 | support_masks = [self.read_mask(name) for name in support_names] 72 | 73 | return query_img, query_mask, support_imgs, support_masks 74 | 75 | def read_mask(self, img_name): 76 | mask = torch.tensor(np.array(Image.open(img_name).convert('L'))) 77 | mask[mask < 128] = 0 78 | mask[mask >= 128] = 1 79 | return mask 80 | 81 | def sample_episode(self, idx): 82 | query_name = self.img_metadata[idx] 83 | class_sample = self.categories.index(query_name.split('/')[-2]) 84 | if self.split == 'val': 85 | class_sample += 520 86 | elif self.split == 'test': 87 | class_sample += 760 88 | 89 | support_names = [] 90 | while True: # keep sampling support set if query == support 91 | support_name = np.random.choice(range(1, 11), 1, replace=False)[0] 92 | support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg' 93 | if query_name != support_name: support_names.append(support_name) 94 | if len(support_names) == self.shot: break 95 | 96 | return query_name, support_names, class_sample 97 | 98 | def build_class_ids(self): 99 | if self.split == 'trn': 100 | class_ids = range(0, 520) 101 | elif self.split == 'val': 102 | class_ids = range(520, 760) 103 | elif self.split == 'test': 104 | class_ids = range(760, 1000) 105 | return class_ids 106 | 107 | def build_img_metadata(self): 108 | img_metadata = [] 109 | for cat in self.categories: 110 | img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))]) 111 | for img_path in img_paths: 112 | if os.path.basename(img_path).split('.')[1] == 'jpg': 113 | img_metadata.append(img_path) 114 | return img_metadata 115 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | r""" Logging during training/testing """ 2 | import datetime 3 | import logging 4 | import os 5 | 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | 9 | 10 | class AverageMeter: 11 | r""" Stores loss, evaluation results """ 12 | def __init__(self, dataset): 13 | self.benchmark = dataset.benchmark 14 | self.class_ids_interest = dataset.class_ids 15 | self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() 16 | 17 | if self.benchmark == 'pascal': 18 | self.nclass = 20 19 | elif self.benchmark == 'coco': 20 | self.nclass = 80 21 | elif self.benchmark == 'fss': 22 | self.nclass = 1000 23 | 24 | self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda() 25 | self.union_buf = torch.zeros([2, self.nclass]).float().cuda() 26 | self.ones = torch.ones_like(self.union_buf) 27 | self.loss_buf = [] 28 | 29 | def update(self, inter_b, union_b, class_id, loss): 30 | self.intersection_buf.index_add_(1, class_id, inter_b.float()) 31 | self.union_buf.index_add_(1, class_id, union_b.float()) 32 | if loss is None: 33 | loss = torch.tensor(0.0) 34 | self.loss_buf.append(loss) 35 | 36 | def compute_iou(self): 37 | iou = self.intersection_buf.float() / \ 38 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 39 | iou = iou.index_select(1, self.class_ids_interest) 40 | miou = iou[1].mean() * 100 41 | 42 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / 43 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 44 | 45 | return miou, fb_iou 46 | 47 | def write_result(self, split, epoch): 48 | iou, fb_iou = self.compute_iou() 49 | 50 | loss_buf = torch.stack(self.loss_buf) 51 | msg = '\n*** %s ' % split 52 | msg += '[@Epoch %02d] ' % epoch 53 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 54 | msg += 'mIoU: %5.2f ' % iou 55 | msg += 'FB-IoU: %5.2f ' % fb_iou 56 | 57 | msg += '***\n' 58 | Logger.info(msg) 59 | 60 | def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): 61 | if batch_idx % write_batch_idx == 0: 62 | msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' 63 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) 64 | iou, fb_iou = self.compute_iou() 65 | if epoch != -1: 66 | loss_buf = torch.stack(self.loss_buf) 67 | msg += 'L: %6.5f ' % loss_buf[-1] 68 | msg += 'Avg L: %6.5f ' % loss_buf.mean() 69 | msg += 'mIoU: %5.2f | ' % iou 70 | msg += 'FB-IoU: %5.2f' % fb_iou 71 | Logger.info(msg) 72 | 73 | 74 | class Logger: 75 | r""" Writes evaluation results of training/testing """ 76 | @classmethod 77 | def initialize(cls, args, training): 78 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') 79 | logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime 80 | if logpath == '': logpath = logtime 81 | 82 | cls.logpath = os.path.join('logs', logpath + '.log') 83 | cls.benchmark = args.benchmark 84 | os.makedirs(cls.logpath) 85 | 86 | logging.basicConfig(filemode='w', 87 | filename=os.path.join(cls.logpath, 'log.txt'), 88 | level=logging.INFO, 89 | format='%(message)s', 90 | datefmt='%m-%d %H:%M:%S') 91 | 92 | # Console log config 93 | console = logging.StreamHandler() 94 | console.setLevel(logging.INFO) 95 | formatter = logging.Formatter('%(message)s') 96 | console.setFormatter(formatter) 97 | logging.getLogger('').addHandler(console) 98 | 99 | # Tensorboard writer 100 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) 101 | 102 | # Log arguments 103 | logging.info('\n:=========== Few-shot Seg. with HSNet ===========') 104 | for arg_key in args.__dict__: 105 | logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) 106 | logging.info(':================================================\n') 107 | 108 | @classmethod 109 | def info(cls, msg): 110 | r""" Writes log message to log.txt """ 111 | logging.info(msg) 112 | 113 | @classmethod 114 | def save_model_miou(cls, model, epoch, val_miou): 115 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) 116 | cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) 117 | 118 | @classmethod 119 | def log_params(cls, model): 120 | backbone_param = 0 121 | learner_param = 0 122 | for k in model.state_dict().keys(): 123 | n_param = model.state_dict()[k].view(-1).size(0) 124 | if k.split('.')[0] in 'backbone': 125 | if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet 126 | continue 127 | backbone_param += n_param 128 | else: 129 | learner_param += n_param 130 | Logger.info('Backbone # param.: %d' % backbone_param) 131 | Logger.info('Learnable # param.: %d' % learner_param) 132 | Logger.info('Total # param.: %d' % (backbone_param + learner_param)) 133 | 134 | -------------------------------------------------------------------------------- /model/mshnet.py: -------------------------------------------------------------------------------- 1 | r""" Hypercorrelation Squeeze Network """ 2 | from functools import reduce 3 | from operator import add 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision.models import resnet 9 | from torchvision.models import vgg 10 | from .base.merge import merge 11 | from .base.feature import extract_feat_vgg, extract_feat_res 12 | from .base.correlation import Correlation 13 | 14 | 15 | class MsimilarityHyperrelationNetwork(nn.Module): 16 | def __init__(self, backbone, use_original_imgsize,shot=1): 17 | super(MsimilarityHyperrelationNetwork, self).__init__() 18 | 19 | # 1. Backbone network initialization 20 | self.backbone_type = backbone 21 | self.use_original_imgsize = use_original_imgsize 22 | self.shot=shot 23 | if backbone == 'vgg16': 24 | self.backbone = vgg.vgg16(pretrained=True) 25 | self.feat_ids = [17, 19, 21, 24, 26, 28, 30] 26 | self.extract_feats = extract_feat_vgg 27 | nbottlenecks = [2, 2, 3, 3, 3, 1] 28 | elif backbone == 'resnet50': 29 | self.backbone = resnet.resnet50(pretrained=True) 30 | self.feat_ids = list(range(4, 17)) 31 | self.extract_feats = extract_feat_res 32 | nbottlenecks = [3, 4, 6, 3] 33 | elif backbone == 'resnet101': 34 | self.backbone = resnet.resnet101(pretrained=True) 35 | self.feat_ids = list(range(4, 34)) 36 | self.extract_feats = extract_feat_res 37 | nbottlenecks = [3, 4, 23, 3] 38 | else: 39 | raise Exception('Unavailable backbone: %s' % backbone) 40 | 41 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) 42 | self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) 43 | self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3] 44 | #print(self.bottleneck_ids)#[0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2] 45 | #print(self.lids)#[1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4] 46 | #print(self.stack_ids)#[ 3, 9, 13] 47 | #print(self.feat_ids)#[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] 48 | self.backbone.eval() 49 | self.cross_entropy_loss =nn.CrossEntropyLoss()#(weight=torch.tensor([0.2,0.8]).cuda()) 50 | self.merge=merge(shot,nsimlairy=list(reversed(nbottlenecks[-3:])),criter=self.cross_entropy_loss) 51 | def forward(self, query_img, support_img, support_mask,gt=None): 52 | sup_feats=[]#shot 53 | corrs=[] 54 | with torch.no_grad(): 55 | query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 56 | for i in range(self.shot): 57 | support_feats = self.extract_feats(support_img[:,i,:,:,:], self.backbone, self.feat_ids, self.bottleneck_ids, self.lids) 58 | support_feats = self.mask_feature(support_feats, support_mask[:,i,:,:]) 59 | corr,sups = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)#[corr_l4, corr_l3, corr_l2] 60 | corrs.append(corr)#s,l,b,c,h,w 61 | sup_feats.append(sups)#s,l,n,b,2*c,hw 62 | 63 | logit_mask,loss = self.merge(sup_feats,corrs,gt) 64 | if not self.use_original_imgsize: 65 | logit_mask = F.interpolate(logit_mask, support_img.size()[-2:], mode='bilinear', align_corners=True) 66 | 67 | return logit_mask,loss 68 | 69 | def mask_feature(self, features, support_mask):#bchw 70 | bs=features[0].shape[0] 71 | initSize=((features[0].shape[-1])*2,)*2 72 | support_mask = support_mask.unsqueeze(1).float() 73 | support_mask = F.interpolate(support_mask,initSize, mode='bilinear', align_corners=True) 74 | for idx, feature in enumerate(features): 75 | feat=[] 76 | if support_mask.shape[-1]!=feature.shape[-1]: 77 | support_mask = F.interpolate(support_mask, feature.size()[2:], mode='bilinear', align_corners=True) 78 | for i in range(bs): 79 | featI=feature[i].flatten(start_dim=1)#c,hw 80 | maskI=support_mask[i].flatten(start_dim=1)#hw 81 | featI = featI * maskI 82 | maskI=maskI.squeeze() 83 | meanVal=maskI[maskI>0].mean() 84 | realSupI=featI[:,maskI>=meanVal] 85 | if maskI.sum()==0: 86 | realSupI=torch.zeros(featI.shape[0],1).cuda() 87 | feat.append(realSupI)#[b,]ch,w 88 | features[idx] = feat#nfeatures ,bs,ch,w 89 | return features 90 | 91 | def predict_mask_nshot(self, batch, nshot): 92 | logit_mask,loss = self(batch['query_img'], batch['support_imgs'], batch['support_masks'],batch['query_mask']) 93 | if self.use_original_imgsize: 94 | org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) 95 | logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) 96 | 97 | logit_mask_agg = logit_mask.argmax(dim=1) 98 | return logit_mask_agg 99 | 100 | def compute_objective(self, logit_mask, gt_mask): 101 | bsz = logit_mask.size(0) 102 | logit_mask = logit_mask.view(bsz, 2, -1) 103 | gt_mask = gt_mask.view(bsz, -1).long() 104 | 105 | return self.cross_entropy_loss(logit_mask, gt_mask) 106 | 107 | def train_mode(self): 108 | self.train() 109 | self.backbone.eval() # to prevent BN from learning data statistics with exponential averaging 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | r""" MSHNet training (validation) code """ 2 | import argparse 3 | 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch 7 | import os 8 | from model.mshnet import MsimilarityHyperrelationNetwork 9 | from common.logger import Logger, AverageMeter 10 | from common.evaluation import Evaluator 11 | from common import utils 12 | from data.dataset import FSSDataset 13 | def train(epoch, model, dataloader, optimizer, training): 14 | r""" Train MSHNet """ 15 | 16 | # Force randomness during training / freeze randomness during testing 17 | utils.fix_randseed(None) if training else utils.fix_randseed(0) 18 | model.train_mode() if training else model.eval() 19 | average_meter = AverageMeter(dataloader.dataset) 20 | 21 | for idx, batch in enumerate(dataloader): 22 | 23 | # 1. MSHNet forward pass 24 | batch = utils.to_cuda(batch) 25 | logit_mask,loss = model(batch['query_img'], batch['support_imgs'], batch['support_masks'],batch['query_mask']) 26 | pred_mask = logit_mask.argmax(dim=1) 27 | 28 | # 2. Compute loss & update model parameters 29 | #loss = model.compute_objective(logit_mask, batch['query_mask']) 30 | if training: 31 | optimizer.zero_grad() 32 | loss.backward() 33 | optimizer.step() 34 | 35 | # 3. Evaluate prediction 36 | area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) 37 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) 38 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) 39 | 40 | # Write evaluation results 41 | average_meter.write_result('Training' if training else 'Validation', epoch) 42 | avg_loss = utils.mean(average_meter.loss_buf) 43 | miou, fb_iou = average_meter.compute_iou() 44 | 45 | return avg_loss, miou, fb_iou 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | # Arguments parsing 51 | parser = argparse.ArgumentParser(description='MSHNet Pytorch Implementat1ion') 52 | parser.add_argument('--datapath', type=str, default='/home/alex/pytorch/data/VOCdevkit') 53 | #parser.add_argument('--datapath', type=str, default='/home/alex/pytorch/data') 54 | parser.add_argument('--save_path', type=str, default='./resume') 55 | parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss']) 56 | parser.add_argument('--logpath', type=str, default='') 57 | parser.add_argument('--bsz', type=int, default=11) 58 | parser.add_argument('--shot', type=int, default=1) 59 | parser.add_argument('--momentum', type=float, default=0.9) 60 | parser.add_argument('--weight_decay', type=float, default=0.00005) 61 | parser.add_argument('--lr', type=float, default=0.025) 62 | parser.add_argument('--niter', type=int, default=300) 63 | parser.add_argument('--nworker', type=int, default=12) 64 | parser.add_argument('--fold', type=int, default=2, choices=[0, 1, 2, 3]) 65 | parser.add_argument('--backbone', type=str, default='resnet50', choices=['vgg16', 'resnet50', 'resnet101']) 66 | args = parser.parse_args() 67 | Logger.initialize(args, training=True) 68 | 69 | # Model initialization 70 | model = MsimilarityHyperrelationNetwork(args.backbone, False,shot=args.shot) 71 | Logger.log_params(model) 72 | 73 | # Device setup 74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 75 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 76 | #model = nn.DataParallel(model) 77 | model=model.cuda() 78 | # Helper classes (for training) initialization 79 | optimizer = torch.optim.SGD( 80 | model.merge.parameters(), 81 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 82 | lrschem = optim.lr_scheduler.ExponentialLR(optimizer, 0.9) 83 | Evaluator.initialize() 84 | if os.path.exists(args.save_path + '/resume.pth'): 85 | print('---load state_dict from: ',args.save_path + '/resume.pth') 86 | val=torch.load(args.save_path + '/resume.pth') 87 | epoch=val['epoch'] 88 | model.load_state_dict(val['state_dict']) 89 | optimizer.load_state_dict(val['optimizer']) 90 | lrschem.load_state_dict(val['lr']) 91 | elif os.path.exists(args.save_path + '/weight.pth'): 92 | model.load_state_dict(torch.load(args.save_path + '/weight.pth')) 93 | epoch=0 94 | else: 95 | epoch=0 96 | print('---there is no resume or weight: ') 97 | # Dataset initialization 98 | print('epoch:',epoch,'lr:',lrschem.get_lr()) 99 | FSSDataset.initialize(img_size=473, datapath=args.datapath, use_original_imgsize=False) 100 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn',shot=args.shot) 101 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val',shot=args.shot) 102 | 103 | # Train MSHNet 104 | best_val_miou = float('-inf') 105 | best_val_loss = float('inf') 106 | while epoch < args.niter: 107 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True) 108 | filename = args.save_path + '/resume.pth' 109 | torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),'lr':lrschem.state_dict()}, filename) 110 | if epoch%5==0 and epoch!=0: 111 | lrschem.step() 112 | print("current lr:",lrschem.get_lr()) 113 | with torch.no_grad(): 114 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False) 115 | 116 | # Save the best model 117 | if val_miou > best_val_miou: 118 | best_val_miou = val_miou 119 | Logger.save_model_miou(model, epoch, val_miou) 120 | filename = args.save_path + '/train_epoch_' + str(epoch) + '_' + str(best_val_miou.item()) + '.pth' 121 | torch.save(model.state_dict(),filename) 122 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) 123 | Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) 124 | Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) 125 | Logger.tbd_writer.flush() 126 | epoch+=1 127 | Logger.tbd_writer.close() 128 | Logger.info('==================== Finished Training ====================') 129 | -------------------------------------------------------------------------------- /train_coco.py: -------------------------------------------------------------------------------- 1 | r""" MSHNet training (validation) code """ 2 | import argparse 3 | 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch 7 | import os 8 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 9 | from model.mshnet import MsimilarityHyperrelationNetwork 10 | from common.logger import Logger, AverageMeter 11 | from common.evaluation import Evaluator 12 | from common import utils 13 | from data.dataset import FSSDataset 14 | def train(epoch, model, dataloader, optimizer, training): 15 | r""" Train MSHNet """ 16 | 17 | # Force randomness during training / freeze randomness during testing 18 | utils.fix_randseed(None) if training else utils.fix_randseed(0) 19 | model.train_mode() if training else model.eval() 20 | average_meter = AverageMeter(dataloader.dataset) 21 | 22 | for idx, batch in enumerate(dataloader): 23 | 24 | # 1. MSHNet forward pass 25 | batch = utils.to_cuda(batch) 26 | logit_mask,loss = model(batch['query_img'], batch['support_imgs'], batch['support_masks'],batch['query_mask']) 27 | pred_mask = logit_mask.argmax(dim=1) 28 | 29 | # 2. Compute loss & update model parameters 30 | #loss = model.compute_objective(logit_mask, batch['query_mask']) 31 | if training: 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | # 3. Evaluate prediction 37 | area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) 38 | average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) 39 | average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) 40 | 41 | # Write evaluation results 42 | average_meter.write_result('Training' if training else 'Validation', epoch) 43 | avg_loss = utils.mean(average_meter.loss_buf) 44 | miou, fb_iou = average_meter.compute_iou() 45 | 46 | return avg_loss, miou, fb_iou 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | # Arguments parsing 52 | parser = argparse.ArgumentParser(description='MSHNet Pytorch Implementat1ion') 53 | #parser.add_argument('--datapath', type=str, default='/home/alex/project/data/voc/VOCdevkit-temp') 54 | parser.add_argument('--datapath', type=str, default='/home/alex/SSD/data') 55 | parser.add_argument('--save_path', type=str, default='./resume-coco101_0') 56 | parser.add_argument('--benchmark', type=str, default='coco', choices=['pascal', 'coco', 'fss']) 57 | parser.add_argument('--logpath', type=str, default='') 58 | parser.add_argument('--bsz', type=int, default=14) 59 | parser.add_argument('--shot', type=int, default=1) 60 | parser.add_argument('--momentum', type=float, default=0.9) 61 | parser.add_argument('--weight_decay', type=float, default=0.00005) 62 | parser.add_argument('--lr', type=float, default=0.025) 63 | parser.add_argument('--niter', type=int, default=300) 64 | parser.add_argument('--nworker', type=int, default=15) 65 | parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3]) 66 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101']) 67 | args = parser.parse_args() 68 | Logger.initialize(args, training=True) 69 | 70 | # Model initialization 71 | model = MsimilarityHyperrelationNetwork(args.backbone, False,shot=args.shot) 72 | Logger.log_params(model) 73 | 74 | # Device setup 75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 76 | Logger.info('# available GPUs: %d' % torch.cuda.device_count()) 77 | #model = nn.DataParallel(model) 78 | model=model.cuda() 79 | # Helper classes (for training) initialization 80 | optimizer = torch.optim.SGD( 81 | model.merge.parameters(), 82 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 83 | lrschem = optim.lr_scheduler.ExponentialLR(optimizer, 0.9) 84 | Evaluator.initialize() 85 | if os.path.exists(args.save_path + '/resume.pth'): 86 | print('---load state_dict from: ',args.save_path + '/resume.pth') 87 | val=torch.load(args.save_path + '/resume.pth') 88 | epoch=val['epoch'] 89 | model.load_state_dict(val['state_dict']) 90 | optimizer.load_state_dict(val['optimizer']) 91 | lrschem.load_state_dict(val['lr']) 92 | elif os.path.exists(args.save_path + '/weight.pth'): 93 | model.load_state_dict(torch.load(args.save_path + '/weight.pth')) 94 | epoch=0 95 | else: 96 | epoch=0 97 | print('---there is no resume or weight: ') 98 | # Dataset initialization 99 | print('epoch:',epoch,'lr:',lrschem.get_lr()) 100 | FSSDataset.initialize(img_size=473, datapath=args.datapath, use_original_imgsize=False) 101 | dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn',shot=args.shot) 102 | dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val',shot=args.shot) 103 | 104 | # Train MSHNet 105 | best_val_miou = float('-inf') 106 | best_val_loss = float('inf') 107 | while epoch < args.niter: 108 | trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True) 109 | filename = args.save_path + '/resume.pth' 110 | torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),'lr':lrschem.state_dict()}, filename) 111 | if epoch%5==0 and epoch!=0: 112 | lrschem.step() 113 | print("current lr:",lrschem.get_lr()) 114 | with torch.no_grad(): 115 | val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False) 116 | 117 | # Save the best model 118 | if val_miou > best_val_miou: 119 | best_val_miou = val_miou 120 | Logger.save_model_miou(model, epoch, val_miou) 121 | filename = args.save_path + '/train_epoch_' + str(epoch) + '_' + str(best_val_miou.item()) + '.pth' 122 | torch.save(model.state_dict(),filename) 123 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) 124 | Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) 125 | Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) 126 | Logger.tbd_writer.flush() 127 | epoch+=1 128 | Logger.tbd_writer.close() 129 | Logger.info('==================== Finished Training ====================') 130 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold0.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_001288__01 10 | 2007_001289__03 11 | 2007_001311__02 12 | 2007_001408__05 13 | 2007_001568__01 14 | 2007_001630__02 15 | 2007_001761__01 16 | 2007_001884__01 17 | 2007_002094__03 18 | 2007_002266__01 19 | 2007_002376__01 20 | 2007_002400__03 21 | 2007_002619__01 22 | 2007_002719__04 23 | 2007_003088__05 24 | 2007_003131__04 25 | 2007_003188__02 26 | 2007_003349__03 27 | 2007_003571__04 28 | 2007_003621__02 29 | 2007_003682__03 30 | 2007_003861__04 31 | 2007_004052__01 32 | 2007_004143__03 33 | 2007_004241__04 34 | 2007_004468__05 35 | 2007_005074__04 36 | 2007_005107__02 37 | 2007_005294__05 38 | 2007_005304__05 39 | 2007_005428__05 40 | 2007_005509__01 41 | 2007_005600__01 42 | 2007_005705__04 43 | 2007_005828__01 44 | 2007_006076__03 45 | 2007_006086__05 46 | 2007_006449__02 47 | 2007_006946__01 48 | 2007_007084__03 49 | 2007_007235__02 50 | 2007_007341__01 51 | 2007_007470__01 52 | 2007_007477__04 53 | 2007_007836__02 54 | 2007_008051__03 55 | 2007_008084__03 56 | 2007_008204__05 57 | 2007_008670__03 58 | 2007_009088__03 59 | 2007_009258__02 60 | 2007_009323__03 61 | 2007_009458__05 62 | 2007_009687__05 63 | 2007_009817__03 64 | 2007_009911__01 65 | 2008_000120__04 66 | 2008_000123__03 67 | 2008_000533__03 68 | 2008_000725__02 69 | 2008_000911__05 70 | 2008_001013__04 71 | 2008_001040__04 72 | 2008_001135__04 73 | 2008_001260__04 74 | 2008_001404__02 75 | 2008_001514__03 76 | 2008_001531__02 77 | 2008_001546__01 78 | 2008_001580__04 79 | 2008_001966__03 80 | 2008_001971__01 81 | 2008_002043__03 82 | 2008_002269__02 83 | 2008_002358__01 84 | 2008_002429__03 85 | 2008_002467__05 86 | 2008_002504__04 87 | 2008_002775__05 88 | 2008_002864__05 89 | 2008_003034__04 90 | 2008_003076__05 91 | 2008_003108__02 92 | 2008_003110__03 93 | 2008_003155__01 94 | 2008_003270__02 95 | 2008_003369__01 96 | 2008_003858__04 97 | 2008_003876__01 98 | 2008_003886__04 99 | 2008_003926__01 100 | 2008_003976__01 101 | 2008_004363__02 102 | 2008_004654__02 103 | 2008_004659__05 104 | 2008_004704__01 105 | 2008_004758__02 106 | 2008_004995__02 107 | 2008_005262__05 108 | 2008_005338__01 109 | 2008_005628__04 110 | 2008_005727__02 111 | 2008_005812__05 112 | 2008_005904__05 113 | 2008_006216__01 114 | 2008_006229__04 115 | 2008_006254__02 116 | 2008_006703__01 117 | 2008_007120__03 118 | 2008_007143__04 119 | 2008_007219__05 120 | 2008_007350__01 121 | 2008_007498__03 122 | 2008_007811__05 123 | 2008_007994__03 124 | 2008_008268__03 125 | 2008_008629__02 126 | 2008_008711__02 127 | 2008_008746__03 128 | 2009_000032__01 129 | 2009_000037__03 130 | 2009_000121__05 131 | 2009_000149__02 132 | 2009_000201__05 133 | 2009_000205__01 134 | 2009_000318__03 135 | 2009_000354__02 136 | 2009_000387__01 137 | 2009_000421__04 138 | 2009_000440__01 139 | 2009_000446__04 140 | 2009_000457__02 141 | 2009_000469__04 142 | 2009_000573__02 143 | 2009_000619__03 144 | 2009_000664__03 145 | 2009_000723__04 146 | 2009_000828__04 147 | 2009_000840__05 148 | 2009_000879__03 149 | 2009_000991__03 150 | 2009_000998__03 151 | 2009_001108__03 152 | 2009_001160__03 153 | 2009_001255__02 154 | 2009_001278__05 155 | 2009_001314__03 156 | 2009_001332__01 157 | 2009_001565__03 158 | 2009_001607__03 159 | 2009_001683__03 160 | 2009_001718__02 161 | 2009_001765__03 162 | 2009_001818__05 163 | 2009_001850__01 164 | 2009_001851__01 165 | 2009_001941__04 166 | 2009_002185__05 167 | 2009_002295__02 168 | 2009_002320__01 169 | 2009_002372__05 170 | 2009_002521__05 171 | 2009_002594__05 172 | 2009_002604__03 173 | 2009_002649__05 174 | 2009_002727__04 175 | 2009_002732__05 176 | 2009_002749__05 177 | 2009_002808__01 178 | 2009_002856__05 179 | 2009_002888__01 180 | 2009_002928__02 181 | 2009_003003__05 182 | 2009_003005__01 183 | 2009_003043__04 184 | 2009_003080__04 185 | 2009_003193__02 186 | 2009_003224__02 187 | 2009_003269__05 188 | 2009_003273__03 189 | 2009_003343__02 190 | 2009_003378__03 191 | 2009_003450__03 192 | 2009_003498__03 193 | 2009_003504__04 194 | 2009_003517__05 195 | 2009_003640__03 196 | 2009_003696__01 197 | 2009_003707__04 198 | 2009_003806__01 199 | 2009_003858__03 200 | 2009_003971__02 201 | 2009_004021__03 202 | 2009_004084__03 203 | 2009_004125__04 204 | 2009_004247__05 205 | 2009_004324__05 206 | 2009_004509__03 207 | 2009_004540__03 208 | 2009_004568__03 209 | 2009_004579__05 210 | 2009_004635__04 211 | 2009_004653__01 212 | 2009_004848__02 213 | 2009_004882__02 214 | 2009_004886__03 215 | 2009_004895__03 216 | 2009_004969__01 217 | 2009_005038__05 218 | 2009_005137__03 219 | 2009_005156__02 220 | 2009_005189__01 221 | 2009_005190__05 222 | 2009_005260__03 223 | 2009_005262__03 224 | 2009_005302__05 225 | 2010_000065__02 226 | 2010_000083__02 227 | 2010_000084__04 228 | 2010_000238__01 229 | 2010_000241__03 230 | 2010_000272__04 231 | 2010_000342__02 232 | 2010_000426__05 233 | 2010_000572__01 234 | 2010_000622__01 235 | 2010_000814__03 236 | 2010_000906__04 237 | 2010_000961__03 238 | 2010_001016__03 239 | 2010_001017__01 240 | 2010_001024__01 241 | 2010_001036__04 242 | 2010_001061__03 243 | 2010_001069__03 244 | 2010_001174__01 245 | 2010_001367__02 246 | 2010_001367__05 247 | 2010_001448__01 248 | 2010_001830__05 249 | 2010_001995__03 250 | 2010_002017__05 251 | 2010_002030__02 252 | 2010_002142__03 253 | 2010_002147__01 254 | 2010_002150__04 255 | 2010_002200__01 256 | 2010_002310__01 257 | 2010_002536__02 258 | 2010_002546__04 259 | 2010_002693__02 260 | 2010_002939__01 261 | 2010_003127__01 262 | 2010_003132__01 263 | 2010_003168__03 264 | 2010_003362__03 265 | 2010_003365__01 266 | 2010_003418__03 267 | 2010_003468__05 268 | 2010_003473__03 269 | 2010_003495__01 270 | 2010_003547__04 271 | 2010_003716__01 272 | 2010_003771__03 273 | 2010_003781__05 274 | 2010_003820__03 275 | 2010_003912__02 276 | 2010_003915__01 277 | 2010_004041__04 278 | 2010_004056__05 279 | 2010_004208__04 280 | 2010_004314__01 281 | 2010_004419__01 282 | 2010_004520__05 283 | 2010_004529__05 284 | 2010_004551__05 285 | 2010_004556__03 286 | 2010_004559__03 287 | 2010_004662__04 288 | 2010_004772__04 289 | 2010_004828__05 290 | 2010_004994__03 291 | 2010_005252__04 292 | 2010_005401__04 293 | 2010_005428__03 294 | 2010_005496__05 295 | 2010_005531__03 296 | 2010_005534__01 297 | 2010_005582__05 298 | 2010_005664__02 299 | 2010_005705__04 300 | 2010_005718__01 301 | 2010_005762__05 302 | 2010_005877__01 303 | 2010_005888__01 304 | 2010_006034__01 305 | 2010_006070__02 306 | 2011_000066__05 307 | 2011_000112__03 308 | 2011_000185__03 309 | 2011_000234__04 310 | 2011_000238__04 311 | 2011_000412__02 312 | 2011_000435__04 313 | 2011_000456__03 314 | 2011_000482__03 315 | 2011_000585__02 316 | 2011_000669__03 317 | 2011_000747__05 318 | 2011_000874__01 319 | 2011_001114__01 320 | 2011_001161__04 321 | 2011_001263__01 322 | 2011_001287__03 323 | 2011_001407__01 324 | 2011_001421__03 325 | 2011_001434__01 326 | 2011_001589__04 327 | 2011_001624__01 328 | 2011_001793__04 329 | 2011_001880__01 330 | 2011_001988__02 331 | 2011_002064__02 332 | 2011_002098__05 333 | 2011_002223__02 334 | 2011_002295__03 335 | 2011_002327__01 336 | 2011_002515__01 337 | 2011_002675__01 338 | 2011_002713__02 339 | 2011_002754__04 340 | 2011_002863__05 341 | 2011_002929__01 342 | 2011_002975__04 343 | 2011_003003__02 344 | 2011_003030__03 345 | 2011_003145__03 346 | 2011_003271__05 347 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold3.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__18 10 | 2007_000676__17 11 | 2007_000804__18 12 | 2007_000925__17 13 | 2007_001154__18 14 | 2007_001175__20 15 | 2007_001408__16 16 | 2007_001430__16 17 | 2007_001430__20 18 | 2007_001457__18 19 | 2007_001458__18 20 | 2007_001585__18 21 | 2007_001594__17 22 | 2007_001678__20 23 | 2007_001717__20 24 | 2007_001733__17 25 | 2007_001763__18 26 | 2007_001763__20 27 | 2007_002119__20 28 | 2007_002132__20 29 | 2007_002268__18 30 | 2007_002284__16 31 | 2007_002378__16 32 | 2007_002426__18 33 | 2007_002427__18 34 | 2007_002565__19 35 | 2007_002618__17 36 | 2007_002648__17 37 | 2007_002728__19 38 | 2007_003011__18 39 | 2007_003011__20 40 | 2007_003169__18 41 | 2007_003367__16 42 | 2007_003499__19 43 | 2007_003506__16 44 | 2007_003530__18 45 | 2007_003587__19 46 | 2007_003714__17 47 | 2007_003848__19 48 | 2007_003957__19 49 | 2007_004190__20 50 | 2007_004193__20 51 | 2007_004275__16 52 | 2007_004281__19 53 | 2007_004483__19 54 | 2007_004510__20 55 | 2007_004558__16 56 | 2007_004649__19 57 | 2007_004712__16 58 | 2007_004969__17 59 | 2007_005469__17 60 | 2007_005626__19 61 | 2007_005689__19 62 | 2007_005813__16 63 | 2007_005857__16 64 | 2007_005915__17 65 | 2007_006171__18 66 | 2007_006348__20 67 | 2007_006373__18 68 | 2007_006678__17 69 | 2007_006680__19 70 | 2007_006802__19 71 | 2007_007130__20 72 | 2007_007165__17 73 | 2007_007168__19 74 | 2007_007195__19 75 | 2007_007196__20 76 | 2007_007203__20 77 | 2007_007417__18 78 | 2007_007534__17 79 | 2007_007624__16 80 | 2007_007795__16 81 | 2007_007881__19 82 | 2007_007996__18 83 | 2007_008204__20 84 | 2007_008260__18 85 | 2007_008339__19 86 | 2007_008374__20 87 | 2007_008543__18 88 | 2007_008547__16 89 | 2007_009068__18 90 | 2007_009252__18 91 | 2007_009320__17 92 | 2007_009419__16 93 | 2007_009446__20 94 | 2007_009521__18 95 | 2007_009521__20 96 | 2007_009592__18 97 | 2007_009655__18 98 | 2007_009684__18 99 | 2007_009750__16 100 | 2008_000016__20 101 | 2008_000149__18 102 | 2008_000270__18 103 | 2008_000391__16 104 | 2008_000589__18 105 | 2008_000657__19 106 | 2008_001078__16 107 | 2008_001283__16 108 | 2008_001688__16 109 | 2008_001688__20 110 | 2008_001966__16 111 | 2008_002273__16 112 | 2008_002379__16 113 | 2008_002464__20 114 | 2008_002536__17 115 | 2008_002680__20 116 | 2008_002900__19 117 | 2008_002929__18 118 | 2008_003003__20 119 | 2008_003026__20 120 | 2008_003105__19 121 | 2008_003135__16 122 | 2008_003676__16 123 | 2008_003709__18 124 | 2008_003733__18 125 | 2008_003885__20 126 | 2008_004172__18 127 | 2008_004212__19 128 | 2008_004279__20 129 | 2008_004367__19 130 | 2008_004453__17 131 | 2008_004477__16 132 | 2008_004562__18 133 | 2008_004610__19 134 | 2008_004621__17 135 | 2008_004754__20 136 | 2008_004854__17 137 | 2008_004910__20 138 | 2008_005089__20 139 | 2008_005217__16 140 | 2008_005242__16 141 | 2008_005254__20 142 | 2008_005439__20 143 | 2008_005445__20 144 | 2008_005544__19 145 | 2008_005633__17 146 | 2008_005680__16 147 | 2008_006055__19 148 | 2008_006159__20 149 | 2008_006327__17 150 | 2008_006523__19 151 | 2008_006553__19 152 | 2008_006752__19 153 | 2008_006784__18 154 | 2008_006835__17 155 | 2008_007497__17 156 | 2008_007527__20 157 | 2008_007677__17 158 | 2008_007814__17 159 | 2008_007828__20 160 | 2008_008103__18 161 | 2008_008221__19 162 | 2008_008434__16 163 | 2009_000022__19 164 | 2009_000039__17 165 | 2009_000087__18 166 | 2009_000096__18 167 | 2009_000136__20 168 | 2009_000242__18 169 | 2009_000391__20 170 | 2009_000418__16 171 | 2009_000418__18 172 | 2009_000487__18 173 | 2009_000488__16 174 | 2009_000488__20 175 | 2009_000628__19 176 | 2009_000675__17 177 | 2009_000704__20 178 | 2009_000712__19 179 | 2009_000732__18 180 | 2009_000845__19 181 | 2009_000924__17 182 | 2009_001300__19 183 | 2009_001333__19 184 | 2009_001363__20 185 | 2009_001505__17 186 | 2009_001644__16 187 | 2009_001644__18 188 | 2009_001644__20 189 | 2009_001684__16 190 | 2009_001731__18 191 | 2009_001768__17 192 | 2009_001775__16 193 | 2009_001775__18 194 | 2009_001991__17 195 | 2009_002082__17 196 | 2009_002094__20 197 | 2009_002202__19 198 | 2009_002265__19 199 | 2009_002291__19 200 | 2009_002346__18 201 | 2009_002366__20 202 | 2009_002390__18 203 | 2009_002487__16 204 | 2009_002562__20 205 | 2009_002568__19 206 | 2009_002571__16 207 | 2009_002571__18 208 | 2009_002573__20 209 | 2009_002584__16 210 | 2009_002638__19 211 | 2009_002732__18 212 | 2009_002887__19 213 | 2009_002982__19 214 | 2009_003105__19 215 | 2009_003123__18 216 | 2009_003299__19 217 | 2009_003311__19 218 | 2009_003433__19 219 | 2009_003523__20 220 | 2009_003551__20 221 | 2009_003564__16 222 | 2009_003564__18 223 | 2009_003607__18 224 | 2009_003666__17 225 | 2009_003857__20 226 | 2009_003895__18 227 | 2009_003895__20 228 | 2009_003938__19 229 | 2009_004099__18 230 | 2009_004140__18 231 | 2009_004255__19 232 | 2009_004298__18 233 | 2009_004687__18 234 | 2009_004730__19 235 | 2009_004799__19 236 | 2009_004993__18 237 | 2009_004993__20 238 | 2009_005148__19 239 | 2009_005220__19 240 | 2010_000256__18 241 | 2010_000284__18 242 | 2010_000309__17 243 | 2010_000318__20 244 | 2010_000330__16 245 | 2010_000639__16 246 | 2010_000738__20 247 | 2010_000764__19 248 | 2010_001011__17 249 | 2010_001079__17 250 | 2010_001104__19 251 | 2010_001149__18 252 | 2010_001151__19 253 | 2010_001246__16 254 | 2010_001256__17 255 | 2010_001327__18 256 | 2010_001367__20 257 | 2010_001522__17 258 | 2010_001557__17 259 | 2010_001577__17 260 | 2010_001699__16 261 | 2010_001734__19 262 | 2010_001752__20 263 | 2010_001767__18 264 | 2010_001773__16 265 | 2010_001851__16 266 | 2010_001951__19 267 | 2010_001962__18 268 | 2010_002106__17 269 | 2010_002137__16 270 | 2010_002137__18 271 | 2010_002232__17 272 | 2010_002531__18 273 | 2010_002682__19 274 | 2010_002921__20 275 | 2010_003014__18 276 | 2010_003123__16 277 | 2010_003302__16 278 | 2010_003514__19 279 | 2010_003541__17 280 | 2010_003597__18 281 | 2010_003781__16 282 | 2010_003956__19 283 | 2010_004149__19 284 | 2010_004226__17 285 | 2010_004382__16 286 | 2010_004479__20 287 | 2010_004757__16 288 | 2010_004757__18 289 | 2010_004783__18 290 | 2010_004825__16 291 | 2010_004857__20 292 | 2010_004951__19 293 | 2010_004980__19 294 | 2010_005180__18 295 | 2010_005187__16 296 | 2010_005305__20 297 | 2010_005606__18 298 | 2010_005706__19 299 | 2010_005719__17 300 | 2010_005727__19 301 | 2010_005788__17 302 | 2010_005860__16 303 | 2010_005871__19 304 | 2010_005991__18 305 | 2010_006054__19 306 | 2011_000070__18 307 | 2011_000173__18 308 | 2011_000283__19 309 | 2011_000291__19 310 | 2011_000310__18 311 | 2011_000436__17 312 | 2011_000521__19 313 | 2011_000747__16 314 | 2011_001005__18 315 | 2011_001060__19 316 | 2011_001281__19 317 | 2011_001350__17 318 | 2011_001567__18 319 | 2011_001601__18 320 | 2011_001614__19 321 | 2011_001674__18 322 | 2011_001713__16 323 | 2011_001713__18 324 | 2011_001726__20 325 | 2011_001794__18 326 | 2011_001862__18 327 | 2011_001863__16 328 | 2011_001910__20 329 | 2011_002124__18 330 | 2011_002156__20 331 | 2011_002178__17 332 | 2011_002247__19 333 | 2011_002379__19 334 | 2011_002391__18 335 | 2011_002532__20 336 | 2011_002535__19 337 | 2011_002644__18 338 | 2011_002644__20 339 | 2011_002879__18 340 | 2011_002879__20 341 | 2011_003103__16 342 | 2011_003103__18 343 | 2011_003146__19 344 | 2011_003182__18 345 | 2011_003197__19 346 | 2011_003256__18 347 | -------------------------------------------------------------------------------- /data/pascal.py: -------------------------------------------------------------------------------- 1 | r""" PASCAL-5i few-shot semantic segmentation dataset """ 2 | import os 3 | 4 | from torch.utils.data import Dataset 5 | import torch.nn.functional as F 6 | import torch 7 | import PIL.Image as Image 8 | import numpy as np 9 | import random 10 | 11 | class DatasetPASCAL(Dataset): 12 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 13 | self.split = 'val' if split in ['val', 'test'] else 'trn' 14 | self.fold = fold 15 | self.nfolds = 4 16 | self.nclass = 20 17 | self.benchmark = 'pascal' 18 | self.shot = shot 19 | self.use_original_imgsize = use_original_imgsize 20 | 21 | self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/') 22 | self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/') 23 | self.transform = transform 24 | 25 | self.class_ids = self.build_class_ids() 26 | print(self.class_ids) 27 | self.img_metadata,self.img_metadata_val = self.build_img_metadata() 28 | self.img_metadata_classwise = self.build_img_metadata_classwise() 29 | 30 | def __len__(self): 31 | return len(self.img_metadata) if self.split == 'trn' else 1000 32 | 33 | def __getitem__(self, idx): 34 | idx %= len(self.img_metadata) # for testing, as n_images < 1000 35 | query_name, support_names, class_sample = self.sample_episode(idx) 36 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names) 37 | 38 | query_img = self.transform(query_img) 39 | if not self.use_original_imgsize: 40 | query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 41 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample) 42 | 43 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 44 | 45 | support_masks = [] 46 | support_ignore_idxs = [] 47 | for scmask in support_cmasks: 48 | scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 49 | support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample) 50 | support_masks.append(support_mask) 51 | support_ignore_idxs.append(support_ignore_idx) 52 | support_masks = torch.stack(support_masks) 53 | support_ignore_idxs = torch.stack(support_ignore_idxs) 54 | batch = {'query_img': query_img, 55 | 'query_mask': query_mask, 56 | 'query_name': query_name, 57 | 'query_ignore_idx': query_ignore_idx, 58 | 59 | 'org_query_imsize': org_qry_imsize, 60 | 61 | 'support_imgs': support_imgs, 62 | 'support_masks': support_masks, 63 | 'support_names': support_names, 64 | 'support_ignore_idxs': support_ignore_idxs, 65 | 66 | 'class_id': torch.tensor(class_sample)} 67 | 68 | return batch 69 | 70 | def extract_ignore_idx(self, mask, class_id): 71 | boundary = (mask / 255).floor() 72 | mask[mask != class_id + 1] = 0 73 | mask[mask == class_id + 1] = 1 74 | 75 | return mask, boundary 76 | 77 | def load_frame(self, query_name, support_names): 78 | query_img = self.read_img(query_name) 79 | query_mask = self.read_mask(query_name) 80 | support_imgs = [self.read_img(name) for name in support_names] 81 | support_masks = [self.read_mask(name) for name in support_names] 82 | 83 | org_qry_imsize = query_img.size 84 | 85 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize 86 | 87 | def read_mask(self, img_name): 88 | r"""Return segmentation mask in PIL Image""" 89 | mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png'))) 90 | return mask 91 | 92 | def read_img(self, img_name): 93 | r"""Return RGB image in PIL Image""" 94 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg') 95 | 96 | def sample_episode(self, idx): 97 | query_name, class_sample = self.img_metadata[idx] 98 | 99 | support_names = [] 100 | while True: # keep sampling support set if query == support 101 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 102 | if query_name != support_name: support_names.append(support_name) 103 | if len(support_names) == self.shot: break 104 | 105 | return query_name, support_names, class_sample 106 | 107 | def build_class_ids(self): 108 | nclass_trn = self.nclass // self.nfolds 109 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)] 110 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 111 | if self.split == 'trn': 112 | return class_ids_trn 113 | else: 114 | return class_ids_val 115 | 116 | def build_img_metadata(self): 117 | 118 | def read_metadata(split, fold_id): 119 | fold_n_metadata = os.path.join('data/splits/pascal/%s/fold%d.txt' % (split, fold_id)) 120 | with open(fold_n_metadata, 'r') as f: 121 | fold_n_metadata = f.read().split('\n')[:-1] 122 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata] 123 | return fold_n_metadata 124 | 125 | img_metadata = [] 126 | val_metadata = [] 127 | if self.split == 'trn': # For training, read image-metadata of "the other" folds 128 | for fold_id in range(self.nfolds): 129 | if fold_id == self.fold: # Skip validation fold 130 | continue 131 | img_metadata += read_metadata(self.split, fold_id) 132 | val_metadata+=read_metadata('trn', self.fold) 133 | elif self.split == 'val': # For validation, read image-metadata of "current" fold 134 | img_metadata = read_metadata(self.split, self.fold) 135 | else: 136 | raise Exception('Undefined split %s: ' % self.split) 137 | 138 | print('Total (%s) images are : %d' % (self.split, len(img_metadata))) 139 | 140 | return img_metadata,val_metadata 141 | 142 | def build_img_metadata_classwise(self): 143 | img_metadata_classwise = {} 144 | img_metadata_classwise_val =set([items[0] for items in self.img_metadata_val]) 145 | meta_data=[] 146 | for class_id in range(self.nclass): 147 | img_metadata_classwise[class_id] = [] 148 | 149 | for img_name, img_class in self.img_metadata: 150 | if img_name in img_metadata_classwise_val: 151 | continue 152 | img_metadata_classwise[img_class] += [img_name] 153 | if self.split=='trn': 154 | for key in self.class_ids: 155 | print(key,":",len(img_metadata_classwise[key])) 156 | img_meta=random.choices(img_metadata_classwise[key],k=750) 157 | meta_data+=[[img,key] for img in img_meta] 158 | self.img_metadata=meta_data 159 | random.shuffle(self.img_metadata) 160 | return img_metadata_classwise 161 | -------------------------------------------------------------------------------- /data/splits/fss/trn.txt: -------------------------------------------------------------------------------- 1 | fountain 2 | taxi 3 | assult_rifle 4 | radio 5 | comb 6 | box_turtle 7 | igloo 8 | head_cabbage 9 | cottontail 10 | coho 11 | ashtray 12 | joystick 13 | sleeping_bag 14 | jackfruit 15 | trailer_truck 16 | shower_cap 17 | ibex 18 | kinguin 19 | squirrel 20 | ac_wall 21 | sidewinder 22 | remote_control 23 | marshmallow 24 | bolotie 25 | polar_bear 26 | rock_beauty 27 | tokyo_tower 28 | wafer 29 | red_bayberry 30 | electronic_toothbrush 31 | hartebeest 32 | cassette 33 | oil_filter 34 | bomb 35 | walnut 36 | toilet_tissue 37 | memory_stick 38 | wild_boar 39 | cableways 40 | chihuahua 41 | envelope 42 | bison 43 | poker 44 | pubg_lvl3helmet 45 | indian_cobra 46 | staffordshire 47 | park_bench 48 | wombat 49 | black_grouse 50 | submarine 51 | washer 52 | agama 53 | coyote 54 | feeder 55 | sarong 56 | buckingham_palace 57 | frog 58 | steam_locomotive 59 | acorn 60 | german_pointer 61 | obelisk 62 | polecat 63 | black_swan 64 | butterfly 65 | mountain_tent 66 | gorilla 67 | sloth_bear 68 | aubergine 69 | stinkhorn 70 | stole 71 | owl 72 | mooli 73 | pool_table 74 | collar 75 | lhasa_apso 76 | ambulance 77 | spade 78 | pufferfish 79 | paint_brush 80 | lark 81 | golf_ball 82 | hock 83 | fork 84 | drake 85 | bee_house 86 | mooncake 87 | wok 88 | cocacola 89 | water_bike 90 | ladder 91 | psp 92 | bassoon 93 | bear 94 | border_terrier 95 | petri_dish 96 | pill_bottle 97 | aircraft_carrier 98 | panther 99 | canoe 100 | baseball_player 101 | turtle 102 | espresso 103 | throne 104 | cornet 105 | coucal 106 | eletrical_switch 107 | bra 108 | snail 109 | backpack 110 | jacamar 111 | scroll_brush 112 | gliding_lizard 113 | raft 114 | pinwheel 115 | grasshopper 116 | green_mamba 117 | eft_newt 118 | computer_mouse 119 | vine_snake 120 | recreational_vehicle 121 | llama 122 | meerkat 123 | chainsaw 124 | ferret 125 | garbage_can 126 | kangaroo 127 | litchi 128 | carbonara 129 | housefinch 130 | modem 131 | tebby_cat 132 | thatch 133 | face_powder 134 | tomb 135 | apple 136 | ladybug 137 | killer_whale 138 | rocket 139 | airship 140 | surfboard 141 | lesser_panda 142 | jordan_logo 143 | banana 144 | nail_scissor 145 | swab 146 | perfume 147 | punching_bag 148 | victor_icon 149 | waffle_iron 150 | trimaran 151 | garlic 152 | flute 153 | langur 154 | starfish 155 | parallel_bars 156 | dandie_dinmont 157 | cosmetic_brush 158 | screwdriver 159 | brick_card 160 | balance_weight 161 | hornet 162 | carton 163 | toothpaste 164 | bracelet 165 | egg_tart 166 | pencil_sharpener2 167 | swimming_glasses 168 | howler_monkey 169 | camel 170 | dragonfly 171 | lionfish 172 | convertible 173 | mule 174 | usb 175 | conch 176 | papaya 177 | garbage_truck 178 | dingo 179 | radiator 180 | solar_dish 181 | streetcar 182 | trilobite 183 | bouzouki 184 | ringlet_butterfly 185 | space_shuttle 186 | waffle 187 | american_staffordshire 188 | violin 189 | flowerpot 190 | forklift 191 | manx 192 | sundial 193 | snowmobile 194 | chickadee_bird 195 | ruffed_grouse 196 | brick_tea 197 | paddle 198 | stove 199 | carousel 200 | spatula 201 | beaker 202 | gas_pump 203 | lawn_mower 204 | speaker 205 | tank 206 | tresher 207 | kappa_logo 208 | hare 209 | tennis_racket 210 | shopping_cart 211 | thimble 212 | tractor 213 | anemone_fish 214 | trolleybus 215 | steak 216 | capuchin 217 | red_breasted_merganser 218 | golden_retriever 219 | light_tube 220 | flatworm 221 | melon_seed 222 | digital_watch 223 | jacko_lantern 224 | brown_bear 225 | cairn 226 | mushroom 227 | chalk 228 | skull 229 | stapler 230 | potato 231 | telescope 232 | proboscis 233 | microphone 234 | torii 235 | baseball_bat 236 | dhole 237 | excavator 238 | fig 239 | snake 240 | bradypod 241 | pepitas 242 | prairie_chicken 243 | scorpion 244 | shotgun 245 | bottle_cap 246 | file_cabinet 247 | grey_whale 248 | one-armed_bandit 249 | banded_gecko 250 | flying_disc 251 | croissant 252 | toothbrush 253 | miniskirt 254 | pokermon_ball 255 | gazelle 256 | grey_fox 257 | esport_chair 258 | necklace 259 | ptarmigan 260 | watermelon 261 | besom 262 | pomelo 263 | radio_telescope 264 | studio_couch 265 | black_stork 266 | vestment 267 | koala 268 | brambling 269 | muscle_car 270 | window_shade 271 | space_heater 272 | sunglasses 273 | motor_scooter 274 | ladyfinger 275 | pencil_box 276 | titi_monkey 277 | chicken_wings 278 | mount_fuji 279 | giant_panda 280 | dart 281 | fire_engine 282 | running_shoe 283 | dumbbell 284 | donkey 285 | loafer 286 | hard_disk 287 | globe 288 | lifeboat 289 | medical_kit 290 | brain_coral 291 | paper_towel 292 | dugong 293 | seatbelt 294 | skunk 295 | military_vest 296 | cocktail_shaker 297 | zucchini 298 | quad_drone 299 | ocicat 300 | shih-tzu 301 | teapot 302 | tile_roof 303 | cheese_burger 304 | handshower 305 | red_wolf 306 | stop_sign 307 | mouse 308 | battery 309 | adidas_logo2 310 | earplug 311 | hummingbird 312 | brush_pen 313 | pistachio 314 | hamster 315 | air_strip 316 | indian_elephant 317 | otter 318 | cucumber 319 | scabbard 320 | hawthorn 321 | bullet_train 322 | leopard 323 | whale 324 | cream 325 | chinese_date 326 | jellyfish 327 | lobster 328 | skua 329 | single_log 330 | chicory 331 | bagel 332 | beacon 333 | pingpong_racket 334 | spoon 335 | yurt 336 | wallaby 337 | egret 338 | christmas_stocking 339 | mcdonald_uncle 340 | wrench 341 | spark_plug 342 | triceratops 343 | wall_clock 344 | jinrikisha 345 | pickup 346 | rhinoceros 347 | swimming_trunk 348 | band-aid 349 | spotted_salamander 350 | leeks 351 | marmot 352 | warthog 353 | cello 354 | stool 355 | chest 356 | toilet_plunger 357 | wardrobe 358 | cannon 359 | adidas_logo1 360 | drumstick 361 | lady_slipper 362 | puma_logo 363 | great_wall 364 | white_shark 365 | witch_hat 366 | vending_machine 367 | wreck 368 | chopsticks 369 | garfish 370 | african_elephant 371 | children_slide 372 | hornbill 373 | zebra 374 | boa_constrictor 375 | armour 376 | pineapple 377 | angora 378 | brick 379 | car_wheel 380 | wallet 381 | boston_bull 382 | hyena 383 | lynx 384 | crash_helmet 385 | terrapin_turtle 386 | persian_cat 387 | shift_gear 388 | cactus_ball 389 | fur_coat 390 | plate 391 | pen 392 | okra 393 | mario 394 | airedale 395 | cowboy_hat 396 | celery 397 | macaque 398 | candle 399 | goose 400 | raccoon 401 | brasscica 402 | almond 403 | maotai_bottle 404 | soccer_ball 405 | sports_car 406 | tobacco_pipe 407 | water_polo 408 | eggnog 409 | hook 410 | ostrich 411 | patas 412 | table_lamp 413 | teddy 414 | mongoose 415 | spoonbill 416 | redheart 417 | crane 418 | dinosaur 419 | kitchen_knife 420 | seal 421 | baboon 422 | golfcart 423 | roller_coaster 424 | avocado 425 | birdhouse 426 | yorkshire_terrier 427 | saluki 428 | basketball 429 | buckler 430 | harvester 431 | afghan_hound 432 | beam_bridge 433 | guinea_pig 434 | lorikeet 435 | shakuhachi 436 | motarboard 437 | statue_liberty 438 | police_car 439 | sulphur_crested 440 | gourd 441 | sombrero 442 | mailbox 443 | adhensive_tape 444 | night_snake 445 | bushtit 446 | mouthpiece 447 | beaver 448 | bathtub 449 | printer 450 | cumquat 451 | orange 452 | cleaver 453 | quill_pen 454 | panpipe 455 | diamond 456 | gypsy_moth 457 | cauliflower 458 | lampshade 459 | cougar 460 | traffic_light 461 | briefcase 462 | ballpoint 463 | african_grey 464 | kremlin 465 | barometer 466 | peacock 467 | paper_crane 468 | sunscreen 469 | tofu 470 | bedlington_terrier 471 | snowball 472 | carrot 473 | tiger 474 | mink 475 | cristo_redentor 476 | ladle 477 | keyboard 478 | maraca 479 | monitor 480 | water_snake 481 | can_opener 482 | mud_turtle 483 | bald_eagle 484 | carp 485 | cn_tower 486 | egyptian_cat 487 | hen_of_the_woods 488 | measuring_cup 489 | roller_skate 490 | kite 491 | sandwich_cookies 492 | sandwich 493 | persimmon 494 | chess_bishop 495 | coffin 496 | ruddy_turnstone 497 | prayer_rug 498 | rain_barrel 499 | neck_brace 500 | nematode 501 | rosehip 502 | dutch_oven 503 | goldfish 504 | blossom_card 505 | dough 506 | trench_coat 507 | sponge 508 | stupa 509 | wash_basin 510 | electric_fan 511 | spring_scroll 512 | potted_plant 513 | sparrow 514 | car_mirror 515 | gecko 516 | diaper 517 | leatherback_turtle 518 | strainer 519 | guacamole 520 | microwave 521 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold1.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003201__10 35 | 2007_003503__06 36 | 2007_003503__07 37 | 2007_003621__06 38 | 2007_003711__06 39 | 2007_003786__06 40 | 2007_003841__10 41 | 2007_003917__07 42 | 2007_003991__08 43 | 2007_004193__09 44 | 2007_004392__09 45 | 2007_004405__09 46 | 2007_004510__09 47 | 2007_004712__09 48 | 2007_004856__08 49 | 2007_004866__08 50 | 2007_005074__07 51 | 2007_005114__10 52 | 2007_005296__07 53 | 2007_005331__07 54 | 2007_005460__08 55 | 2007_005547__07 56 | 2007_005547__10 57 | 2007_005844__09 58 | 2007_005845__08 59 | 2007_005911__06 60 | 2007_005978__06 61 | 2007_006035__07 62 | 2007_006086__09 63 | 2007_006241__09 64 | 2007_006260__08 65 | 2007_006277__07 66 | 2007_006348__09 67 | 2007_006553__09 68 | 2007_006761__10 69 | 2007_006841__10 70 | 2007_007414__07 71 | 2007_007417__08 72 | 2007_007524__08 73 | 2007_007815__07 74 | 2007_007818__07 75 | 2007_007996__09 76 | 2007_008106__09 77 | 2007_008110__09 78 | 2007_008543__09 79 | 2007_008722__10 80 | 2007_008747__06 81 | 2007_008815__08 82 | 2007_008897__09 83 | 2007_008973__10 84 | 2007_009015__06 85 | 2007_009015__07 86 | 2007_009068__09 87 | 2007_009084__09 88 | 2007_009096__07 89 | 2007_009221__08 90 | 2007_009245__10 91 | 2007_009346__08 92 | 2007_009392__06 93 | 2007_009392__07 94 | 2007_009413__09 95 | 2007_009521__09 96 | 2007_009764__06 97 | 2007_009794__08 98 | 2007_009897__10 99 | 2007_009923__08 100 | 2007_009938__07 101 | 2008_000009__10 102 | 2008_000073__10 103 | 2008_000075__06 104 | 2008_000107__09 105 | 2008_000149__09 106 | 2008_000182__08 107 | 2008_000345__08 108 | 2008_000401__08 109 | 2008_000464__08 110 | 2008_000501__07 111 | 2008_000673__09 112 | 2008_000853__08 113 | 2008_000919__10 114 | 2008_001078__08 115 | 2008_001433__08 116 | 2008_001439__09 117 | 2008_001513__08 118 | 2008_001640__08 119 | 2008_001715__09 120 | 2008_001885__08 121 | 2008_002152__08 122 | 2008_002205__06 123 | 2008_002212__07 124 | 2008_002379__09 125 | 2008_002521__09 126 | 2008_002623__08 127 | 2008_002681__08 128 | 2008_002778__10 129 | 2008_002958__07 130 | 2008_003141__06 131 | 2008_003141__07 132 | 2008_003333__07 133 | 2008_003477__09 134 | 2008_003499__08 135 | 2008_003577__07 136 | 2008_003777__06 137 | 2008_003821__09 138 | 2008_003846__07 139 | 2008_004069__07 140 | 2008_004339__07 141 | 2008_004552__07 142 | 2008_004612__09 143 | 2008_004701__10 144 | 2008_005097__10 145 | 2008_005105__10 146 | 2008_005245__07 147 | 2008_005676__06 148 | 2008_006008__09 149 | 2008_006063__10 150 | 2008_006254__07 151 | 2008_006325__08 152 | 2008_006341__08 153 | 2008_006480__08 154 | 2008_006528__10 155 | 2008_006554__06 156 | 2008_006986__07 157 | 2008_007025__10 158 | 2008_007031__10 159 | 2008_007048__09 160 | 2008_007123__10 161 | 2008_007194__09 162 | 2008_007273__10 163 | 2008_007378__09 164 | 2008_007402__09 165 | 2008_007527__09 166 | 2008_007548__08 167 | 2008_007596__10 168 | 2008_007737__09 169 | 2008_007797__06 170 | 2008_007804__07 171 | 2008_007828__09 172 | 2008_008252__06 173 | 2008_008301__06 174 | 2008_008469__06 175 | 2008_008682__06 176 | 2009_000013__08 177 | 2009_000080__08 178 | 2009_000219__10 179 | 2009_000309__10 180 | 2009_000335__06 181 | 2009_000335__07 182 | 2009_000426__06 183 | 2009_000455__06 184 | 2009_000457__07 185 | 2009_000523__07 186 | 2009_000641__10 187 | 2009_000716__08 188 | 2009_000731__10 189 | 2009_000771__10 190 | 2009_000825__07 191 | 2009_000964__08 192 | 2009_001008__08 193 | 2009_001082__06 194 | 2009_001240__07 195 | 2009_001255__07 196 | 2009_001299__09 197 | 2009_001391__08 198 | 2009_001411__08 199 | 2009_001536__07 200 | 2009_001775__09 201 | 2009_001804__06 202 | 2009_001816__06 203 | 2009_001854__06 204 | 2009_002035__10 205 | 2009_002122__10 206 | 2009_002150__10 207 | 2009_002164__07 208 | 2009_002171__10 209 | 2009_002221__10 210 | 2009_002238__06 211 | 2009_002238__07 212 | 2009_002239__07 213 | 2009_002268__08 214 | 2009_002346__09 215 | 2009_002415__09 216 | 2009_002487__09 217 | 2009_002527__08 218 | 2009_002535__06 219 | 2009_002549__10 220 | 2009_002571__09 221 | 2009_002618__07 222 | 2009_002635__10 223 | 2009_002753__08 224 | 2009_002936__08 225 | 2009_002990__07 226 | 2009_003003__07 227 | 2009_003059__10 228 | 2009_003071__09 229 | 2009_003269__07 230 | 2009_003304__06 231 | 2009_003387__07 232 | 2009_003406__07 233 | 2009_003494__09 234 | 2009_003507__09 235 | 2009_003542__10 236 | 2009_003549__07 237 | 2009_003569__10 238 | 2009_003589__07 239 | 2009_003703__06 240 | 2009_003771__08 241 | 2009_003773__10 242 | 2009_003849__09 243 | 2009_003895__09 244 | 2009_003904__08 245 | 2009_004072__06 246 | 2009_004140__09 247 | 2009_004217__09 248 | 2009_004248__08 249 | 2009_004455__07 250 | 2009_004504__08 251 | 2009_004590__06 252 | 2009_004594__07 253 | 2009_004687__09 254 | 2009_004721__08 255 | 2009_004732__06 256 | 2009_004748__07 257 | 2009_004789__06 258 | 2009_004859__09 259 | 2009_004867__06 260 | 2009_005158__08 261 | 2009_005219__08 262 | 2009_005231__06 263 | 2010_000003__09 264 | 2010_000160__07 265 | 2010_000163__08 266 | 2010_000372__07 267 | 2010_000427__10 268 | 2010_000530__07 269 | 2010_000552__08 270 | 2010_000573__06 271 | 2010_000628__07 272 | 2010_000639__09 273 | 2010_000682__06 274 | 2010_000683__08 275 | 2010_000724__08 276 | 2010_000907__10 277 | 2010_000941__08 278 | 2010_000952__07 279 | 2010_001000__10 280 | 2010_001010__10 281 | 2010_001070__08 282 | 2010_001206__06 283 | 2010_001292__08 284 | 2010_001331__08 285 | 2010_001351__08 286 | 2010_001403__06 287 | 2010_001403__07 288 | 2010_001534__08 289 | 2010_001553__07 290 | 2010_001579__09 291 | 2010_001646__06 292 | 2010_001656__08 293 | 2010_001692__10 294 | 2010_001699__09 295 | 2010_001767__07 296 | 2010_001851__09 297 | 2010_001913__08 298 | 2010_002017__07 299 | 2010_002017__09 300 | 2010_002025__08 301 | 2010_002137__08 302 | 2010_002146__08 303 | 2010_002305__08 304 | 2010_002336__09 305 | 2010_002348__08 306 | 2010_002361__07 307 | 2010_002390__10 308 | 2010_002422__08 309 | 2010_002512__08 310 | 2010_002531__08 311 | 2010_002546__06 312 | 2010_002623__09 313 | 2010_002693__08 314 | 2010_002693__09 315 | 2010_002763__08 316 | 2010_002763__10 317 | 2010_002868__06 318 | 2010_002900__08 319 | 2010_002902__07 320 | 2010_002921__09 321 | 2010_002929__07 322 | 2010_002988__07 323 | 2010_003123__07 324 | 2010_003183__10 325 | 2010_003231__07 326 | 2010_003239__10 327 | 2010_003275__08 328 | 2010_003276__07 329 | 2010_003293__06 330 | 2010_003302__09 331 | 2010_003325__09 332 | 2010_003381__07 333 | 2010_003402__08 334 | 2010_003409__09 335 | 2010_003446__07 336 | 2010_003453__07 337 | 2010_003468__08 338 | 2010_003531__09 339 | 2010_003675__08 340 | 2010_003746__07 341 | 2010_003758__08 342 | 2010_003764__08 343 | 2010_003768__07 344 | 2010_003772__06 345 | 2010_003781__08 346 | 2010_003813__07 347 | 2010_003854__07 348 | 2010_003971__08 349 | 2010_003971__09 350 | 2010_004104__08 351 | 2010_004120__08 352 | 2010_004320__08 353 | 2010_004322__10 354 | 2010_004348__06 355 | 2010_004369__08 356 | 2010_004472__07 357 | 2010_004479__08 358 | 2010_004635__10 359 | 2010_004763__09 360 | 2010_004783__09 361 | 2010_004789__10 362 | 2010_004815__08 363 | 2010_004825__09 364 | 2010_004861__08 365 | 2010_004946__07 366 | 2010_005013__07 367 | 2010_005021__08 368 | 2010_005021__09 369 | 2010_005063__06 370 | 2010_005108__08 371 | 2010_005118__06 372 | 2010_005160__06 373 | 2010_005166__10 374 | 2010_005284__06 375 | 2010_005344__08 376 | 2010_005421__08 377 | 2010_005432__07 378 | 2010_005501__07 379 | 2010_005508__08 380 | 2010_005606__08 381 | 2010_005709__08 382 | 2010_005718__07 383 | 2010_005860__07 384 | 2010_005899__08 385 | 2010_006070__07 386 | 2011_000178__06 387 | 2011_000226__09 388 | 2011_000239__06 389 | 2011_000248__06 390 | 2011_000312__06 391 | 2011_000338__09 392 | 2011_000419__08 393 | 2011_000503__07 394 | 2011_000548__10 395 | 2011_000566__10 396 | 2011_000607__09 397 | 2011_000661__08 398 | 2011_000661__09 399 | 2011_000780__08 400 | 2011_000789__08 401 | 2011_000809__09 402 | 2011_000813__08 403 | 2011_000813__09 404 | 2011_000830__06 405 | 2011_000843__09 406 | 2011_000888__06 407 | 2011_000900__07 408 | 2011_000969__06 409 | 2011_001047__10 410 | 2011_001064__06 411 | 2011_001071__09 412 | 2011_001110__07 413 | 2011_001159__10 414 | 2011_001232__10 415 | 2011_001292__08 416 | 2011_001341__06 417 | 2011_001346__09 418 | 2011_001447__09 419 | 2011_001530__10 420 | 2011_001534__08 421 | 2011_001546__10 422 | 2011_001567__09 423 | 2011_001597__08 424 | 2011_001601__08 425 | 2011_001607__08 426 | 2011_001665__09 427 | 2011_001708__10 428 | 2011_001775__08 429 | 2011_001782__10 430 | 2011_001812__09 431 | 2011_002041__09 432 | 2011_002064__07 433 | 2011_002124__09 434 | 2011_002200__09 435 | 2011_002298__09 436 | 2011_002322__07 437 | 2011_002343__09 438 | 2011_002358__09 439 | 2011_002391__09 440 | 2011_002509__09 441 | 2011_002592__07 442 | 2011_002644__09 443 | 2011_002685__08 444 | 2011_002812__07 445 | 2011_002885__10 446 | 2011_003011__09 447 | 2011_003019__07 448 | 2011_003019__10 449 | 2011_003055__07 450 | 2011_003103__09 451 | 2011_003114__06 452 | -------------------------------------------------------------------------------- /model/base/merge_cor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Sequential): 6 | def __init__(self, inCh, outCh): 7 | super(BasicConv, self).__init__() 8 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=3, padding=1) 9 | self.bn0 = nn.BatchNorm2d(outCh) 10 | self.relu0 = nn.ReLU() 11 | class BasicConv1x1(nn.Sequential): 12 | def __init__(self, inCh, outCh): 13 | super(BasicConv1x1, self).__init__() 14 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=1) 15 | self.bn0 = nn.BatchNorm2d(outCh) 16 | self.relu0 = nn.ReLU() 17 | self.conv1 = nn.Conv2d(outCh, outCh, kernel_size=3, padding=1) 18 | self.bn1 = nn.BatchNorm2d(outCh) 19 | self.relu1 = nn.ReLU() 20 | #self.drop=nn.Dropout2d(0.3) 21 | 22 | class ShotConv(nn.Module): 23 | def __init__(self, inCh, outCh): 24 | super(ShotConv, self).__init__() 25 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1), nn.BatchNorm2d(outCh),nn.ReLU()) 26 | self.conv1 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=2, padding=2),nn.BatchNorm2d(outCh), nn.ReLU()) 27 | self.conv2 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=4, padding=4),nn.BatchNorm2d(outCh),nn.ReLU()) 28 | self.conv3 = nn.Sequential(nn.Conv2d(outCh*3, outCh, kernel_size=3,padding=1), nn.BatchNorm2d(outCh),nn.ReLU()) 29 | self.drop=nn.Dropout2d(0.2) 30 | def forward(self,x): 31 | x0=self.conv0(x) 32 | x1=self.conv1(x) 33 | x2=self.conv2(x) 34 | x= self.conv3(torch.cat([x0,x1,x2],dim=1)) 35 | return self.drop(x) 36 | class Flatten(nn.Module): 37 | def forward(self, x): 38 | return x.flatten(start_dim=1) 39 | 40 | 41 | class ChannelGate(nn.Module): 42 | def __init__(self, gate_channels, reduction_ratio=4, pool_types=['avg', 'max']): 43 | super(ChannelGate, self).__init__() 44 | self.gate_channels = gate_channels 45 | self.mlp = nn.Sequential( 46 | Flatten(), 47 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 48 | nn.ReLU(), 49 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 50 | ) 51 | self.pool_types = pool_types 52 | 53 | def forward(self, x): 54 | channel_att_sum = None 55 | for pool_type in self.pool_types: 56 | if pool_type == 'avg': 57 | avg_pool = F.adaptive_avg_pool2d(x, (1, 1)) 58 | channel_att_raw = self.mlp(avg_pool) 59 | elif pool_type == 'max': 60 | max_pool = F.adaptive_max_pool2d(x, (1, 1)) 61 | channel_att_raw = self.mlp(max_pool) 62 | 63 | if channel_att_sum is None: 64 | channel_att_sum = channel_att_raw 65 | else: 66 | channel_att_sum = channel_att_sum + channel_att_raw 67 | # visMap.heatmap(scale, win='abc') 68 | # time.sleep(2) 69 | channel_att_sum = torch.sigmoid(channel_att_sum) 70 | scale = channel_att_sum.unsqueeze(2).unsqueeze(3) 71 | return x * scale 72 | 73 | 74 | class MergeConv1(nn.Module): 75 | def __init__(self, inplan1, outplan): 76 | super(MergeConv1, self).__init__() 77 | self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 78 | #self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 79 | self.chATT = ChannelGate(outplan, 4) 80 | self.merge = BasicConv(outplan, outplan) 81 | self.merge2 = BasicConv(outplan, outplan) 82 | self.drop = nn.Dropout2d(0.2) 83 | 84 | def forward(self, x1): 85 | x1 = self.conv1(x1) 86 | out = self.merge(x1) 87 | x = self.chATT(out) 88 | out = self.merge2(x) 89 | return self.drop(out) 90 | 91 | 92 | class MergeConv(nn.Module): 93 | def __init__(self, inplan1, inplan2, outplan): 94 | super(MergeConv, self).__init__() 95 | self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 96 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 97 | self.merge = BasicConv(2 * outplan, outplan) 98 | self.merge2 = BasicConv(outplan, outplan) 99 | self.drop = nn.Dropout2d(0.2) 100 | 101 | def forward(self, x1, x2): 102 | # shape=x1.shape 103 | x1 = self.conv1(x1) 104 | # x2= F.interpolate(x2, size=(shape[2], shape[3]), mode='bilinear', align_corners=True) 105 | x2 = self.conv2(x2) 106 | out = self.merge(torch.cat([x1, x2], dim=1)) 107 | out = self.merge2(out) 108 | return self.drop(out) 109 | 110 | 111 | class MergeDown(nn.Module): 112 | def __init__(self, inplan1, inplan2, outplan): 113 | super(MergeDown, self).__init__() 114 | self.conv1 = BasicConv(inplan1, outplan) 115 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 116 | self.merge = BasicConv(2 * outplan, outplan) 117 | self.convout = BasicConv(outplan, outplan) 118 | 119 | def forward(self, x1, x2): 120 | x1 = self.conv1(x1) 121 | x2 = self.conv2(x2) 122 | out = self.merge(torch.cat([x1, x2], dim=1))#+x1 123 | val = self.convout(out) 124 | return val 125 | 126 | class qsSim(nn.Module): 127 | def __init__(self, inplan): 128 | super(qsSim, self).__init__() 129 | self.conv2 = nn.Conv2d(inplan, 1, kernel_size=1,bias=False) 130 | def forward(self, x): 131 | x = self.conv2(x) 132 | return torch.sigmoid(x) 133 | class sqOut(nn.Module): 134 | def __init__(self, inCh, outCh): 135 | super(sqOut, self).__init__() 136 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1),nn.BatchNorm2d(outCh),nn.ReLU()) 137 | self.conv1=BasicConv(outCh,outCh) 138 | self.conv2=BasicConv(outCh,outCh) 139 | def forward(self,x): 140 | x0=self.conv0(x) 141 | x1=self.conv1(x0) 142 | x=self.conv2(x0+x1) 143 | return x 144 | 145 | class merge(nn.Module): 146 | def __init__(self, shot, nfeatures=[2048 * 2, 1024 * 2, 512 * 2], nsimlairy=[3, 6, 4],criter=None): 147 | super(merge, self).__init__() 148 | self.shot = shot 149 | self.nsimlairy = nsimlairy 150 | self.sim_layer_corrConv = [] 151 | self.criter=criter 152 | for num in nsimlairy: 153 | self.sim_layer_corrConv.append(BasicConv1x1(num, 256)) 154 | self.sim_layer_corrConv = nn.ModuleList(self.sim_layer_corrConv) 155 | 156 | self.simShotConv4 = BasicConv(256, 256) 157 | self.simShotConv3 = BasicConv(256, 256) 158 | self.simShotConv2 = BasicConv(256, 256) 159 | self.conv4 = MergeConv1(256, 256) 160 | self.conv3 = MergeConv1(256, 256) 161 | self.conv2 = MergeConv1(256, 256) 162 | self.conv43 = MergeConv(256, 256, 512) 163 | self.conv432 = MergeDown(512, 256, 512) 164 | self.decoder1 = sqOut(512, 256) 165 | 166 | self.decoder4 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 167 | nn.ReLU(), 168 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 169 | 170 | self.decoder3 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 171 | nn.ReLU(), 172 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 173 | 174 | self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 175 | nn.ReLU(), 176 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 177 | 178 | self.decoderOut = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 179 | nn.ReLU(), 180 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 181 | 182 | def forward(self, qur_sups, sims,gt=None): 183 | corr_sim = [] 184 | for s in range(self.shot): 185 | sup = qur_sups[s] 186 | simS = sims[s] 187 | simList = [] 188 | for l in range(len(sup)): # L4,L3,L2 189 | simL = self.sim_layer_corrConv[l](simS[l]) 190 | simList.append(simL) # L,b,128,h,w 191 | corr_sim.append(simList) 192 | 193 | corr_lySim = [corr_sim[0][i] for i in range(len(self.nsimlairy))] # l,b,128,h,w 194 | for ly in range(len(self.nsimlairy)): 195 | for s in range(1,self.shot): 196 | corr_lySim[ly]+=(corr_sim[s][ly]) 197 | x4 = self.conv4(self.simShotConv4(corr_lySim[0]/self.shot)) 198 | x3 = self.conv3(self.simShotConv3(corr_lySim[1]/self.shot)) 199 | x2 = self.conv2(self.simShotConv2(corr_lySim[2]/self.shot)) 200 | x4 = F.interpolate(x4, x3.size()[-2:], mode='bilinear', align_corners=True) 201 | x43 = self.conv43(x4, x3) 202 | x43 = F.interpolate(x43, x2.size()[-2:], mode='bilinear', align_corners=True) 203 | x432 = self.conv432(x43, x2) 204 | d1 = self.decoder1(x432) 205 | upsize = (d1.shape[-1] * 2,) * 2 206 | d1 = F.interpolate(d1, upsize, mode='bilinear', align_corners=True) 207 | d2 = self.decoderOut(d1) 208 | if self.training: 209 | lossSize=x432.size()[-2:] 210 | gt=gt.unsqueeze(1).float() 211 | gtOut = F.interpolate(gt, upsize, mode='nearest') 212 | gtOut=gtOut.squeeze(1) 213 | gtOut=gtOut.long() 214 | gt=F.interpolate(gt, lossSize, mode='nearest') 215 | gt=gt.squeeze(1) 216 | gt=gt.long() 217 | x2=self.decoder2(x2) 218 | x3 = self.decoder3(x3) 219 | x4 = self.decoder4(x4) 220 | x2=F.interpolate(x2, lossSize,mode='bilinear', align_corners=True) 221 | x3 = F.interpolate(x3, lossSize, mode='bilinear', align_corners=True) 222 | x4 = F.interpolate(x4, lossSize, mode='bilinear', align_corners=True) 223 | loss=0.3*(self.criter(x2,gt)+self.criter(x3,gt)+self.criter(x4,gt))+self.criter(d2,gtOut) 224 | return d2,loss 225 | else: 226 | gt = gt.unsqueeze(1).float() 227 | gtOut = F.interpolate(gt, upsize, mode='nearest') 228 | gtOut=gtOut.squeeze(1) 229 | gtOut = gtOut.long() 230 | loss=self.criter(d2, gtOut) 231 | return d2,loss 232 | -------------------------------------------------------------------------------- /model/base/merge_pro.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Sequential): 6 | def __init__(self, inCh, outCh): 7 | super(BasicConv, self).__init__() 8 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=3, padding=1) 9 | self.bn0 = nn.BatchNorm2d(outCh) 10 | self.relu0 = nn.ReLU() 11 | class BasicConv1x1(nn.Sequential): 12 | def __init__(self, inCh, outCh): 13 | super(BasicConv1x1, self).__init__() 14 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=1) 15 | self.bn0 = nn.BatchNorm2d(outCh) 16 | self.relu0 = nn.ReLU() 17 | self.conv1 = nn.Conv2d(outCh, outCh, kernel_size=3, padding=1) 18 | self.bn1 = nn.BatchNorm2d(outCh) 19 | self.relu1 = nn.ReLU() 20 | #self.drop=nn.Dropout2d(0.3) 21 | 22 | class ShotConv(nn.Module): 23 | def __init__(self, inCh, outCh): 24 | super(ShotConv, self).__init__() 25 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1), nn.BatchNorm2d(outCh),nn.ReLU()) 26 | self.conv1 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=2, padding=2),nn.BatchNorm2d(outCh), nn.ReLU()) 27 | self.conv2 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=4, padding=4),nn.BatchNorm2d(outCh),nn.ReLU()) 28 | self.conv3 = nn.Sequential(nn.Conv2d(outCh*3, outCh, kernel_size=3,padding=1), nn.BatchNorm2d(outCh),nn.ReLU()) 29 | self.drop=nn.Dropout2d(0.2) 30 | def forward(self,x): 31 | x0=self.conv0(x) 32 | x1=self.conv1(x) 33 | x2=self.conv2(x) 34 | x= self.conv3(torch.cat([x0,x1,x2],dim=1)) 35 | return self.drop(x) 36 | class Flatten(nn.Module): 37 | def forward(self, x): 38 | return x.flatten(start_dim=1) 39 | 40 | 41 | class ChannelGate(nn.Module): 42 | def __init__(self, gate_channels, reduction_ratio=4, pool_types=['avg', 'max']): 43 | super(ChannelGate, self).__init__() 44 | self.gate_channels = gate_channels 45 | self.mlp = nn.Sequential( 46 | Flatten(), 47 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 48 | nn.ReLU(), 49 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 50 | ) 51 | self.pool_types = pool_types 52 | 53 | def forward(self, x): 54 | channel_att_sum = None 55 | for pool_type in self.pool_types: 56 | if pool_type == 'avg': 57 | avg_pool = F.adaptive_avg_pool2d(x, (1, 1)) 58 | channel_att_raw = self.mlp(avg_pool) 59 | elif pool_type == 'max': 60 | max_pool = F.adaptive_max_pool2d(x, (1, 1)) 61 | channel_att_raw = self.mlp(max_pool) 62 | 63 | if channel_att_sum is None: 64 | channel_att_sum = channel_att_raw 65 | else: 66 | channel_att_sum = channel_att_sum + channel_att_raw 67 | # visMap.heatmap(scale, win='abc') 68 | # time.sleep(2) 69 | channel_att_sum = torch.sigmoid(channel_att_sum) 70 | scale = channel_att_sum.unsqueeze(2).unsqueeze(3) 71 | return x * scale 72 | 73 | 74 | class MergeConv1(nn.Module): 75 | def __init__(self, inplan2, outplan): 76 | super(MergeConv1, self).__init__() 77 | #self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 78 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 79 | self.chATT = ChannelGate(outplan, 4) 80 | self.merge = BasicConv(outplan, outplan) 81 | self.merge2 = BasicConv(outplan, outplan) 82 | self.drop = nn.Dropout2d(0.2) 83 | 84 | def forward(self, x2): 85 | #x1 = self.conv1(x1) 86 | x2 = self.conv2(x2) 87 | out = self.merge(x2) 88 | x = self.chATT(out) 89 | out = self.merge2(x) 90 | return self.drop(out) 91 | 92 | 93 | class MergeConv(nn.Module): 94 | def __init__(self, inplan1, inplan2, outplan): 95 | super(MergeConv, self).__init__() 96 | self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 97 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 98 | self.merge = BasicConv(2 * outplan, outplan) 99 | self.merge2 = BasicConv(outplan, outplan) 100 | self.drop = nn.Dropout2d(0.2) 101 | 102 | def forward(self, x1, x2): 103 | # shape=x1.shape 104 | x1 = self.conv1(x1) 105 | # x2= F.interpolate(x2, size=(shape[2], shape[3]), mode='bilinear', align_corners=True) 106 | x2 = self.conv2(x2) 107 | out = self.merge(torch.cat([x1, x2], dim=1)) 108 | out = self.merge2(out) 109 | return self.drop(out) 110 | 111 | 112 | class MergeDown(nn.Module): 113 | def __init__(self, inplan1, inplan2, outplan): 114 | super(MergeDown, self).__init__() 115 | self.conv1 = BasicConv(inplan1, outplan) 116 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 117 | self.merge = BasicConv(2 * outplan, outplan) 118 | self.convout = BasicConv(outplan, outplan) 119 | 120 | def forward(self, x1, x2): 121 | x1 = self.conv1(x1) 122 | x2 = self.conv2(x2) 123 | out = self.merge(torch.cat([x1, x2], dim=1))#+x1 124 | val = self.convout(out) 125 | return val 126 | 127 | class qsSim(nn.Module): 128 | def __init__(self, inplan): 129 | super(qsSim, self).__init__() 130 | self.conv2 = nn.Conv2d(inplan, 1, kernel_size=1,bias=False) 131 | def forward(self, x): 132 | x = self.conv2(x) 133 | return torch.sigmoid(x) 134 | class sqOut(nn.Module): 135 | def __init__(self, inCh, outCh): 136 | super(sqOut, self).__init__() 137 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1),nn.BatchNorm2d(outCh),nn.ReLU()) 138 | self.conv1=BasicConv(outCh,outCh) 139 | self.conv2=BasicConv(outCh,outCh) 140 | def forward(self,x): 141 | x0=self.conv0(x) 142 | x1=self.conv1(x0) 143 | x=self.conv2(x0+x1) 144 | return x 145 | 146 | class merge(nn.Module): 147 | def __init__(self, shot, nfeatures=[2048 * 2, 1024 * 2, 512 * 2], nsimlairy=[3, 6, 4],criter=None): 148 | super(merge, self).__init__() 149 | self.shot = shot 150 | self.nsimlairy = nsimlairy 151 | self.query_sup_corrConv = [] 152 | self.qs_layer_corrConv = [] 153 | self.qs_sim = [] 154 | self.criter=criter 155 | for num in nfeatures: 156 | self.query_sup_corrConv.append(qsSim(num)) 157 | for num in nsimlairy: 158 | self.qs_layer_corrConv.append(BasicConv1x1(num, 128)) 159 | self.query_sup_corrConv = nn.ModuleList(self.query_sup_corrConv) 160 | self.qs_layer_corrConv = nn.ModuleList(self.qs_layer_corrConv) 161 | self.sqShotConv4 = ShotConv(128, 128) 162 | self.sqShotConv3 = ShotConv(128, 128) 163 | self.sqShotConv2 = ShotConv(128, 128) 164 | 165 | self.conv4 = MergeConv1(128, 256) 166 | self.conv3 = MergeConv1(128, 256) 167 | self.conv2 = MergeConv1(128, 256) 168 | self.conv43 = MergeConv(256, 256, 512) 169 | self.conv432 = MergeDown(512, 256, 512) 170 | self.decoder1 = sqOut(512, 256) 171 | 172 | self.decoder4 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 173 | nn.ReLU(), 174 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 175 | 176 | self.decoder3 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 177 | nn.ReLU(), 178 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 179 | 180 | self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 181 | nn.ReLU(), 182 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 183 | 184 | self.decoderOut = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 185 | nn.ReLU(), 186 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 187 | 188 | def forward(self, qur_sups, sims,gt=None): 189 | qs_sim = [] 190 | corr_sim = [] 191 | for s in range(self.shot): 192 | sup = qur_sups[s] 193 | supsimList = [] 194 | for l in range(len(sup)): # L4,L3,L2 195 | supL = sup[l] 196 | supLsim = [] 197 | for i in range(len(supL)): 198 | supSim = self.query_sup_corrConv[l](supL[i]) # b,1,h,w 199 | supLsim.append(supSim) 200 | supLsim = torch.cat(supLsim, dim=1) # b,n,h,w 201 | supLsim = self.qs_layer_corrConv[l](supLsim) 202 | supsimList.append(supLsim) # L,b,128,h,w 203 | qs_sim.append(supsimList) # s,l,b,128,h,w 204 | qs_lySim = [qs_sim[0][i] for i in range(len(self.nsimlairy))] # l,b,128,h,w 205 | for ly in range(len(self.nsimlairy)): 206 | for s in range(1,self.shot): 207 | qs_lySim[ly]+=(qs_sim[s][ly]) 208 | x4 = self.conv4(self.sqShotConv4(qs_lySim[0]/self.shot)) 209 | x3 = self.conv3(self.sqShotConv3(qs_lySim[1]/self.shot)) 210 | x2 = self.conv2(self.sqShotConv2(qs_lySim[2]/self.shot)) 211 | x4 = F.interpolate(x4, x3.size()[-2:], mode='bilinear', align_corners=True) 212 | x43 = self.conv43(x4, x3) 213 | x43 = F.interpolate(x43, x2.size()[-2:], mode='bilinear', align_corners=True) 214 | x432 = self.conv432(x43, x2) 215 | d1 = self.decoder1(x432) 216 | upsize = (d1.shape[-1] * 2,) * 2 217 | d1 = F.interpolate(d1, upsize, mode='bilinear', align_corners=True) 218 | d2 = self.decoderOut(d1) 219 | if self.training: 220 | lossSize=x432.size()[-2:] 221 | gt=gt.unsqueeze(1).float() 222 | gtOut = F.interpolate(gt, upsize, mode='nearest') 223 | gtOut=gtOut.squeeze(1) 224 | gtOut=gtOut.long() 225 | gt=F.interpolate(gt, lossSize, mode='nearest') 226 | gt=gt.squeeze(1) 227 | gt=gt.long() 228 | x2=self.decoder2(x2) 229 | x3 = self.decoder3(x3) 230 | x4 = self.decoder4(x4) 231 | x2=F.interpolate(x2, lossSize,mode='bilinear', align_corners=True) 232 | x3 = F.interpolate(x3, lossSize, mode='bilinear', align_corners=True) 233 | x4 = F.interpolate(x4, lossSize, mode='bilinear', align_corners=True) 234 | loss=0.3*(self.criter(x2,gt)+self.criter(x3,gt)+self.criter(x4,gt))+self.criter(d2,gtOut) 235 | return d2,loss 236 | else: 237 | gt = gt.unsqueeze(1).float() 238 | gtOut = F.interpolate(gt, upsize, mode='nearest') 239 | gtOut=gtOut.squeeze(1) 240 | gtOut = gtOut.long() 241 | loss=self.criter(d2, gtOut) 242 | return d2,loss 243 | -------------------------------------------------------------------------------- /model/base/merge.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Sequential): 6 | def __init__(self, inCh, outCh): 7 | super(BasicConv, self).__init__() 8 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=3, padding=1) 9 | self.bn0 = nn.BatchNorm2d(outCh) 10 | self.relu0 = nn.ReLU() 11 | class BasicConv1x1(nn.Sequential): 12 | def __init__(self, inCh, outCh): 13 | super(BasicConv1x1, self).__init__() 14 | self.conv0 = nn.Conv2d(inCh, outCh, kernel_size=1) 15 | self.bn0 = nn.BatchNorm2d(outCh) 16 | self.relu0 = nn.ReLU() 17 | self.conv1 = nn.Conv2d(outCh, outCh, kernel_size=3, padding=1) 18 | self.bn1 = nn.BatchNorm2d(outCh) 19 | self.relu1 = nn.ReLU() 20 | #self.drop=nn.Dropout2d(0.3) 21 | 22 | class ShotConv(nn.Module): 23 | def __init__(self, inCh, outCh): 24 | super(ShotConv, self).__init__() 25 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1), nn.BatchNorm2d(outCh),nn.ReLU()) 26 | self.conv1 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=2, padding=2),nn.BatchNorm2d(outCh), nn.ReLU()) 27 | self.conv2 = nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=3, dilation=4, padding=4),nn.BatchNorm2d(outCh),nn.ReLU()) 28 | self.conv3 = nn.Sequential(nn.Conv2d(outCh*3, outCh, kernel_size=3,padding=1), nn.BatchNorm2d(outCh),nn.ReLU()) 29 | self.drop=nn.Dropout2d(0.2) 30 | def forward(self,x): 31 | x0=self.conv0(x) 32 | x1=self.conv1(x) 33 | x2=self.conv2(x) 34 | x= self.conv3(torch.cat([x0,x1,x2],dim=1)) 35 | return self.drop(x) 36 | class Flatten(nn.Module): 37 | def forward(self, x): 38 | return x.flatten(start_dim=1) 39 | 40 | 41 | class ChannelGate(nn.Module): 42 | def __init__(self, gate_channels, reduction_ratio=4, pool_types=['avg', 'max']): 43 | super(ChannelGate, self).__init__() 44 | self.gate_channels = gate_channels 45 | self.mlp = nn.Sequential( 46 | Flatten(), 47 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 48 | nn.ReLU(), 49 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 50 | ) 51 | self.pool_types = pool_types 52 | 53 | def forward(self, x): 54 | channel_att_sum = None 55 | for pool_type in self.pool_types: 56 | if pool_type == 'avg': 57 | avg_pool = F.adaptive_avg_pool2d(x, (1, 1)) 58 | channel_att_raw = self.mlp(avg_pool) 59 | elif pool_type == 'max': 60 | max_pool = F.adaptive_max_pool2d(x, (1, 1)) 61 | channel_att_raw = self.mlp(max_pool) 62 | 63 | if channel_att_sum is None: 64 | channel_att_sum = channel_att_raw 65 | else: 66 | channel_att_sum = channel_att_sum + channel_att_raw 67 | # visMap.heatmap(scale, win='abc') 68 | # time.sleep(2) 69 | channel_att_sum = torch.sigmoid(channel_att_sum) 70 | scale = channel_att_sum.unsqueeze(2).unsqueeze(3) 71 | return x * scale 72 | 73 | 74 | class MergeConv1(nn.Module): 75 | def __init__(self, inplan1, inplan2, outplan): 76 | super(MergeConv1, self).__init__() 77 | self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 78 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 79 | self.chATT = ChannelGate(outplan, 4) 80 | self.merge = BasicConv(outplan, outplan) 81 | self.merge2 = BasicConv(outplan, outplan) 82 | self.drop = nn.Dropout2d(0.2) 83 | 84 | def forward(self, x1, x2): 85 | x1 = self.conv1(x1) 86 | x2 = self.conv2(x2) 87 | out = self.merge(x1 + x2) 88 | x = self.chATT(out) 89 | out = self.merge2(x) 90 | return self.drop(out) 91 | 92 | 93 | class MergeConv(nn.Module): 94 | def __init__(self, inplan1, inplan2, outplan): 95 | super(MergeConv, self).__init__() 96 | self.conv1 = nn.Conv2d(inplan1, outplan, kernel_size=1) 97 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 98 | self.merge = BasicConv(2 * outplan, outplan) 99 | self.merge2 = BasicConv(outplan, outplan) 100 | self.drop = nn.Dropout2d(0.2) 101 | 102 | def forward(self, x1, x2): 103 | # shape=x1.shape 104 | x1 = self.conv1(x1) 105 | # x2= F.interpolate(x2, size=(shape[2], shape[3]), mode='bilinear', align_corners=True) 106 | x2 = self.conv2(x2) 107 | out = self.merge(torch.cat([x1, x2], dim=1)) 108 | out = self.merge2(out) 109 | return self.drop(out) 110 | 111 | 112 | class MergeDown(nn.Module): 113 | def __init__(self, inplan1, inplan2, outplan): 114 | super(MergeDown, self).__init__() 115 | self.conv1 = BasicConv(inplan1, outplan) 116 | self.conv2 = nn.Conv2d(inplan2, outplan, kernel_size=1) 117 | self.merge = BasicConv(2 * outplan, outplan) 118 | self.convout = BasicConv(outplan, outplan) 119 | 120 | def forward(self, x1, x2): 121 | x1 = self.conv1(x1) 122 | x2 = self.conv2(x2) 123 | out = self.merge(torch.cat([x1, x2], dim=1))#+x1 124 | val = self.convout(out) 125 | return val 126 | 127 | class qsSim(nn.Module): 128 | def __init__(self, inplan): 129 | super(qsSim, self).__init__() 130 | self.conv2 = nn.Conv2d(inplan, 1, kernel_size=1,bias=False) 131 | def forward(self, x): 132 | x = self.conv2(x) 133 | return torch.sigmoid(x) 134 | class sqOut(nn.Module): 135 | def __init__(self, inCh, outCh): 136 | super(sqOut, self).__init__() 137 | self.conv0 =nn.Sequential(nn.Conv2d(inCh, outCh, kernel_size=1),nn.BatchNorm2d(outCh),nn.ReLU()) 138 | self.conv1=BasicConv(outCh,outCh) 139 | self.conv2=BasicConv(outCh,outCh) 140 | def forward(self,x): 141 | x0=self.conv0(x) 142 | x1=self.conv1(x0) 143 | x=self.conv2(x0+x1) 144 | return x 145 | 146 | class merge(nn.Module): 147 | def __init__(self, shot, nfeatures=[2048 * 2, 1024 * 2, 512 * 2], nsimlairy=[3, 6, 4],criter=None): 148 | super(merge, self).__init__() 149 | self.shot = shot 150 | self.nsimlairy = nsimlairy 151 | self.query_sup_corrConv = [] 152 | self.qs_layer_corrConv = [] 153 | self.sim_layer_corrConv = [] 154 | self.qs_sim = [] 155 | self.criter=criter 156 | for num in nfeatures: 157 | self.query_sup_corrConv.append(qsSim(num)) 158 | for num in nsimlairy: 159 | self.qs_layer_corrConv.append(BasicConv1x1(num, 128)) 160 | self.sim_layer_corrConv.append(BasicConv1x1(num, 256)) 161 | self.query_sup_corrConv = nn.ModuleList(self.query_sup_corrConv) 162 | self.qs_layer_corrConv = nn.ModuleList(self.qs_layer_corrConv) 163 | self.sim_layer_corrConv = nn.ModuleList(self.sim_layer_corrConv) 164 | self.sqShotConv4 = ShotConv(128, 128) 165 | self.sqShotConv3 = ShotConv(128, 128) 166 | self.sqShotConv2 = ShotConv(128, 128) 167 | 168 | self.simShotConv4 = BasicConv(256, 256) 169 | self.simShotConv3 = BasicConv(256, 256) 170 | self.simShotConv2 = BasicConv(256, 256) 171 | self.conv4 = MergeConv1(256, 128, 256) 172 | self.conv3 = MergeConv1(256, 128, 256) 173 | self.conv2 = MergeConv1(256, 128, 256) 174 | self.conv43 = MergeConv(256, 256, 512) 175 | self.conv432 = MergeDown(512, 256, 512) 176 | self.decoder1 = sqOut(512, 256) 177 | 178 | self.decoder4 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 179 | nn.ReLU(), 180 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 181 | 182 | self.decoder3 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 183 | nn.ReLU(), 184 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 185 | 186 | self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 187 | nn.ReLU(), 188 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 189 | 190 | self.decoderOut = nn.Sequential(nn.Conv2d(256, 128, (3, 3), padding=(1, 1), bias=False), 191 | nn.ReLU(), 192 | nn.Conv2d(128, 2, (3, 3), padding=(1, 1))) 193 | 194 | def forward(self, qur_sups, sims,gt=None): 195 | qs_sim = [] 196 | corr_sim = [] 197 | for s in range(self.shot): 198 | sup = qur_sups[s] 199 | simS = sims[s] 200 | supsimList = [] 201 | simList = [] 202 | for l in range(len(sup)): # L4,L3,L2 203 | supL = sup[l] 204 | supLsim = [] 205 | for i in range(len(supL)): 206 | supSim = self.query_sup_corrConv[l](supL[i]) # b,1,h,w 207 | supLsim.append(supSim) 208 | supLsim = torch.cat(supLsim, dim=1) # b,n,h,w 209 | supLsim = self.qs_layer_corrConv[l](supLsim) 210 | supsimList.append(supLsim) # L,b,128,h,w 211 | simL = self.sim_layer_corrConv[l](simS[l]) 212 | simList.append(simL) # L,b,128,h,w 213 | qs_sim.append(supsimList) # s,l,b,128,h,w 214 | corr_sim.append(simList) 215 | 216 | corr_lySim = [corr_sim[0][i] for i in range(len(self.nsimlairy))] # l,b,128,h,w 217 | qs_lySim = [qs_sim[0][i] for i in range(len(self.nsimlairy))] # l,b,128,h,w 218 | for ly in range(len(self.nsimlairy)): 219 | for s in range(1,self.shot): 220 | corr_lySim[ly]+=(corr_sim[s][ly]) 221 | qs_lySim[ly]+=(qs_sim[s][ly]) 222 | x4 = self.conv4(self.simShotConv4(corr_lySim[0]/self.shot), 223 | self.sqShotConv4(qs_lySim[0]/self.shot)) 224 | x3 = self.conv3(self.simShotConv3(corr_lySim[1]/self.shot), 225 | self.sqShotConv3(qs_lySim[1]/self.shot)) 226 | x2 = self.conv2(self.simShotConv2(corr_lySim[2]/self.shot), 227 | self.sqShotConv2(qs_lySim[2]/self.shot)) 228 | x4 = F.interpolate(x4, x3.size()[-2:], mode='bilinear', align_corners=True) 229 | x43 = self.conv43(x4, x3) 230 | x43 = F.interpolate(x43, x2.size()[-2:], mode='bilinear', align_corners=True) 231 | x432 = self.conv432(x43, x2) 232 | d1 = self.decoder1(x432) 233 | upsize = (d1.shape[-1] * 2,) * 2 234 | d1 = F.interpolate(d1, upsize, mode='bilinear', align_corners=True) 235 | d2 = self.decoderOut(d1) 236 | if self.training: 237 | lossSize=x432.size()[-2:] 238 | gt=gt.unsqueeze(1).float() 239 | gtOut = F.interpolate(gt, upsize, mode='nearest') 240 | gtOut=gtOut.squeeze(1) 241 | gtOut=gtOut.long() 242 | gt=F.interpolate(gt, lossSize, mode='nearest') 243 | gt=gt.squeeze(1) 244 | gt=gt.long() 245 | x2=self.decoder2(x2) 246 | x3 = self.decoder3(x3) 247 | x4 = self.decoder4(x4) 248 | x2=F.interpolate(x2, lossSize,mode='bilinear', align_corners=True) 249 | x3 = F.interpolate(x3, lossSize, mode='bilinear', align_corners=True) 250 | x4 = F.interpolate(x4, lossSize, mode='bilinear', align_corners=True) 251 | loss=0.3*(self.criter(x2,gt)+self.criter(x3,gt)+self.criter(x4,gt))+self.criter(d2,gtOut) 252 | return d2,loss 253 | else: 254 | gt = gt.unsqueeze(1).float() 255 | gtOut = F.interpolate(gt, upsize, mode='nearest') 256 | gtOut=gtOut.squeeze(1) 257 | gtOut = gtOut.long() 258 | loss=self.criter(d2, gtOut) 259 | return d2,loss 260 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | r""" COCO-20i few-shot semantic segmentation dataset """ 2 | import os 3 | import pickle 4 | 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | import torch 8 | import PIL.Image as Image 9 | #from pycocotools.coco import orgCOCO 10 | import json 11 | import time 12 | import numpy as np 13 | import itertools 14 | from collections import defaultdict 15 | import sys 16 | import cv2 17 | PYTHON_VERSION = sys.version_info[0] 18 | if PYTHON_VERSION == 2: 19 | from urllib import urlretrieve 20 | elif PYTHON_VERSION == 3: 21 | from urllib.request import urlretrieve 22 | 23 | 24 | def _isArrayLike(obj): 25 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 26 | 27 | 28 | class COCO: 29 | def __init__(self, annotation_file=None): 30 | """ 31 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 32 | :param annotation_file (str): location of annotation file 33 | :param image_folder (str): location to the folder that hosts images. 34 | :return: 35 | """ 36 | # load dataset 37 | self.dataset,self.anns,self.cats,self.imgs,self.imgNameId,self.clsId = dict(),dict(),dict(),dict(),dict(),dict() 38 | self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) 39 | if not annotation_file == None: 40 | print('loading annotations into memory...') 41 | tic = time.time() 42 | dataset = json.load(open(annotation_file, 'r')) 43 | assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 44 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 45 | self.dataset = dataset 46 | self.createIndex() 47 | self.idToCId() 48 | def idToCId(self): 49 | for idx, val in enumerate(self.dataset['categories']): 50 | self.clsId[idx]=val['id'] 51 | def createIndex(self): 52 | # create index 53 | print('creating index...') 54 | anns, cats, imgs = {}, {}, {} 55 | imgToAnns,catToImgs = defaultdict(list),defaultdict(list) 56 | if 'annotations' in self.dataset: 57 | for ann in self.dataset['annotations']: 58 | imgToAnns[ann['image_id']].append(ann) 59 | anns[ann['id']] = ann 60 | 61 | if 'images' in self.dataset: 62 | for img in self.dataset['images']: 63 | imgs[img['id']] = img 64 | if 'images' in self.dataset: 65 | for img in self.dataset['images']: 66 | self.imgNameId[img['file_name']] = img 67 | 68 | if 'categories' in self.dataset: 69 | for cat in self.dataset['categories']: 70 | cats[cat['id']] = cat 71 | 72 | if 'annotations' in self.dataset and 'categories' in self.dataset: 73 | for ann in self.dataset['annotations']: 74 | catToImgs[ann['category_id']].append(ann['image_id']) 75 | 76 | print('index created!') 77 | 78 | # create class members 79 | self.anns = anns 80 | self.imgToAnns = imgToAnns 81 | self.catToImgs = catToImgs 82 | self.imgs = imgs 83 | self.cats = cats 84 | 85 | def info(self): 86 | """ 87 | Print information about the annotation file. 88 | :return: 89 | """ 90 | for key, value in self.dataset['info'].items(): 91 | print('{}: {}'.format(key, value)) 92 | 93 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 94 | """ 95 | Get ann ids that satisfy given filter conditions. default skips that filter 96 | :param imgIds (int array) : get anns for given imgs 97 | catIds (int array) : get anns for given cats 98 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 99 | iscrowd (boolean) : get anns for given crowd label (False or True) 100 | :return: ids (int array) : integer array of ann ids 101 | """ 102 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 103 | catIds = catIds if _isArrayLike(catIds) else [catIds] 104 | 105 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 106 | anns = self.dataset['annotations'] 107 | else: 108 | if not len(imgIds) == 0: 109 | lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] 110 | anns = list(itertools.chain.from_iterable(lists)) 111 | else: 112 | anns = self.dataset['annotations'] 113 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 114 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 115 | if not iscrowd == None: 116 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 117 | else: 118 | ids = [ann['id'] for ann in anns] 119 | return ids 120 | 121 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 122 | """ 123 | filtering parameters. default skips that filter. 124 | :param catNms (str array) : get cats for given cat names 125 | :param supNms (str array) : get cats for given supercategory names 126 | :param catIds (int array) : get cats for given cat ids 127 | :return: ids (int array) : integer array of cat ids 128 | """ 129 | catNms = catNms if _isArrayLike(catNms) else [catNms] 130 | supNms = supNms if _isArrayLike(supNms) else [supNms] 131 | catIds = catIds if _isArrayLike(catIds) else [catIds] 132 | 133 | if len(catNms) == len(supNms) == len(catIds) == 0: 134 | cats = self.dataset['categories'] 135 | else: 136 | cats = self.dataset['categories'] 137 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 138 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 139 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 140 | ids = [cat['id'] for cat in cats] 141 | return ids 142 | 143 | def getImgIds(self, imgIds=[], catIds=[]): 144 | ''' 145 | Get img ids that satisfy given filter conditions. 146 | :param imgIds (int array) : get imgs for given ids 147 | :param catIds (int array) : get imgs with all given cats 148 | :return: ids (int array) : integer array of img ids 149 | ''' 150 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 151 | catIds = catIds if _isArrayLike(catIds) else [catIds] 152 | 153 | if len(imgIds) == len(catIds) == 0: 154 | ids = self.imgs.keys() 155 | else: 156 | ids = set(imgIds) 157 | for i, catId in enumerate(catIds): 158 | if i == 0 and len(ids) == 0: 159 | ids = set(self.catToImgs[catId]) 160 | else: 161 | ids &= set(self.catToImgs[catId]) 162 | return list(ids) 163 | 164 | def loadAnns(self, ids=[]): 165 | """ 166 | Load anns with the specified ids. 167 | :param ids (int array) : integer ids specifying anns 168 | :return: anns (object array) : loaded ann objects 169 | """ 170 | if _isArrayLike(ids): 171 | return [self.anns[id] for id in ids] 172 | elif type(ids) == int: 173 | return [self.anns[ids]] 174 | 175 | def loadCats(self, ids=[]): 176 | """ 177 | Load cats with the specified ids. 178 | :param ids (int array) : integer ids specifying cats 179 | :return: cats (object array) : loaded cat objects 180 | """ 181 | if _isArrayLike(ids): 182 | return [self.cats[id] for id in ids] 183 | elif type(ids) == int: 184 | return [self.cats[ids]] 185 | 186 | def loadImgs(self, ids=[]): 187 | """ 188 | Load anns with the specified ids. 189 | :param ids (int array) : integer ids specifying img 190 | :return: imgs (object array) : loaded img objects 191 | """ 192 | if _isArrayLike(ids): 193 | return [self.imgs[id] for id in ids] 194 | elif type(ids) == int: 195 | return [self.imgs[ids]] 196 | 197 | class DatasetCOCO(Dataset): 198 | def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize): 199 | self.split = 'val' if split in ['val', 'test'] else 'trn' 200 | self.fold = fold 201 | self.nfolds = 4 202 | self.nclass = 80 203 | self.benchmark = 'coco' 204 | self.shot = shot 205 | self.split_coco = 'val2014' if self.split == 'val' else 'train2014' 206 | self.base_path = os.path.join(datapath, 'COCO2014') 207 | self.annotion_path=self.base_path+'/annotations/instances_'+self.split_coco+'.json' 208 | print('load annotions in:',self.annotion_path,self.split) 209 | self.coco=COCO(self.annotion_path) 210 | self.transform = transform 211 | self.use_original_imgsize = use_original_imgsize 212 | 213 | self.class_ids = self.build_class_ids() 214 | self.img_metadata_classwise = self.build_img_metadata_classwise() 215 | self.img_metadata = self.build_img_metadata() 216 | 217 | def __len__(self): 218 | return 7500 if self.split == 'trn' else 1000 219 | 220 | def __getitem__(self, idx): 221 | # ignores idx during training & testing and perform uniform sampling over object classes to form an episode 222 | # (due to the large size of the COCO dataset) 223 | query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame() 224 | 225 | query_img = self.transform(query_img) 226 | query_mask = query_mask.float() 227 | if not self.use_original_imgsize: 228 | query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze() 229 | 230 | support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs]) 231 | for midx, smask in enumerate(support_masks): 232 | support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze() 233 | support_masks = torch.stack(support_masks) 234 | 235 | batch = {'query_img': query_img, 236 | 'query_mask': query_mask, 237 | 'query_name': query_name, 238 | 239 | 'org_query_imsize': org_qry_imsize, 240 | 241 | 'support_imgs': support_imgs, 242 | 'support_masks': support_masks, 243 | 'support_names': support_names, 244 | 'class_id': torch.tensor(class_sample)} 245 | 246 | return batch 247 | 248 | def build_class_ids(self): 249 | nclass_trn = self.nclass // self.nfolds 250 | class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)] 251 | class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val] 252 | class_ids = class_ids_trn if self.split == 'trn' else class_ids_val 253 | 254 | return class_ids 255 | 256 | def build_img_metadata_classwise(self): 257 | with open('./data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f: 258 | img_metadata_classwise = pickle.load(f) 259 | return img_metadata_classwise 260 | 261 | def build_img_metadata(self): 262 | img_metadata = [] 263 | for k in self.img_metadata_classwise.keys(): 264 | img_metadata += self.img_metadata_classwise[k] 265 | return sorted(list(set(img_metadata))) 266 | 267 | def read_mask(self, name,class_id): 268 | calssSeg=[] 269 | imgInfo = self.coco.imgNameId[name] 270 | h = imgInfo['height'] 271 | w = imgInfo['width'] 272 | gt = np.zeros((h, w), dtype=np.uint8) 273 | imgid = self.coco.getAnnIds(imgInfo['id']) 274 | for id in imgid: 275 | annotion = self.coco.anns[id] 276 | if annotion['category_id']==self.coco.clsId[class_id]: 277 | if annotion['iscrowd']: 278 | continue 279 | for val in annotion['segmentation']: 280 | seg = np.array(val).reshape(-1, 2) 281 | calssSeg.append(seg.astype(np.int32)[np.newaxis, :, :]) 282 | cv2.fillPoly(gt, calssSeg, int(1)) 283 | return torch.tensor(gt) 284 | 285 | def load_frame(self): 286 | class_sample = np.random.choice(self.class_ids, 1, replace=False)[0] 287 | query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 288 | query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB') 289 | query_mask = self.read_mask(query_name.split('/')[-1],class_sample) 290 | 291 | org_qry_imsize = query_img.size 292 | support_names = [] 293 | while True: # keep sampling support set if query == support 294 | support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0] 295 | if query_name != support_name: support_names.append(support_name) 296 | if len(support_names) == self.shot: break 297 | support_imgs = [] 298 | support_masks = [] 299 | for support_name in support_names: 300 | support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB')) 301 | support_mask = self.read_mask(support_name.split('/')[-1],class_sample) 302 | support_masks.append(support_mask) 303 | 304 | return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize 305 | 306 | -------------------------------------------------------------------------------- /data/splits/pascal/val/fold2.txt: -------------------------------------------------------------------------------- 1 | 2007_000129__15 2 | 2007_000323__15 3 | 2007_000332__13 4 | 2007_000346__15 5 | 2007_000762__11 6 | 2007_000762__15 7 | 2007_000783__13 8 | 2007_000783__15 9 | 2007_000799__13 10 | 2007_000799__15 11 | 2007_000830__11 12 | 2007_000847__11 13 | 2007_000847__15 14 | 2007_000999__15 15 | 2007_001175__15 16 | 2007_001239__12 17 | 2007_001284__15 18 | 2007_001311__15 19 | 2007_001408__15 20 | 2007_001423__15 21 | 2007_001430__11 22 | 2007_001430__15 23 | 2007_001526__15 24 | 2007_001585__15 25 | 2007_001586__13 26 | 2007_001586__15 27 | 2007_001594__15 28 | 2007_001630__15 29 | 2007_001677__11 30 | 2007_001678__15 31 | 2007_001717__15 32 | 2007_001763__12 33 | 2007_001955__13 34 | 2007_002046__13 35 | 2007_002119__15 36 | 2007_002260__14 37 | 2007_002268__12 38 | 2007_002378__15 39 | 2007_002426__15 40 | 2007_002539__15 41 | 2007_002565__15 42 | 2007_002597__12 43 | 2007_002624__11 44 | 2007_002624__15 45 | 2007_002643__15 46 | 2007_002728__15 47 | 2007_002823__14 48 | 2007_002823__15 49 | 2007_002824__15 50 | 2007_002852__12 51 | 2007_003011__11 52 | 2007_003020__15 53 | 2007_003022__13 54 | 2007_003022__15 55 | 2007_003088__15 56 | 2007_003106__15 57 | 2007_003110__12 58 | 2007_003134__15 59 | 2007_003188__15 60 | 2007_003194__12 61 | 2007_003367__14 62 | 2007_003367__15 63 | 2007_003373__12 64 | 2007_003373__15 65 | 2007_003530__15 66 | 2007_003621__15 67 | 2007_003742__11 68 | 2007_003742__15 69 | 2007_003872__12 70 | 2007_004033__14 71 | 2007_004033__15 72 | 2007_004112__12 73 | 2007_004112__15 74 | 2007_004121__15 75 | 2007_004189__12 76 | 2007_004275__14 77 | 2007_004275__15 78 | 2007_004281__15 79 | 2007_004380__14 80 | 2007_004380__15 81 | 2007_004392__15 82 | 2007_004405__11 83 | 2007_004538__13 84 | 2007_004538__15 85 | 2007_004644__12 86 | 2007_004712__11 87 | 2007_004712__15 88 | 2007_004722__13 89 | 2007_004722__15 90 | 2007_004902__13 91 | 2007_004902__15 92 | 2007_005114__13 93 | 2007_005114__15 94 | 2007_005149__12 95 | 2007_005173__14 96 | 2007_005173__15 97 | 2007_005281__15 98 | 2007_005304__15 99 | 2007_005331__13 100 | 2007_005331__15 101 | 2007_005354__14 102 | 2007_005354__15 103 | 2007_005509__15 104 | 2007_005547__15 105 | 2007_005608__14 106 | 2007_005608__15 107 | 2007_005696__12 108 | 2007_005759__14 109 | 2007_005803__11 110 | 2007_005844__11 111 | 2007_005845__15 112 | 2007_006028__15 113 | 2007_006076__15 114 | 2007_006086__11 115 | 2007_006117__15 116 | 2007_006171__12 117 | 2007_006171__15 118 | 2007_006241__11 119 | 2007_006364__13 120 | 2007_006364__15 121 | 2007_006373__15 122 | 2007_006444__12 123 | 2007_006444__15 124 | 2007_006560__15 125 | 2007_006647__14 126 | 2007_006647__15 127 | 2007_006698__15 128 | 2007_006802__15 129 | 2007_006841__15 130 | 2007_006864__15 131 | 2007_006866__13 132 | 2007_006866__15 133 | 2007_007007__11 134 | 2007_007007__15 135 | 2007_007109__13 136 | 2007_007109__15 137 | 2007_007195__15 138 | 2007_007203__15 139 | 2007_007211__14 140 | 2007_007235__15 141 | 2007_007417__12 142 | 2007_007493__15 143 | 2007_007498__11 144 | 2007_007498__15 145 | 2007_007651__11 146 | 2007_007651__15 147 | 2007_007688__14 148 | 2007_007748__13 149 | 2007_007748__15 150 | 2007_007795__15 151 | 2007_007810__11 152 | 2007_007810__15 153 | 2007_007815__15 154 | 2007_007836__15 155 | 2007_007849__15 156 | 2007_007996__15 157 | 2007_008110__15 158 | 2007_008204__15 159 | 2007_008222__12 160 | 2007_008256__13 161 | 2007_008256__15 162 | 2007_008260__12 163 | 2007_008374__15 164 | 2007_008415__12 165 | 2007_008430__15 166 | 2007_008596__13 167 | 2007_008596__15 168 | 2007_008708__15 169 | 2007_008802__13 170 | 2007_008897__15 171 | 2007_008944__15 172 | 2007_008964__12 173 | 2007_008964__15 174 | 2007_008980__12 175 | 2007_009068__15 176 | 2007_009084__12 177 | 2007_009084__14 178 | 2007_009251__13 179 | 2007_009251__15 180 | 2007_009258__15 181 | 2007_009320__15 182 | 2007_009331__12 183 | 2007_009331__13 184 | 2007_009331__15 185 | 2007_009413__11 186 | 2007_009413__15 187 | 2007_009521__11 188 | 2007_009562__12 189 | 2007_009592__12 190 | 2007_009654__15 191 | 2007_009655__15 192 | 2007_009684__15 193 | 2007_009687__15 194 | 2007_009691__14 195 | 2007_009691__15 196 | 2007_009706__11 197 | 2007_009750__15 198 | 2007_009756__14 199 | 2007_009756__15 200 | 2007_009841__13 201 | 2007_009938__14 202 | 2008_000080__12 203 | 2008_000213__15 204 | 2008_000215__15 205 | 2008_000223__15 206 | 2008_000233__15 207 | 2008_000234__15 208 | 2008_000239__12 209 | 2008_000270__12 210 | 2008_000270__15 211 | 2008_000271__15 212 | 2008_000359__15 213 | 2008_000474__15 214 | 2008_000510__15 215 | 2008_000573__11 216 | 2008_000573__15 217 | 2008_000602__13 218 | 2008_000630__15 219 | 2008_000661__12 220 | 2008_000661__15 221 | 2008_000662__15 222 | 2008_000666__15 223 | 2008_000673__15 224 | 2008_000700__15 225 | 2008_000725__15 226 | 2008_000731__15 227 | 2008_000763__11 228 | 2008_000763__15 229 | 2008_000765__13 230 | 2008_000782__14 231 | 2008_000795__15 232 | 2008_000811__14 233 | 2008_000811__15 234 | 2008_000863__12 235 | 2008_000943__12 236 | 2008_000992__15 237 | 2008_001013__15 238 | 2008_001028__15 239 | 2008_001070__12 240 | 2008_001074__15 241 | 2008_001076__15 242 | 2008_001150__14 243 | 2008_001170__15 244 | 2008_001231__15 245 | 2008_001249__15 246 | 2008_001283__15 247 | 2008_001308__15 248 | 2008_001379__12 249 | 2008_001404__15 250 | 2008_001478__12 251 | 2008_001491__15 252 | 2008_001504__15 253 | 2008_001531__15 254 | 2008_001547__15 255 | 2008_001629__15 256 | 2008_001682__13 257 | 2008_001821__15 258 | 2008_001874__15 259 | 2008_001895__12 260 | 2008_001895__15 261 | 2008_001992__13 262 | 2008_001992__15 263 | 2008_002212__15 264 | 2008_002239__12 265 | 2008_002240__14 266 | 2008_002241__15 267 | 2008_002379__11 268 | 2008_002383__14 269 | 2008_002495__15 270 | 2008_002536__12 271 | 2008_002588__15 272 | 2008_002775__11 273 | 2008_002775__15 274 | 2008_002835__13 275 | 2008_002835__15 276 | 2008_002859__12 277 | 2008_002864__11 278 | 2008_002864__15 279 | 2008_002904__12 280 | 2008_002929__15 281 | 2008_002936__12 282 | 2008_002942__15 283 | 2008_002958__12 284 | 2008_003034__15 285 | 2008_003076__15 286 | 2008_003108__15 287 | 2008_003141__15 288 | 2008_003210__15 289 | 2008_003238__12 290 | 2008_003238__15 291 | 2008_003330__15 292 | 2008_003333__14 293 | 2008_003333__15 294 | 2008_003379__13 295 | 2008_003451__14 296 | 2008_003451__15 297 | 2008_003461__13 298 | 2008_003461__15 299 | 2008_003477__11 300 | 2008_003492__15 301 | 2008_003511__12 302 | 2008_003511__15 303 | 2008_003546__15 304 | 2008_003576__12 305 | 2008_003676__15 306 | 2008_003733__15 307 | 2008_003782__13 308 | 2008_003856__15 309 | 2008_003874__15 310 | 2008_004101__15 311 | 2008_004140__11 312 | 2008_004140__15 313 | 2008_004175__13 314 | 2008_004345__14 315 | 2008_004396__13 316 | 2008_004399__14 317 | 2008_004399__15 318 | 2008_004575__11 319 | 2008_004575__15 320 | 2008_004624__13 321 | 2008_004654__15 322 | 2008_004687__13 323 | 2008_004705__13 324 | 2008_005049__14 325 | 2008_005089__15 326 | 2008_005145__11 327 | 2008_005197__12 328 | 2008_005197__15 329 | 2008_005245__14 330 | 2008_005245__15 331 | 2008_005399__15 332 | 2008_005422__14 333 | 2008_005445__15 334 | 2008_005525__13 335 | 2008_005637__14 336 | 2008_005642__13 337 | 2008_005691__13 338 | 2008_005738__15 339 | 2008_005812__15 340 | 2008_005915__14 341 | 2008_006008__11 342 | 2008_006036__13 343 | 2008_006108__11 344 | 2008_006108__15 345 | 2008_006130__12 346 | 2008_006216__15 347 | 2008_006219__13 348 | 2008_006254__15 349 | 2008_006275__15 350 | 2008_006341__15 351 | 2008_006408__11 352 | 2008_006408__15 353 | 2008_006526__14 354 | 2008_006526__15 355 | 2008_006554__15 356 | 2008_006722__12 357 | 2008_006722__15 358 | 2008_006874__14 359 | 2008_006874__15 360 | 2008_006981__12 361 | 2008_007048__11 362 | 2008_007219__15 363 | 2008_007378__11 364 | 2008_007378__12 365 | 2008_007392__13 366 | 2008_007392__15 367 | 2008_007402__11 368 | 2008_007402__15 369 | 2008_007513__12 370 | 2008_007737__15 371 | 2008_007828__15 372 | 2008_007945__13 373 | 2008_007994__15 374 | 2008_008051__11 375 | 2008_008127__14 376 | 2008_008127__15 377 | 2008_008221__15 378 | 2008_008335__11 379 | 2008_008335__15 380 | 2008_008362__11 381 | 2008_008362__15 382 | 2008_008392__13 383 | 2008_008393__13 384 | 2008_008421__13 385 | 2008_008469__15 386 | 2009_000012__13 387 | 2009_000074__14 388 | 2009_000074__15 389 | 2009_000156__12 390 | 2009_000219__15 391 | 2009_000309__15 392 | 2009_000412__13 393 | 2009_000418__15 394 | 2009_000421__15 395 | 2009_000457__15 396 | 2009_000704__15 397 | 2009_000705__13 398 | 2009_000727__13 399 | 2009_000730__14 400 | 2009_000730__15 401 | 2009_000825__14 402 | 2009_000825__15 403 | 2009_000839__12 404 | 2009_000892__12 405 | 2009_000931__13 406 | 2009_000935__12 407 | 2009_001215__11 408 | 2009_001215__15 409 | 2009_001299__15 410 | 2009_001433__13 411 | 2009_001433__15 412 | 2009_001535__12 413 | 2009_001663__15 414 | 2009_001687__12 415 | 2009_001687__15 416 | 2009_001718__15 417 | 2009_001768__15 418 | 2009_001854__15 419 | 2009_002012__12 420 | 2009_002042__15 421 | 2009_002097__13 422 | 2009_002155__12 423 | 2009_002165__13 424 | 2009_002185__15 425 | 2009_002239__14 426 | 2009_002239__15 427 | 2009_002317__14 428 | 2009_002317__15 429 | 2009_002346__12 430 | 2009_002346__15 431 | 2009_002372__15 432 | 2009_002382__14 433 | 2009_002382__15 434 | 2009_002415__11 435 | 2009_002445__12 436 | 2009_002487__11 437 | 2009_002539__12 438 | 2009_002571__11 439 | 2009_002584__15 440 | 2009_002649__15 441 | 2009_002651__14 442 | 2009_002651__15 443 | 2009_002732__15 444 | 2009_002975__13 445 | 2009_003003__11 446 | 2009_003003__15 447 | 2009_003063__12 448 | 2009_003065__15 449 | 2009_003071__11 450 | 2009_003071__15 451 | 2009_003123__11 452 | 2009_003196__14 453 | 2009_003217__12 454 | 2009_003241__12 455 | 2009_003269__15 456 | 2009_003323__13 457 | 2009_003323__15 458 | 2009_003466__12 459 | 2009_003481__13 460 | 2009_003494__15 461 | 2009_003507__11 462 | 2009_003576__14 463 | 2009_003576__15 464 | 2009_003756__12 465 | 2009_003804__13 466 | 2009_003810__12 467 | 2009_003849__11 468 | 2009_003849__15 469 | 2009_003903__13 470 | 2009_003928__12 471 | 2009_003991__11 472 | 2009_003991__15 473 | 2009_004033__12 474 | 2009_004043__14 475 | 2009_004043__15 476 | 2009_004140__11 477 | 2009_004221__15 478 | 2009_004455__14 479 | 2009_004497__13 480 | 2009_004507__12 481 | 2009_004507__15 482 | 2009_004581__12 483 | 2009_004592__12 484 | 2009_004738__14 485 | 2009_004738__15 486 | 2009_004848__15 487 | 2009_004859__11 488 | 2009_004859__15 489 | 2009_004942__13 490 | 2009_004987__14 491 | 2009_004987__15 492 | 2009_004994__12 493 | 2009_004994__15 494 | 2009_005038__11 495 | 2009_005038__15 496 | 2009_005078__14 497 | 2009_005087__15 498 | 2009_005217__13 499 | 2009_005217__15 500 | 2010_000003__12 501 | 2010_000038__13 502 | 2010_000038__15 503 | 2010_000087__14 504 | 2010_000087__15 505 | 2010_000110__12 506 | 2010_000110__15 507 | 2010_000159__12 508 | 2010_000174__11 509 | 2010_000174__15 510 | 2010_000216__12 511 | 2010_000238__15 512 | 2010_000256__15 513 | 2010_000422__12 514 | 2010_000530__15 515 | 2010_000559__15 516 | 2010_000639__12 517 | 2010_000666__13 518 | 2010_000666__15 519 | 2010_000738__15 520 | 2010_000788__12 521 | 2010_000874__13 522 | 2010_000904__12 523 | 2010_001024__15 524 | 2010_001124__12 525 | 2010_001251__14 526 | 2010_001264__12 527 | 2010_001313__14 528 | 2010_001313__15 529 | 2010_001367__15 530 | 2010_001376__12 531 | 2010_001451__13 532 | 2010_001553__14 533 | 2010_001563__12 534 | 2010_001563__15 535 | 2010_001579__11 536 | 2010_001579__15 537 | 2010_001692__15 538 | 2010_001699__15 539 | 2010_001734__15 540 | 2010_001767__15 541 | 2010_001851__11 542 | 2010_001908__12 543 | 2010_001956__12 544 | 2010_002017__15 545 | 2010_002137__15 546 | 2010_002161__13 547 | 2010_002161__15 548 | 2010_002228__12 549 | 2010_002251__14 550 | 2010_002251__15 551 | 2010_002271__14 552 | 2010_002336__11 553 | 2010_002396__14 554 | 2010_002396__15 555 | 2010_002480__12 556 | 2010_002623__15 557 | 2010_002691__13 558 | 2010_002763__15 559 | 2010_002792__15 560 | 2010_002902__15 561 | 2010_002929__15 562 | 2010_003014__15 563 | 2010_003060__12 564 | 2010_003187__12 565 | 2010_003207__14 566 | 2010_003239__15 567 | 2010_003325__11 568 | 2010_003325__15 569 | 2010_003381__15 570 | 2010_003409__15 571 | 2010_003446__15 572 | 2010_003506__12 573 | 2010_003531__11 574 | 2010_003532__13 575 | 2010_003597__11 576 | 2010_003597__15 577 | 2010_003746__12 578 | 2010_003746__15 579 | 2010_003947__14 580 | 2010_003971__11 581 | 2010_004042__14 582 | 2010_004165__12 583 | 2010_004165__15 584 | 2010_004219__14 585 | 2010_004219__15 586 | 2010_004337__15 587 | 2010_004355__14 588 | 2010_004432__15 589 | 2010_004472__15 590 | 2010_004479__15 591 | 2010_004519__13 592 | 2010_004550__12 593 | 2010_004559__15 594 | 2010_004628__12 595 | 2010_004697__14 596 | 2010_004697__15 597 | 2010_004795__12 598 | 2010_004815__15 599 | 2010_004825__11 600 | 2010_004828__15 601 | 2010_004856__13 602 | 2010_004941__14 603 | 2010_004951__15 604 | 2010_005046__11 605 | 2010_005046__15 606 | 2010_005118__15 607 | 2010_005159__12 608 | 2010_005160__14 609 | 2010_005166__15 610 | 2010_005174__13 611 | 2010_005206__12 612 | 2010_005245__12 613 | 2010_005245__15 614 | 2010_005252__14 615 | 2010_005252__15 616 | 2010_005284__15 617 | 2010_005366__14 618 | 2010_005433__14 619 | 2010_005501__14 620 | 2010_005575__12 621 | 2010_005582__15 622 | 2010_005606__15 623 | 2010_005626__11 624 | 2010_005626__15 625 | 2010_005644__12 626 | 2010_005709__15 627 | 2010_005871__15 628 | 2010_005991__12 629 | 2010_005991__15 630 | 2010_005992__12 631 | 2011_000045__12 632 | 2011_000051__15 633 | 2011_000054__15 634 | 2011_000178__15 635 | 2011_000226__11 636 | 2011_000248__15 637 | 2011_000338__11 638 | 2011_000396__13 639 | 2011_000435__15 640 | 2011_000438__15 641 | 2011_000455__14 642 | 2011_000455__15 643 | 2011_000479__15 644 | 2011_000512__14 645 | 2011_000526__13 646 | 2011_000536__12 647 | 2011_000566__15 648 | 2011_000585__15 649 | 2011_000598__11 650 | 2011_000618__14 651 | 2011_000618__15 652 | 2011_000638__15 653 | 2011_000780__15 654 | 2011_000809__11 655 | 2011_000809__15 656 | 2011_000843__15 657 | 2011_000953__11 658 | 2011_000953__15 659 | 2011_001014__12 660 | 2011_001060__15 661 | 2011_001069__15 662 | 2011_001071__15 663 | 2011_001159__15 664 | 2011_001276__11 665 | 2011_001276__12 666 | 2011_001276__15 667 | 2011_001346__15 668 | 2011_001416__15 669 | 2011_001447__15 670 | 2011_001530__15 671 | 2011_001567__15 672 | 2011_001619__15 673 | 2011_001642__12 674 | 2011_001665__11 675 | 2011_001674__15 676 | 2011_001714__12 677 | 2011_001714__15 678 | 2011_001722__13 679 | 2011_001745__12 680 | 2011_001794__15 681 | 2011_001862__11 682 | 2011_001862__12 683 | 2011_001868__12 684 | 2011_001984__12 685 | 2011_001988__15 686 | 2011_002002__15 687 | 2011_002040__12 688 | 2011_002075__11 689 | 2011_002075__15 690 | 2011_002098__12 691 | 2011_002110__12 692 | 2011_002110__15 693 | 2011_002121__12 694 | 2011_002124__15 695 | 2011_002156__12 696 | 2011_002200__11 697 | 2011_002200__15 698 | 2011_002247__15 699 | 2011_002279__12 700 | 2011_002298__12 701 | 2011_002308__15 702 | 2011_002317__15 703 | 2011_002322__14 704 | 2011_002322__15 705 | 2011_002343__15 706 | 2011_002358__11 707 | 2011_002358__15 708 | 2011_002371__12 709 | 2011_002498__15 710 | 2011_002509__15 711 | 2011_002532__15 712 | 2011_002575__15 713 | 2011_002578__15 714 | 2011_002589__12 715 | 2011_002623__15 716 | 2011_002641__15 717 | 2011_002675__15 718 | 2011_002951__13 719 | 2011_002997__15 720 | 2011_003019__14 721 | 2011_003019__15 722 | 2011_003085__13 723 | 2011_003114__15 724 | 2011_003240__15 725 | 2011_003256__12 726 | -------------------------------------------------------------------------------- /data/splits/pascal/trn/fold3.txt: -------------------------------------------------------------------------------- 1 | 2007_001149__16 2 | 2007_001420__16 3 | 2007_002361__16 4 | 2007_002967__16 5 | 2007_003189__16 6 | 2007_003778__16 7 | 2007_004081__16 8 | 2007_004707__16 9 | 2007_004948__16 10 | 2007_006303__16 11 | 2007_006605__16 12 | 2007_007890__16 13 | 2007_008043__16 14 | 2007_008140__16 15 | 2007_008821__16 16 | 2007_009630__16 17 | 2008_000188__16 18 | 2008_000196__16 19 | 2008_000274__16 20 | 2008_000287__16 21 | 2008_000491__16 22 | 2008_000564__16 23 | 2008_000790__16 24 | 2008_000841__16 25 | 2008_000857__16 26 | 2008_000916__16 27 | 2008_000923__16 28 | 2008_000960__16 29 | 2008_001133__16 30 | 2008_001451__16 31 | 2008_001460__16 32 | 2008_001467__16 33 | 2008_001784__16 34 | 2008_001862__16 35 | 2008_001865__16 36 | 2008_002026__16 37 | 2008_002191__16 38 | 2008_002317__16 39 | 2008_002653__16 40 | 2008_003225__16 41 | 2008_003320__16 42 | 2008_003434__16 43 | 2008_003466__16 44 | 2008_003523__16 45 | 2008_003547__16 46 | 2008_003636__16 47 | 2008_003665__16 48 | 2008_003726__16 49 | 2008_003767__16 50 | 2008_003768__16 51 | 2008_003815__16 52 | 2008_003838__16 53 | 2008_003978__16 54 | 2008_004002__16 55 | 2008_004003__16 56 | 2008_004092__16 57 | 2008_004171__16 58 | 2008_004380__16 59 | 2008_004428__16 60 | 2008_004435__16 61 | 2008_004497__16 62 | 2008_004615__16 63 | 2008_004619__16 64 | 2008_004634__16 65 | 2008_004661__16 66 | 2008_004756__16 67 | 2008_004977__16 68 | 2008_005001__16 69 | 2008_005040__16 70 | 2008_005042__16 71 | 2008_005111__16 72 | 2008_005146__16 73 | 2008_005214__16 74 | 2008_005345__16 75 | 2008_005417__16 76 | 2008_005501__16 77 | 2008_005511__16 78 | 2008_005519__16 79 | 2008_005608__16 80 | 2008_005794__16 81 | 2008_005798__16 82 | 2008_005847__16 83 | 2008_005874__16 84 | 2008_005897__16 85 | 2008_005914__16 86 | 2008_005954__16 87 | 2008_006068__16 88 | 2008_006112__16 89 | 2008_006203__16 90 | 2008_006207__16 91 | 2008_006262__16 92 | 2008_006295__16 93 | 2008_006337__16 94 | 2008_006441__16 95 | 2008_006524__16 96 | 2008_006534__16 97 | 2008_006543__16 98 | 2008_006562__16 99 | 2008_006751__16 100 | 2008_006796__16 101 | 2008_006807__16 102 | 2008_006816__16 103 | 2008_006828__16 104 | 2008_006864__16 105 | 2008_006881__16 106 | 2008_006950__16 107 | 2008_007042__16 108 | 2008_007108__16 109 | 2008_007223__16 110 | 2008_007226__16 111 | 2008_007281__16 112 | 2008_007388__16 113 | 2008_007461__16 114 | 2008_007525__16 115 | 2008_007621__16 116 | 2008_007701__16 117 | 2008_007823__16 118 | 2008_007831__16 119 | 2008_007835__16 120 | 2008_007973__16 121 | 2008_007977__16 122 | 2008_008024__16 123 | 2008_008070__16 124 | 2008_008096__16 125 | 2008_008184__16 126 | 2008_008208__16 127 | 2008_008235__16 128 | 2008_008237__16 129 | 2008_008310__16 130 | 2008_008330__16 131 | 2008_008331__16 132 | 2008_008341__16 133 | 2008_008363__16 134 | 2008_008517__16 135 | 2008_008531__16 136 | 2008_008608__16 137 | 2008_008621__16 138 | 2008_008641__16 139 | 2008_008689__16 140 | 2009_000158__16 141 | 2009_000198__16 142 | 2009_000297__16 143 | 2009_000419__16 144 | 2009_000526__16 145 | 2009_000590__16 146 | 2009_000624__16 147 | 2009_000635__16 148 | 2009_000760__16 149 | 2009_000867__16 150 | 2009_000926__16 151 | 2009_001085__16 152 | 2009_001100__16 153 | 2009_001137__16 154 | 2009_001229__16 155 | 2009_001249__16 156 | 2009_001417__16 157 | 2009_001440__16 158 | 2009_001514__16 159 | 2009_001627__16 160 | 2009_001667__16 161 | 2009_001704__16 162 | 2009_001806__16 163 | 2009_001864__16 164 | 2009_001888__16 165 | 2009_001922__16 166 | 2009_001934__16 167 | 2009_001975__16 168 | 2009_002088__16 169 | 2009_002123__16 170 | 2009_002386__16 171 | 2009_002433__16 172 | 2009_002444__16 173 | 2009_002628__16 174 | 2009_002670__16 175 | 2009_002688__16 176 | 2009_002698__16 177 | 2009_002741__16 178 | 2009_002755__16 179 | 2009_002817__16 180 | 2009_002935__16 181 | 2009_002952__16 182 | 2009_003039__16 183 | 2009_003074__16 184 | 2009_003077__16 185 | 2009_003238__16 186 | 2009_003288__16 187 | 2009_003301__16 188 | 2009_003351__16 189 | 2009_003384__16 190 | 2009_003440__16 191 | 2009_003476__16 192 | 2009_003642__16 193 | 2009_003654__16 194 | 2009_003677__16 195 | 2009_003679__16 196 | 2009_003697__16 197 | 2009_003815__16 198 | 2009_003888__16 199 | 2009_003929__16 200 | 2009_004025__16 201 | 2009_004050__16 202 | 2009_004051__16 203 | 2009_004093__16 204 | 2009_004191__16 205 | 2009_004263__16 206 | 2009_004274__16 207 | 2009_004283__16 208 | 2009_004394__16 209 | 2009_004404__16 210 | 2009_004426__16 211 | 2009_004438__16 212 | 2009_004514__16 213 | 2009_004532__16 214 | 2009_004537__16 215 | 2009_004554__16 216 | 2009_004631__16 217 | 2009_004642__16 218 | 2009_004645__16 219 | 2009_004745__16 220 | 2009_004794__16 221 | 2009_004869__16 222 | 2009_004885__16 223 | 2009_004901__16 224 | 2009_004919__16 225 | 2009_004983__16 226 | 2009_004990__16 227 | 2009_005008__16 228 | 2009_005069__16 229 | 2009_005070__16 230 | 2009_005165__16 231 | 2009_005170__16 232 | 2009_005278__16 233 | 2009_005286__16 234 | 2009_005310__16 235 | 2010_000015__16 236 | 2010_000027__16 237 | 2010_000132__16 238 | 2010_000202__16 239 | 2010_000399__16 240 | 2010_000470__16 241 | 2010_000567__16 242 | 2010_000601__16 243 | 2010_000669__16 244 | 2010_000737__16 245 | 2010_000773__16 246 | 2010_000800__16 247 | 2010_000871__16 248 | 2010_000876__16 249 | 2010_000973__16 250 | 2010_001111__16 251 | 2010_001134__16 252 | 2010_001219__16 253 | 2010_001310__16 254 | 2010_001479__16 255 | 2010_001544__16 256 | 2010_001717__16 257 | 2010_001743__16 258 | 2010_001787__16 259 | 2010_001843__16 260 | 2010_001864__16 261 | 2010_001940__16 262 | 2010_002193__16 263 | 2010_002195__16 264 | 2010_002316__16 265 | 2010_002366__16 266 | 2010_002379__16 267 | 2010_002462__16 268 | 2010_002537__16 269 | 2010_002605__16 270 | 2010_002661__16 271 | 2010_002676__16 272 | 2010_002742__16 273 | 2010_002830__16 274 | 2010_002982__16 275 | 2010_003017__16 276 | 2010_003101__16 277 | 2010_003103__16 278 | 2010_003203__16 279 | 2010_003218__16 280 | 2010_003390__16 281 | 2010_003556__16 282 | 2010_003651__16 283 | 2010_003667__16 284 | 2010_003674__16 285 | 2010_003728__16 286 | 2010_003729__16 287 | 2010_003910__16 288 | 2010_004062__16 289 | 2010_004072__16 290 | 2010_004224__16 291 | 2010_004358__16 292 | 2010_004466__16 293 | 2010_004683__16 294 | 2010_004908__16 295 | 2010_004910__16 296 | 2010_004944__16 297 | 2010_004974__16 298 | 2010_004982__16 299 | 2010_005158__16 300 | 2010_005274__16 301 | 2010_005279__16 302 | 2010_005455__16 303 | 2010_005536__16 304 | 2010_005593__16 305 | 2010_005758__16 306 | 2010_005830__16 307 | 2010_005930__16 308 | 2010_005932__16 309 | 2010_005975__16 310 | 2011_000072__16 311 | 2011_000145__16 312 | 2011_000196__16 313 | 2011_000361__16 314 | 2011_000388__16 315 | 2011_000468__16 316 | 2011_000514__16 317 | 2011_000530__16 318 | 2011_000572__16 319 | 2011_000731__16 320 | 2011_000743__16 321 | 2011_000823__16 322 | 2011_000875__16 323 | 2011_000885__16 324 | 2011_000919__16 325 | 2011_000934__16 326 | 2011_000957__16 327 | 2011_001009__16 328 | 2011_001011__16 329 | 2011_001022__16 330 | 2011_001034__16 331 | 2011_001055__16 332 | 2011_001221__16 333 | 2011_001226__16 334 | 2011_001360__16 335 | 2011_001369__16 336 | 2011_001382__16 337 | 2011_001440__16 338 | 2011_001456__16 339 | 2011_001600__16 340 | 2011_001611__16 341 | 2011_001689__16 342 | 2011_001766__16 343 | 2011_001820__16 344 | 2011_001845__16 345 | 2011_001946__16 346 | 2011_002022__16 347 | 2011_002031__16 348 | 2011_002318__16 349 | 2011_002386__16 350 | 2011_002443__16 351 | 2011_002614__16 352 | 2011_002808__16 353 | 2011_002810__16 354 | 2011_002924__16 355 | 2011_002978__16 356 | 2011_003002__16 357 | 2011_003047__16 358 | 2011_003162__16 359 | 2007_001416__17 360 | 2007_001872__17 361 | 2007_002845__17 362 | 2007_003190__17 363 | 2007_003593__17 364 | 2007_004423__17 365 | 2007_004768__17 366 | 2007_006136__17 367 | 2007_006832__17 368 | 2007_006899__17 369 | 2007_006944__17 370 | 2007_007048__17 371 | 2007_007230__17 372 | 2007_007621__17 373 | 2008_000084__17 374 | 2008_000099__17 375 | 2008_000669__17 376 | 2008_001601__17 377 | 2008_002061__17 378 | 2008_002150__17 379 | 2008_002343__17 380 | 2008_002430__17 381 | 2008_003147__17 382 | 2008_004007__17 383 | 2008_004629__17 384 | 2008_005447__17 385 | 2008_005494__17 386 | 2008_005505__17 387 | 2008_005635__17 388 | 2008_005706__17 389 | 2008_005736__17 390 | 2008_005938__17 391 | 2008_005987__17 392 | 2008_006059__17 393 | 2008_006070__17 394 | 2008_006100__17 395 | 2008_006221__17 396 | 2008_006339__17 397 | 2008_006477__17 398 | 2008_006570__17 399 | 2008_006892__17 400 | 2008_006939__17 401 | 2008_007069__17 402 | 2008_007070__17 403 | 2008_007245__17 404 | 2008_007334__17 405 | 2008_007430__17 406 | 2008_007693__17 407 | 2008_007806__17 408 | 2008_007890__17 409 | 2008_007909__17 410 | 2008_007985__17 411 | 2008_008109__17 412 | 2008_008319__17 413 | 2008_008322__17 414 | 2008_008323__17 415 | 2008_008601__17 416 | 2008_008613__17 417 | 2008_008623__17 418 | 2008_008665__17 419 | 2008_008666__17 420 | 2008_008714__17 421 | 2008_008744__17 422 | 2009_000168__17 423 | 2009_000223__17 424 | 2009_000289__17 425 | 2009_000356__17 426 | 2009_000670__17 427 | 2009_000725__17 428 | 2009_000750__17 429 | 2009_000837__17 430 | 2009_001172__17 431 | 2009_001177__17 432 | 2009_001203__17 433 | 2009_001236__17 434 | 2009_001263__17 435 | 2009_001349__17 436 | 2009_001403__17 437 | 2009_001537__17 438 | 2009_001618__17 439 | 2009_001643__17 440 | 2009_001699__17 441 | 2009_001732__17 442 | 2009_001738__17 443 | 2009_001783__17 444 | 2009_001959__17 445 | 2009_002133__17 446 | 2009_002245__17 447 | 2009_002282__17 448 | 2009_002391__17 449 | 2009_002719__17 450 | 2009_002921__17 451 | 2009_002988__17 452 | 2009_003076__17 453 | 2009_003249__17 454 | 2009_003254__17 455 | 2009_003271__17 456 | 2009_003425__17 457 | 2009_003430__17 458 | 2009_003460__17 459 | 2009_003541__17 460 | 2009_003618__17 461 | 2009_003624__17 462 | 2009_003784__17 463 | 2009_004164__17 464 | 2009_004171__17 465 | 2009_004181__17 466 | 2009_004222__17 467 | 2009_004513__17 468 | 2009_004547__17 469 | 2009_004706__17 470 | 2009_004768__17 471 | 2009_004805__17 472 | 2009_004834__17 473 | 2009_004943__17 474 | 2009_004945__17 475 | 2009_005005__17 476 | 2009_005193__17 477 | 2010_000002__17 478 | 2010_000052__17 479 | 2010_000089__17 480 | 2010_000117__17 481 | 2010_000139__17 482 | 2010_000189__17 483 | 2010_000190__17 484 | 2010_000307__17 485 | 2010_000327__17 486 | 2010_000390__17 487 | 2010_000436__17 488 | 2010_000483__17 489 | 2010_000527__17 490 | 2010_000562__17 491 | 2010_000641__17 492 | 2010_000667__17 493 | 2010_000727__17 494 | 2010_000735__17 495 | 2010_000822__17 496 | 2010_000831__17 497 | 2010_000847__17 498 | 2010_000866__17 499 | 2010_000920__17 500 | 2010_000970__17 501 | 2010_000978__17 502 | 2010_001076__17 503 | 2010_001082__17 504 | 2010_001160__17 505 | 2010_001175__17 506 | 2010_001245__17 507 | 2010_001257__17 508 | 2010_001286__17 509 | 2010_001356__17 510 | 2010_001607__17 511 | 2010_001746__17 512 | 2010_001796__17 513 | 2010_001881__17 514 | 2010_001987__17 515 | 2010_002039__17 516 | 2010_002041__17 517 | 2010_002058__17 518 | 2010_002113__17 519 | 2010_002124__17 520 | 2010_002130__17 521 | 2010_002176__17 522 | 2010_002215__17 523 | 2010_002294__17 524 | 2010_002346__17 525 | 2010_002353__17 526 | 2010_002378__17 527 | 2010_002501__17 528 | 2010_002507__17 529 | 2010_002518__17 530 | 2010_002582__17 531 | 2010_002624__17 532 | 2010_002628__17 533 | 2010_002665__17 534 | 2010_002705__17 535 | 2010_002736__17 536 | 2010_002737__17 537 | 2010_002821__17 538 | 2010_003013__17 539 | 2010_003028__17 540 | 2010_003074__17 541 | 2010_003094__17 542 | 2010_003102__17 543 | 2010_003153__17 544 | 2010_003253__17 545 | 2010_003343__17 546 | 2010_003372__17 547 | 2010_003376__17 548 | 2010_003429__17 549 | 2010_003491__17 550 | 2010_003567__17 551 | 2010_003725__17 552 | 2010_003742__17 553 | 2010_003754__17 554 | 2010_003761__17 555 | 2010_003774__17 556 | 2010_003792__17 557 | 2010_003806__17 558 | 2010_003865__17 559 | 2010_003919__17 560 | 2010_003939__17 561 | 2010_003954__17 562 | 2010_003996__17 563 | 2010_004061__17 564 | 2010_004065__17 565 | 2010_004067__17 566 | 2010_004074__17 567 | 2010_004089__17 568 | 2010_004105__17 569 | 2010_004188__17 570 | 2010_004225__17 571 | 2010_004259__17 572 | 2010_004332__17 573 | 2010_004428__17 574 | 2010_004431__17 575 | 2010_004436__17 576 | 2010_004499__17 577 | 2010_004514__17 578 | 2010_004560__17 579 | 2010_004629__17 580 | 2010_004659__17 581 | 2010_004694__17 582 | 2010_004704__17 583 | 2010_004710__17 584 | 2010_004812__17 585 | 2010_004868__17 586 | 2010_005002__17 587 | 2010_005026__17 588 | 2010_005066__17 589 | 2010_005098__17 590 | 2010_005120__17 591 | 2010_005183__17 592 | 2010_005260__17 593 | 2010_005285__17 594 | 2010_005310__17 595 | 2010_005385__17 596 | 2010_005416__17 597 | 2010_005466__17 598 | 2010_005514__17 599 | 2010_005519__17 600 | 2010_005566__17 601 | 2010_005567__17 602 | 2010_005601__17 603 | 2010_005635__17 604 | 2010_005688__17 605 | 2010_005736__17 606 | 2010_005740__17 607 | 2010_005767__17 608 | 2010_005986__17 609 | 2011_000102__17 610 | 2011_000332__17 611 | 2011_000404__17 612 | 2011_000450__17 613 | 2011_000454__17 614 | 2011_000641__17 615 | 2011_000759__17 616 | 2011_000829__17 617 | 2011_000834__17 618 | 2011_001240__17 619 | 2011_001246__17 620 | 2011_001329__17 621 | 2011_001373__17 622 | 2011_001549__17 623 | 2011_001757__17 624 | 2011_001822__17 625 | 2011_001986__17 626 | 2011_002027__17 627 | 2011_002119__17 628 | 2011_002169__17 629 | 2011_002447__17 630 | 2011_002464__17 631 | 2011_002553__17 632 | 2011_002571__17 633 | 2011_002817__17 634 | 2011_003023__17 635 | 2011_003223__17 636 | 2007_000584__18 637 | 2007_001027__18 638 | 2007_001149__18 639 | 2007_001901__18 640 | 2007_002055__18 641 | 2007_002368__18 642 | 2007_002545__18 643 | 2007_003451__18 644 | 2007_004166__18 645 | 2007_005212__18 646 | 2007_005266__18 647 | 2007_005647__18 648 | 2007_006066__18 649 | 2007_006530__18 650 | 2007_008203__18 651 | 2007_008468__18 652 | 2007_008821__18 653 | 2007_009435__18 654 | 2007_009554__18 655 | 2008_000093__18 656 | 2008_000128__18 657 | 2008_000321__18 658 | 2008_000419__18 659 | 2008_000421__18 660 | 2008_000465__18 661 | 2008_000493__18 662 | 2008_000541__18 663 | 2008_000636__18 664 | 2008_000648__18 665 | 2008_000704__18 666 | 2008_000857__18 667 | 2008_001030__18 668 | 2008_001092__18 669 | 2008_001133__18 670 | 2008_001238__18 671 | 2008_001333__18 672 | 2008_001366__18 673 | 2008_001390__18 674 | 2008_001399__18 675 | 2008_001461__18 676 | 2008_001589__18 677 | 2008_001660__18 678 | 2008_001694__18 679 | 2008_001781__18 680 | 2008_001787__18 681 | 2008_001838__18 682 | 2008_001869__18 683 | 2008_001896__18 684 | 2008_002082__18 685 | 2008_002092__18 686 | 2008_002119__18 687 | 2008_002434__18 688 | 2008_002508__18 689 | 2008_002533__18 690 | 2008_002776__18 691 | 2008_002801__18 692 | 2008_002916__18 693 | 2008_002920__18 694 | 2008_002922__18 695 | 2008_002948__18 696 | 2008_003271__18 697 | 2008_003393__18 698 | 2008_003562__18 699 | 2008_003607__18 700 | 2008_003814__18 701 | 2008_004269__18 702 | 2008_004271__18 703 | 2008_004321__18 704 | 2008_004416__18 705 | 2008_004435__18 706 | 2008_004492__18 707 | 2008_004497__18 708 | 2008_004632__18 709 | 2008_004661__18 710 | 2008_004670__18 711 | 2008_004697__18 712 | 2008_004774__18 713 | 2008_004881__18 714 | 2008_004887__18 715 | 2008_004938__18 716 | 2008_004964__18 717 | 2008_005090__18 718 | 2008_005323__18 719 | 2008_005395__18 720 | 2008_005444__18 721 | 2008_005623__18 722 | 2008_005627__18 723 | 2008_005788__18 724 | 2008_005850__18 725 | 2008_005882__18 726 | 2008_005926__18 727 | 2008_006038__18 728 | 2008_006117__18 729 | 2008_006276__18 730 | 2008_006370__18 731 | 2008_006389__18 732 | 2008_006436__18 733 | 2008_006616__18 734 | 2008_006665__18 735 | 2008_006737__18 736 | 2008_006773__18 737 | 2008_006843__18 738 | 2008_006868__18 739 | 2008_006979__18 740 | 2008_007021__18 741 | 2008_007043__18 742 | 2008_007050__18 743 | 2008_007169__18 744 | 2008_007182__18 745 | 2008_007218__18 746 | 2008_007282__18 747 | 2008_007285__18 748 | 2008_007511__18 749 | 2008_007682__18 750 | 2008_007733__18 751 | 2008_007837__18 752 | 2008_008029__18 753 | 2008_008106__18 754 | 2008_008162__18 755 | 2008_008190__18 756 | 2008_008206__18 757 | 2008_008271__18 758 | 2008_008276__18 759 | 2008_008313__18 760 | 2008_008410__18 761 | 2008_008433__18 762 | 2008_008470__18 763 | 2008_008517__18 764 | 2008_008522__18 765 | 2008_008526__18 766 | 2008_008538__18 767 | 2008_008550__18 768 | 2008_008554__18 769 | 2008_008560__18 770 | 2008_008567__18 771 | 2008_008574__18 772 | 2008_008578__18 773 | 2008_008588__18 774 | 2008_008590__18 775 | 2008_008606__18 776 | 2008_008608__18 777 | 2008_008621__18 778 | 2008_008622__18 779 | 2008_008628__18 780 | 2008_008642__18 781 | 2008_008649__18 782 | 2008_008658__18 783 | 2008_008772__18 784 | 2009_000014__18 785 | 2009_000016__18 786 | 2009_000142__18 787 | 2009_000189__18 788 | 2009_000217__18 789 | 2009_000251__18 790 | 2009_000300__18 791 | 2009_000316__18 792 | 2009_000342__18 793 | 2009_000375__18 794 | 2009_000379__18 795 | 2009_000416__18 796 | 2009_000422__18 797 | 2009_000449__18 798 | 2009_000474__18 799 | 2009_000505__18 800 | 2009_000563__18 801 | 2009_000577__18 802 | 2009_000615__18 803 | 2009_000653__18 804 | 2009_000672__18 805 | 2009_000674__18 806 | 2009_000779__18 807 | 2009_000925__18 808 | 2009_000926__18 809 | 2009_000937__18 810 | 2009_000939__18 811 | 2009_000973__18 812 | 2009_000995__18 813 | 2009_001021__18 814 | 2009_001081__18 815 | 2009_001107__18 816 | 2009_001146__18 817 | 2009_001190__18 818 | 2009_001212__18 819 | 2009_001241__18 820 | 2009_001243__18 821 | 2009_001249__18 822 | 2009_001268__18 823 | 2009_001313__18 824 | 2009_001343__18 825 | 2009_001357__18 826 | 2009_001376__18 827 | 2009_001437__18 828 | 2009_001440__18 829 | 2009_001446__18 830 | 2009_001470__18 831 | 2009_001577__18 832 | 2009_001581__18 833 | 2009_001605__18 834 | 2009_001608__18 835 | 2009_001631__18 836 | 2009_001719__18 837 | 2009_001743__18 838 | 2009_001746__18 839 | 2009_001774__18 840 | 2009_001871__18 841 | 2009_001874__18 842 | 2009_001888__18 843 | 2009_001906__18 844 | 2009_001908__18 845 | 2009_001961__18 846 | 2009_001980__18 847 | 2009_002083__18 848 | 2009_002192__18 849 | 2009_002208__18 850 | 2009_002253__18 851 | 2009_002325__18 852 | 2009_002370__18 853 | 2009_002408__18 854 | 2009_002522__18 855 | 2009_002523__18 856 | 2009_002558__18 857 | 2009_002611__18 858 | 2009_002612__18 859 | 2009_002663__18 860 | 2009_002673__18 861 | 2009_002681__18 862 | 2009_002683__18 863 | 2009_002713__18 864 | 2009_002717__18 865 | 2009_002893__18 866 | 2009_002972__18 867 | 2009_002998__18 868 | 2009_003087__18 869 | 2009_003129__18 870 | 2009_003156__18 871 | 2009_003208__18 872 | 2009_003373__18 873 | 2009_003377__18 874 | 2009_003394__18 875 | 2009_003409__18 876 | 2009_003441__18 877 | 2009_003459__18 878 | 2009_003581__18 879 | 2009_003605__18 880 | 2009_003613__18 881 | 2009_003642__18 882 | 2009_003646__18 883 | 2009_003656__18 884 | 2009_003671__18 885 | 2009_003695__18 886 | 2009_003711__18 887 | 2009_003785__18 888 | 2009_003795__18 889 | 2009_003819__18 890 | 2009_003835__18 891 | 2009_003843__18 892 | 2009_003848__18 893 | 2009_003965__18 894 | 2009_003966__18 895 | 2009_003995__18 896 | 2009_004073__18 897 | 2009_004076__18 898 | 2009_004088__18 899 | 2009_004091__18 900 | 2009_004161__18 901 | 2009_004165__18 902 | 2009_004177__18 903 | 2009_004180__18 904 | 2009_004188__18 905 | 2009_004193__18 906 | 2009_004264__18 907 | 2009_004283__18 908 | 2009_004291__18 909 | 2009_004301__18 910 | 2009_004322__18 911 | 2009_004419__18 912 | 2009_004426__18 913 | 2009_004449__18 914 | 2009_004452__18 915 | 2009_004456__18 916 | 2009_004457__18 917 | 2009_004464__18 918 | 2009_004560__18 919 | 2009_004582__18 920 | 2009_004588__18 921 | 2009_004593__18 922 | 2009_004674__18 923 | 2009_004701__18 924 | 2009_004718__18 925 | 2009_004782__18 926 | 2009_004823__18 927 | 2009_004839__18 928 | 2009_004901__18 929 | 2009_004983__18 930 | 2009_005070__18 931 | 2009_005240__18 932 | 2009_005299__18 933 | 2010_000018__18 934 | 2010_000097__18 935 | 2010_000329__18 936 | 2010_000344__18 937 | 2010_000588__18 938 | 2010_000671__18 939 | 2010_000691__18 940 | 2010_000694__18 941 | 2010_000821__18 942 | 2010_000830__18 943 | 2010_000922__18 944 | 2010_000968__18 945 | 2010_001051__18 946 | 2010_001066__18 947 | 2010_001111__18 948 | 2010_001127__18 949 | 2010_001148__18 950 | 2010_001189__18 951 | 2010_001277__18 952 | 2010_001287__18 953 | 2010_001434__18 954 | 2010_001514__18 955 | 2010_001547__18 956 | 2010_001586__18 957 | 2010_001636__18 958 | 2010_001743__18 959 | 2010_001763__18 960 | 2010_001933__18 961 | 2010_002015__18 962 | 2010_002045__18 963 | 2010_002191__18 964 | 2010_002193__18 965 | 2010_002312__18 966 | 2010_002337__18 967 | 2010_002461__18 968 | 2010_002527__18 969 | 2010_002659__18 970 | 2010_002710__18 971 | 2010_002811__18 972 | 2010_002817__18 973 | 2010_002860__18 974 | 2010_002962__18 975 | 2010_003027__18 976 | 2010_003114__18 977 | 2010_003143__18 978 | 2010_003149__18 979 | 2010_003169__18 980 | 2010_003174__18 981 | 2010_003203__18 982 | 2010_003278__18 983 | 2010_003305__18 984 | 2010_003331__18 985 | 2010_003401__18 986 | 2010_003451__18 987 | 2010_003556__18 988 | 2010_003613__18 989 | 2010_003717__18 990 | 2010_003804__18 991 | 2010_003822__18 992 | 2010_003861__18 993 | 2010_003864__18 994 | 2010_004025__18 995 | 2010_004043__18 996 | 2010_004062__18 997 | 2010_004095__18 998 | 2010_004109__18 999 | 2010_004111__18 1000 | 2010_004125__18 1001 | 2010_004358__18 1002 | 2010_004409__18 1003 | 2010_004447__18 1004 | 2010_004481__18 1005 | 2010_004741__18 1006 | 2010_004765__18 1007 | 2010_004805__18 1008 | 2010_005049__18 1009 | 2010_005054__18 1010 | 2010_005170__18 1011 | 2010_005193__18 1012 | 2010_005388__18 1013 | 2010_005398__18 1014 | 2010_005532__18 1015 | 2010_005610__18 1016 | 2010_005614__18 1017 | 2010_005681__18 1018 | 2010_005692__18 1019 | 2010_005734__18 1020 | 2010_005770__18 1021 | 2010_005830__18 1022 | 2010_005898__18 1023 | 2010_005937__18 1024 | 2010_005980__18 1025 | 2010_006056__18 1026 | 2010_006073__18 1027 | 2011_000006__18 1028 | 2011_000037__18 1029 | 2011_000082__18 1030 | 2011_000122__18 1031 | 2011_000142__18 1032 | 2011_000146__18 1033 | 2011_000182__18 1034 | 2011_000224__18 1035 | 2011_000304__18 1036 | 2011_000364__18 1037 | 2011_000379__18 1038 | 2011_000386__18 1039 | 2011_000399__18 1040 | 2011_000434__18 1041 | 2011_000457__18 1042 | 2011_000475__18 1043 | 2011_000477__18 1044 | 2011_000499__18 1045 | 2011_000550__18 1046 | 2011_000565__18 1047 | 2011_000572__18 1048 | 2011_000608__18 1049 | 2011_000630__18 1050 | 2011_000646__18 1051 | 2011_000657__18 1052 | 2011_000689__18 1053 | 2011_000765__18 1054 | 2011_000820__18 1055 | 2011_000947__18 1056 | 2011_001027__18 1057 | 2011_001031__18 1058 | 2011_001167__18 1059 | 2011_001175__18 1060 | 2011_001192__18 1061 | 2011_001198__18 1062 | 2011_001215__18 1063 | 2011_001283__18 1064 | 2011_001304__18 1065 | 2011_001330__18 1066 | 2011_001402__18 1067 | 2011_001404__18 1068 | 2011_001412__18 1069 | 2011_001440__18 1070 | 2011_001451__18 1071 | 2011_001518__18 1072 | 2011_001531__18 1073 | 2011_001547__18 1074 | 2011_001600__18 1075 | 2011_001662__18 1076 | 2011_001691__18 1077 | 2011_001733__18 1078 | 2011_001739__18 1079 | 2011_001751__18 1080 | 2011_001811__18 1081 | 2011_001820__18 1082 | 2011_001845__18 1083 | 2011_001856__18 1084 | 2011_001895__18 1085 | 2011_001914__18 1086 | 2011_001922__18 1087 | 2011_001932__18 1088 | 2011_001974__18 1089 | 2011_001977__18 1090 | 2011_001980__18 1091 | 2011_002109__18 1092 | 2011_002184__18 1093 | 2011_002186__18 1094 | 2011_002268__18 1095 | 2011_002291__18 1096 | 2011_002335__18 1097 | 2011_002359__18 1098 | 2011_002395__18 1099 | 2011_002414__18 1100 | 2011_002507__18 1101 | 2011_002554__18 1102 | 2011_002561__18 1103 | 2011_002594__18 1104 | 2011_002714__18 1105 | 2011_002726__18 1106 | 2011_002752__18 1107 | 2011_002756__18 1108 | 2011_002775__18 1109 | 2011_002784__18 1110 | 2011_002810__18 1111 | 2011_002814__18 1112 | 2011_002834__18 1113 | 2011_002852__18 1114 | 2011_002953__18 1115 | 2011_002965__18 1116 | 2011_003038__18 1117 | 2011_003039__18 1118 | 2011_003044__18 1119 | 2011_003049__18 1120 | 2011_003188__18 1121 | 2011_003201__18 1122 | 2011_003212__18 1123 | 2007_000333__19 1124 | 2007_002462__19 1125 | 2007_003178__19 1126 | 2007_003286__19 1127 | 2007_004627__19 1128 | 2007_004663__19 1129 | 2007_004951__19 1130 | 2007_005360__19 1131 | 2007_006254__19 1132 | 2007_006400__19 1133 | 2007_006803__19 1134 | 2007_007387__19 1135 | 2007_007726__19 1136 | 2007_007947__19 1137 | 2007_009436__19 1138 | 2007_009580__19 1139 | 2007_009597__19 1140 | 2007_009950__19 1141 | 2008_000003__19 1142 | 2008_000045__19 1143 | 2008_000343__19 1144 | 2008_000373__19 1145 | 2008_000470__19 1146 | 2008_000916__19 1147 | 2008_001105__19 1148 | 2008_001114__19 1149 | 2008_001118__19 1150 | 2008_001164__19 1151 | 2008_001169__19 1152 | 2008_001358__19 1153 | 2008_001625__19 1154 | 2008_001710__19 1155 | 2008_001850__19 1156 | 2008_001866__19 1157 | 2008_001905__19 1158 | 2008_001926__19 1159 | 2008_001956__19 1160 | 2008_002158__19 1161 | 2008_002193__19 1162 | 2008_002222__19 1163 | 2008_002279__19 1164 | 2008_002325__19 1165 | 2008_002344__19 1166 | 2008_002452__19 1167 | 2008_002457__19 1168 | 2008_002465__19 1169 | 2008_002965__19 1170 | 2008_003025__19 1171 | 2008_003068__19 1172 | 2008_003083__19 1173 | 2008_003263__19 1174 | 2008_003414__19 1175 | 2008_003571__19 1176 | 2008_003578__19 1177 | 2008_003826__19 1178 | 2008_003992__19 1179 | 2008_004110__19 1180 | 2008_004214__19 1181 | 2008_004235__19 1182 | 2008_004357__19 1183 | 2008_004358__19 1184 | 2008_004547__19 1185 | 2008_004663__19 1186 | 2008_004770__19 1187 | 2008_004852__19 1188 | 2008_004869__19 1189 | 2008_004946__19 1190 | 2008_005085__19 1191 | 2008_005185__19 1192 | 2008_005269__19 1193 | 2008_005282__19 1194 | 2008_005354__19 1195 | 2008_005446__19 1196 | 2008_005653__19 1197 | 2008_005742__19 1198 | 2008_005763__19 1199 | 2008_005801__19 1200 | 2008_005825__19 1201 | 2008_005968__19 1202 | 2008_006010__19 1203 | 2008_006158__19 1204 | 2008_006365__19 1205 | 2008_006368__19 1206 | 2008_006655__19 1207 | 2008_006818__19 1208 | 2008_006849__19 1209 | 2008_006865__19 1210 | 2008_006900__19 1211 | 2008_006919__19 1212 | 2008_007011__19 1213 | 2008_007084__19 1214 | 2008_007105__19 1215 | 2008_007115__19 1216 | 2008_007185__19 1217 | 2008_007189__19 1218 | 2008_007201__19 1219 | 2008_007231__19 1220 | 2008_007247__19 1221 | 2008_007280__19 1222 | 2008_007383__19 1223 | 2008_007521__19 1224 | 2008_007648__19 1225 | 2008_007749__19 1226 | 2008_007759__19 1227 | 2008_007760__19 1228 | 2008_007779__19 1229 | 2008_007787__19 1230 | 2008_007829__19 1231 | 2008_007887__19 1232 | 2008_007999__19 1233 | 2008_008001__19 1234 | 2008_008020__19 1235 | 2008_008055__19 1236 | 2008_008074__19 1237 | 2008_008123__19 1238 | 2008_008152__19 1239 | 2008_008192__19 1240 | 2008_008200__19 1241 | 2008_008203__19 1242 | 2008_008223__19 1243 | 2008_008275__19 1244 | 2008_008297__19 1245 | 2008_008302__19 1246 | 2008_008321__19 1247 | 2008_008336__19 1248 | 2008_008342__19 1249 | 2008_008364__19 1250 | 2008_008379__19 1251 | 2008_008382__19 1252 | 2008_008527__19 1253 | 2008_008545__19 1254 | 2008_008583__19 1255 | 2008_008615__19 1256 | 2008_008618__19 1257 | 2008_008632__19 1258 | 2008_008637__19 1259 | 2008_008641__19 1260 | 2008_008662__19 1261 | 2008_008673__19 1262 | 2008_008676__19 1263 | 2008_008681__19 1264 | 2008_008690__19 1265 | 2008_008696__19 1266 | 2008_008697__19 1267 | 2008_008726__19 1268 | 2008_008732__19 1269 | 2008_008735__19 1270 | 2008_008739__19 1271 | 2008_008749__19 1272 | 2008_008751__19 1273 | 2008_008757__19 1274 | 2008_008767__19 1275 | 2008_008770__19 1276 | 2009_000011__19 1277 | 2009_000051__19 1278 | 2009_000073__19 1279 | 2009_000090__19 1280 | 2009_000105__19 1281 | 2009_000137__19 1282 | 2009_000177__19 1283 | 2009_000244__19 1284 | 2009_000283__19 1285 | 2009_000347__19 1286 | 2009_000443__19 1287 | 2009_000476__19 1288 | 2009_000501__19 1289 | 2009_000592__19 1290 | 2009_000597__19 1291 | 2009_000658__19 1292 | 2009_000663__19 1293 | 2009_000689__19 1294 | 2009_000789__19 1295 | 2009_000824__19 1296 | 2009_000890__19 1297 | 2009_000910__19 1298 | 2009_000920__19 1299 | 2009_000974__19 1300 | 2009_001007__19 1301 | 2009_001042__19 1302 | 2009_001078__19 1303 | 2009_001118__19 1304 | 2009_001152__19 1305 | 2009_001164__19 1306 | 2009_001192__19 1307 | 2009_001245__19 1308 | 2009_001259__19 1309 | 2009_001291__19 1310 | 2009_001350__19 1311 | 2009_001359__19 1312 | 2009_001412__19 1313 | 2009_001468__19 1314 | 2009_001493__19 1315 | 2009_001516__19 1316 | 2009_001519__19 1317 | 2009_001534__19 1318 | 2009_001648__19 1319 | 2009_001651__19 1320 | 2009_001671__19 1321 | 2009_001735__19 1322 | 2009_001747__19 1323 | 2009_001802__19 1324 | 2009_001823__19 1325 | 2009_001831__19 1326 | 2009_001853__19 1327 | 2009_001865__19 1328 | 2009_001868__19 1329 | 2009_001904__19 1330 | 2009_001977__19 1331 | 2009_002009__19 1332 | 2009_002116__19 1333 | 2009_002144__19 1334 | 2009_002175__19 1335 | 2009_002197__19 1336 | 2009_002214__19 1337 | 2009_002219__19 1338 | 2009_002225__19 1339 | 2009_002274__19 1340 | 2009_002281__19 1341 | 2009_002377__19 1342 | 2009_002441__19 1343 | 2009_002557__19 1344 | 2009_002616__19 1345 | 2009_002624__19 1346 | 2009_002669__19 1347 | 2009_002676__19 1348 | 2009_002689__19 1349 | 2009_002695__19 1350 | 2009_002712__19 1351 | 2009_002725__19 1352 | 2009_002734__19 1353 | 2009_002774__19 1354 | 2009_002838__19 1355 | 2009_002867__19 1356 | 2009_002938__19 1357 | 2009_002947__19 1358 | 2009_003022__19 1359 | 2009_003054__19 1360 | 2009_003185__19 1361 | 2009_003230__19 1362 | 2009_003233__19 1363 | 2009_003333__19 1364 | 2009_003348__19 1365 | 2009_003407__19 1366 | 2009_003436__19 1367 | 2009_003453__19 1368 | 2009_003492__19 1369 | 2009_003497__19 1370 | 2009_003534__19 1371 | 2009_003543__19 1372 | 2009_003583__19 1373 | 2009_003638__19 1374 | 2009_003650__19 1375 | 2009_003758__19 1376 | 2009_003765__19 1377 | 2009_003790__19 1378 | 2009_003821__19 1379 | 2009_003863__19 1380 | 2009_003892__19 1381 | 2009_003942__19 1382 | 2009_003951__19 1383 | 2009_004019__19 1384 | 2009_004109__19 1385 | 2009_004159__19 1386 | 2009_004163__19 1387 | 2009_004170__19 1388 | 2009_004211__19 1389 | 2009_004213__19 1390 | 2009_004329__19 1391 | 2009_004336__19 1392 | 2009_004371__19 1393 | 2009_004406__19 1394 | 2009_004453__19 1395 | 2009_004468__19 1396 | 2009_004511__19 1397 | 2009_004527__19 1398 | 2009_004559__19 1399 | 2009_004619__19 1400 | 2009_004624__19 1401 | 2009_004669__19 1402 | 2009_004671__19 1403 | 2009_004677__19 1404 | 2009_004708__19 1405 | 2009_004766__19 1406 | 2009_004771__19 1407 | 2009_004804__19 1408 | 2009_004880__19 1409 | 2009_004956__19 1410 | 2009_004958__19 1411 | 2009_004977__19 1412 | 2009_004988__19 1413 | 2009_005024__19 1414 | 2009_005061__19 1415 | 2009_005084__19 1416 | 2009_005126__19 1417 | 2009_005128__19 1418 | 2009_005149__19 1419 | 2009_005246__19 1420 | 2009_005287__19 1421 | 2009_005292__19 1422 | 2009_005303__19 1423 | 2010_000080__19 1424 | 2010_000085__19 1425 | 2010_000136__19 1426 | 2010_000199__19 1427 | 2010_000233__19 1428 | 2010_000249__19 1429 | 2010_000313__19 1430 | 2010_000321__19 1431 | 2010_000406__19 1432 | 2010_000513__19 1433 | 2010_000537__19 1434 | 2010_000581__19 1435 | 2010_000633__19 1436 | 2010_000651__19 1437 | 2010_000740__19 1438 | 2010_000743__19 1439 | 2010_000786__19 1440 | 2010_000860__19 1441 | 2010_000865__19 1442 | 2010_000955__19 1443 | 2010_000959__19 1444 | 2010_000984__19 1445 | 2010_001052__19 1446 | 2010_001105__19 1447 | 2010_001143__19 1448 | 2010_001250__19 1449 | 2010_001272__19 1450 | 2010_001374__19 1451 | 2010_001395__19 1452 | 2010_001405__19 1453 | 2010_001408__19 1454 | 2010_001502__19 1455 | 2010_001515__19 1456 | 2010_001539__19 1457 | 2010_001560__19 1458 | 2010_001586__19 1459 | 2010_001625__19 1460 | 2010_001719__19 1461 | 2010_001748__19 1462 | 2010_001788__19 1463 | 2010_001801__19 1464 | 2010_001941__19 1465 | 2010_002073__19 1466 | 2010_002080__19 1467 | 2010_002179__19 1468 | 2010_002182__19 1469 | 2010_002208__19 1470 | 2010_002223__19 1471 | 2010_002261__19 1472 | 2010_002369__19 1473 | 2010_002420__19 1474 | 2010_002487__19 1475 | 2010_002556__19 1476 | 2010_002618__19 1477 | 2010_002667__19 1478 | 2010_002697__19 1479 | 2010_002722__19 1480 | 2010_002838__19 1481 | 2010_002840__19 1482 | 2010_002851__19 1483 | 2010_002896__19 1484 | 2010_002937__19 1485 | 2010_002946__19 1486 | 2010_003129__19 1487 | 2010_003160__19 1488 | 2010_003274__19 1489 | 2010_003335__19 1490 | 2010_003384__19 1491 | 2010_003470__19 1492 | 2010_003482__19 1493 | 2010_003601__19 1494 | 2010_003609__19 1495 | 2010_003628__19 1496 | 2010_003630__19 1497 | 2010_003788__19 1498 | 2010_003847__19 1499 | 2010_003900__19 1500 | 2010_003944__19 1501 | 2010_004075__19 1502 | 2010_004148__19 1503 | 2010_004168__19 1504 | 2010_004193__19 1505 | 2010_004256__19 1506 | 2010_004313__19 1507 | 2010_004350__19 1508 | 2010_004412__19 1509 | 2010_004469__19 1510 | 2010_004475__19 1511 | 2010_004478__19 1512 | 2010_004536__19 1513 | 2010_004604__19 1514 | 2010_004669__19 1515 | 2010_004677__19 1516 | 2010_004779__19 1517 | 2010_004826__19 1518 | 2010_005055__19 1519 | 2010_005130__19 1520 | 2010_005309__19 1521 | 2010_005463__19 1522 | 2010_005506__19 1523 | 2010_005515__19 1524 | 2010_005559__19 1525 | 2010_005565__19 1526 | 2010_005643__19 1527 | 2010_005768__19 1528 | 2010_005810__19 1529 | 2010_005816__19 1530 | 2010_005934__19 1531 | 2010_005996__19 1532 | 2011_000012__19 1533 | 2011_000058__19 1534 | 2011_000105__19 1535 | 2011_000195__19 1536 | 2011_000197__19 1537 | 2011_000210__19 1538 | 2011_000221__19 1539 | 2011_000241__19 1540 | 2011_000250__19 1541 | 2011_000277__19 1542 | 2011_000346__19 1543 | 2011_000398__19 1544 | 2011_000442__19 1545 | 2011_000491__19 1546 | 2011_000498__19 1547 | 2011_000513__19 1548 | 2011_000558__19 1549 | 2011_000627__19 1550 | 2011_000688__19 1551 | 2011_000819__19 1552 | 2011_000848__19 1553 | 2011_000858__19 1554 | 2011_000895__19 1555 | 2011_000909__19 1556 | 2011_000944__19 1557 | 2011_000979__19 1558 | 2011_000987__19 1559 | 2011_000997__19 1560 | 2011_001019__19 1561 | 2011_001052__19 1562 | 2011_001086__19 1563 | 2011_001126__19 1564 | 2011_001152__19 1565 | 2011_001166__19 1566 | 2011_001217__19 1567 | 2011_001337__19 1568 | 2011_001375__19 1569 | 2011_001381__19 1570 | 2011_001525__19 1571 | 2011_001560__19 1572 | 2011_001602__19 1573 | 2011_001655__19 1574 | 2011_001671__19 1575 | 2011_001741__19 1576 | 2011_001776__19 1577 | 2011_001796__19 1578 | 2011_001827__19 1579 | 2011_001889__19 1580 | 2011_001904__19 1581 | 2011_001927__19 1582 | 2011_001951__19 1583 | 2011_002053__19 1584 | 2011_002105__19 1585 | 2011_002173__19 1586 | 2011_002251__19 1587 | 2011_002278__19 1588 | 2011_002280__19 1589 | 2011_002294__19 1590 | 2011_002385__19 1591 | 2011_002396__19 1592 | 2011_002492__19 1593 | 2011_002516__19 1594 | 2011_002528__19 1595 | 2011_002636__19 1596 | 2011_002738__19 1597 | 2011_002779__19 1598 | 2011_002821__19 1599 | 2011_002917__19 1600 | 2011_002932__19 1601 | 2011_002987__19 1602 | 2011_003028__19 1603 | 2011_003059__19 1604 | 2011_003124__19 1605 | 2011_003132__19 1606 | 2011_003149__19 1607 | 2011_003187__19 1608 | 2011_003228__19 1609 | 2011_003260__19 1610 | 2011_003274__19 1611 | 2007_000039__20 1612 | 2007_000121__20 1613 | 2007_001027__20 1614 | 2007_001149__20 1615 | 2007_001704__20 1616 | 2007_002227__20 1617 | 2007_002953__20 1618 | 2007_003451__20 1619 | 2007_003604__20 1620 | 2007_005210__20 1621 | 2007_005902__20 1622 | 2007_006066__20 1623 | 2007_006704__20 1624 | 2007_007250__20 1625 | 2007_007432__20 1626 | 2007_007530__20 1627 | 2007_008407__20 1628 | 2007_008948__20 1629 | 2007_009216__20 1630 | 2007_009295__20 1631 | 2007_009594__20 1632 | 2008_000002__20 1633 | 2008_000023__20 1634 | 2008_000093__20 1635 | 2008_000145__20 1636 | 2008_000202__20 1637 | 2008_000244__20 1638 | 2008_000305__20 1639 | 2008_000309__20 1640 | 2008_000348__20 1641 | 2008_000383__20 1642 | 2008_000495__20 1643 | 2008_000566__20 1644 | 2008_000578__20 1645 | 2008_000904__20 1646 | 2008_001021__20 1647 | 2008_001073__20 1648 | 2008_001130__20 1649 | 2008_001401__20 1650 | 2008_001428__20 1651 | 2008_001481__20 1652 | 2008_001576__20 1653 | 2008_001641__20 1654 | 2008_001704__20 1655 | 2008_001781__20 1656 | 2008_001815__20 1657 | 2008_001880__20 1658 | 2008_001888__20 1659 | 2008_001896__20 1660 | 2008_001920__20 1661 | 2008_001997__20 1662 | 2008_002066__20 1663 | 2008_002082__20 1664 | 2008_002140__20 1665 | 2008_002218__20 1666 | 2008_002328__20 1667 | 2008_002547__20 1668 | 2008_002650__20 1669 | 2008_002676__20 1670 | 2008_002776__20 1671 | 2008_002817__20 1672 | 2008_002826__20 1673 | 2008_002831__20 1674 | 2008_002954__20 1675 | 2008_003200__20 1676 | 2008_003213__20 1677 | 2008_003248__20 1678 | 2008_003280__20 1679 | 2008_003348__20 1680 | 2008_003432__20 1681 | 2008_003434__20 1682 | 2008_003435__20 1683 | 2008_003466__20 1684 | 2008_003500__20 1685 | 2008_003585__20 1686 | 2008_003589__20 1687 | 2008_003609__20 1688 | 2008_003667__20 1689 | 2008_003712__20 1690 | 2008_003814__20 1691 | 2008_003825__20 1692 | 2008_003883__20 1693 | 2008_003948__20 1694 | 2008_003995__20 1695 | 2008_004004__20 1696 | 2008_004006__20 1697 | 2008_004008__20 1698 | 2008_004093__20 1699 | 2008_004097__20 1700 | 2008_004217__20 1701 | 2008_004259__20 1702 | 2008_004297__20 1703 | 2008_004301__20 1704 | 2008_004321__20 1705 | 2008_004330__20 1706 | 2008_004333__20 1707 | 2008_004501__20 1708 | 2008_004506__20 1709 | 2008_004526__20 1710 | 2008_004541__20 1711 | 2008_004550__20 1712 | 2008_004606__20 1713 | 2008_004719__20 1714 | 2008_004720__20 1715 | 2008_004781__20 1716 | 2008_004807__20 1717 | 2008_004881__20 1718 | 2008_004898__20 1719 | 2008_004908__20 1720 | 2008_004930__20 1721 | 2008_004961__20 1722 | 2008_005006__20 1723 | 2008_005008__20 1724 | 2008_005037__20 1725 | 2008_005064__20 1726 | 2008_005066__20 1727 | 2008_005090__20 1728 | 2008_005094__20 1729 | 2008_005191__20 1730 | 2008_005231__20 1731 | 2008_005255__20 1732 | 2008_005329__20 1733 | 2008_005342__20 1734 | 2008_005393__20 1735 | 2008_005569__20 1736 | 2008_005609__20 1737 | 2008_005625__20 1738 | 2008_005639__20 1739 | 2008_005660__20 1740 | 2008_005678__20 1741 | 2008_005732__20 1742 | 2008_005817__20 1743 | 2008_005877__20 1744 | 2008_005918__20 1745 | 2008_005929__20 1746 | 2008_005954__20 1747 | 2008_005957__20 1748 | 2008_005962__20 1749 | 2008_005967__20 1750 | 2008_005976__20 1751 | 2008_006031__20 1752 | 2008_006047__20 1753 | 2008_006062__20 1754 | 2008_006135__20 1755 | 2008_006136__20 1756 | 2008_006147__20 1757 | 2008_006233__20 1758 | 2008_006267__20 1759 | 2008_006271__20 1760 | 2008_006273__20 1761 | 2008_006288__20 1762 | 2008_006295__20 1763 | 2008_006366__20 1764 | 2008_006373__20 1765 | 2008_006409__20 1766 | 2008_006433__20 1767 | 2008_006591__20 1768 | 2008_006605__20 1769 | 2008_006606__20 1770 | 2008_006617__20 1771 | 2008_006624__20 1772 | 2008_006662__20 1773 | 2008_006668__20 1774 | 2008_006710__20 1775 | 2008_006719__20 1776 | 2008_006733__20 1777 | 2008_006946__20 1778 | 2008_007010__20 1779 | 2008_007038__20 1780 | 2008_007114__20 1781 | 2008_007196__20 1782 | 2008_007217__20 1783 | 2008_007242__20 1784 | 2008_007246__20 1785 | 2008_007324__20 1786 | 2008_007332__20 1787 | 2008_007361__20 1788 | 2008_007446__20 1789 | 2008_007476__20 1790 | 2008_007536__20 1791 | 2008_007561__20 1792 | 2008_007567__20 1793 | 2008_007685__20 1794 | 2008_007696__20 1795 | 2008_007798__20 1796 | 2008_007864__20 1797 | 2008_007916__20 1798 | 2008_007933__20 1799 | 2008_007962__20 1800 | 2008_007987__20 1801 | 2008_008269__20 1802 | 2008_008429__20 1803 | 2008_008439__20 1804 | 2008_008524__20 1805 | 2008_008590__20 1806 | 2008_008608__20 1807 | 2008_008649__20 1808 | 2009_000010__20 1809 | 2009_000014__20 1810 | 2009_000041__20 1811 | 2009_000157__20 1812 | 2009_000214__20 1813 | 2009_000216__20 1814 | 2009_000336__20 1815 | 2009_000398__20 1816 | 2009_000439__20 1817 | 2009_000444__20 1818 | 2009_000544__20 1819 | 2009_000549__20 1820 | 2009_000552__20 1821 | 2009_000585__20 1822 | 2009_000629__20 1823 | 2009_000679__20 1824 | 2009_000722__20 1825 | 2009_000791__20 1826 | 2009_000848__20 1827 | 2009_000895__20 1828 | 2009_000981__20 1829 | 2009_000987__20 1830 | 2009_001069__20 1831 | 2009_001103__20 1832 | 2009_001106__20 1833 | 2009_001111__20 1834 | 2009_001133__20 1835 | 2009_001188__20 1836 | 2009_001357__20 1837 | 2009_001393__20 1838 | 2009_001452__20 1839 | 2009_001526__20 1840 | 2009_001553__20 1841 | 2009_001555__20 1842 | 2009_001608__20 1843 | 2009_001615__20 1844 | 2009_001682__20 1845 | 2009_001779__20 1846 | 2009_001809__20 1847 | 2009_001812__20 1848 | 2009_001839__20 1849 | 2009_001852__20 1850 | 2009_001864__20 1851 | 2009_001874__20 1852 | 2009_001875__20 1853 | 2009_001961__20 1854 | 2009_001964__20 1855 | 2009_002110__20 1856 | 2009_002139__20 1857 | 2009_002232__20 1858 | 2009_002409__20 1859 | 2009_002537__20 1860 | 2009_002652__20 1861 | 2009_002663__20 1862 | 2009_002705__20 1863 | 2009_002733__20 1864 | 2009_002755__20 1865 | 2009_002758__20 1866 | 2009_002820__20 1867 | 2009_002827__20 1868 | 2009_002849__20 1869 | 2009_002872__20 1870 | 2009_002932__20 1871 | 2009_002967__20 1872 | 2009_002970__20 1873 | 2009_002984__20 1874 | 2009_002995__20 1875 | 2009_003078__20 1876 | 2009_003093__20 1877 | 2009_003140__20 1878 | 2009_003191__20 1879 | 2009_003204__20 1880 | 2009_003214__20 1881 | 2009_003316__20 1882 | 2009_003367__20 1883 | 2009_003537__20 1884 | 2009_003554__20 1885 | 2009_003646__20 1886 | 2009_003720__20 1887 | 2009_003753__20 1888 | 2009_003852__20 1889 | 2009_003920__20 1890 | 2009_004062__20 1891 | 2009_004128__20 1892 | 2009_004138__20 1893 | 2009_004176__20 1894 | 2009_004243__20 1895 | 2009_004301__20 1896 | 2009_004341__20 1897 | 2009_004357__20 1898 | 2009_004359__20 1899 | 2009_004478__20 1900 | 2009_004503__20 1901 | 2009_004519__20 1902 | 2009_004631__20 1903 | 2009_004719__20 1904 | 2009_004760__20 1905 | 2009_004763__20 1906 | 2009_004902__20 1907 | 2009_004905__20 1908 | 2009_004922__20 1909 | 2009_004965__20 1910 | 2009_005030__20 1911 | 2009_005042__20 1912 | 2009_005062__20 1913 | 2009_005070__20 1914 | 2009_005221__20 1915 | 2009_005240__20 1916 | 2009_005256__20 1917 | 2010_000053__20 1918 | 2010_000141__20 1919 | 2010_000291__20 1920 | 2010_000375__20 1921 | 2010_000379__20 1922 | 2010_000442__20 1923 | 2010_000449__20 1924 | 2010_000578__20 1925 | 2010_000658__20 1926 | 2010_000669__20 1927 | 2010_000705__20 1928 | 2010_000773__20 1929 | 2010_000787__20 1930 | 2010_000800__20 1931 | 2010_000807__20 1932 | 2010_000931__20 1933 | 2010_000944__20 1934 | 2010_000974__20 1935 | 2010_001099__20 1936 | 2010_001127__20 1937 | 2010_001270__20 1938 | 2010_001277__20 1939 | 2010_001363__20 1940 | 2010_001533__20 1941 | 2010_001562__20 1942 | 2010_001580__20 1943 | 2010_001690__20 1944 | 2010_001780__20 1945 | 2010_001860__20 1946 | 2010_001918__20 1947 | 2010_001939__20 1948 | 2010_002002__20 1949 | 2010_002015__20 1950 | 2010_002094__20 1951 | 2010_002097__20 1952 | 2010_002152__20 1953 | 2010_002167__20 1954 | 2010_002193__20 1955 | 2010_002245__20 1956 | 2010_002247__20 1957 | 2010_002327__20 1958 | 2010_002427__20 1959 | 2010_002513__20 1960 | 2010_002526__20 1961 | 2010_002561__20 1962 | 2010_002567__20 1963 | 2010_002586__20 1964 | 2010_002652__20 1965 | 2010_002686__20 1966 | 2010_002770__20 1967 | 2010_002791__20 1968 | 2010_002843__20 1969 | 2010_002982__20 1970 | 2010_003035__20 1971 | 2010_003103__20 1972 | 2010_003137__20 1973 | 2010_003236__20 1974 | 2010_003241__20 1975 | 2010_003287__20 1976 | 2010_003405__20 1977 | 2010_003437__20 1978 | 2010_003461__20 1979 | 2010_003674__20 1980 | 2010_003688__20 1981 | 2010_003719__20 1982 | 2010_003728__20 1983 | 2010_003770__20 1984 | 2010_003844__20 1985 | 2010_003857__20 1986 | 2010_003864__20 1987 | 2010_003874__20 1988 | 2010_003892__20 1989 | 2010_003942__20 1990 | 2010_004009__20 1991 | 2010_004050__20 1992 | 2010_004095__20 1993 | 2010_004102__20 1994 | 2010_004109__20 1995 | 2010_004137__20 1996 | 2010_004249__20 1997 | 2010_004254__20 1998 | 2010_004295__20 1999 | 2010_004306__20 2000 | 2010_004368__20 2001 | 2010_004460__20 2002 | 2010_004503__20 2003 | 2010_004523__20 2004 | 2010_004545__20 2005 | 2010_004586__20 2006 | 2010_004591__20 2007 | 2010_004816__20 2008 | 2010_004836__20 2009 | 2010_004944__20 2010 | 2010_004982__20 2011 | 2010_005049__20 2012 | 2010_005071__20 2013 | 2010_005133__20 2014 | 2010_005158__20 2015 | 2010_005190__20 2016 | 2010_005239__20 2017 | 2010_005345__20 2018 | 2010_005372__20 2019 | 2010_005450__20 2020 | 2010_005676__20 2021 | 2010_005678__20 2022 | 2010_005744__20 2023 | 2010_005805__20 2024 | 2010_005827__20 2025 | 2010_005841__20 2026 | 2010_006050__20 2027 | 2011_000009__20 2028 | 2011_000036__20 2029 | 2011_000037__20 2030 | 2011_000038__20 2031 | 2011_000061__20 2032 | 2011_000071__20 2033 | 2011_000077__20 2034 | 2011_000192__20 2035 | 2011_000253__20 2036 | 2011_000288__20 2037 | 2011_000290__20 2038 | 2011_000364__20 2039 | 2011_000382__20 2040 | 2011_000399__20 2041 | 2011_000400__20 2042 | 2011_000434__20 2043 | 2011_000444__20 2044 | 2011_000608__20 2045 | 2011_000685__20 2046 | 2011_000755__20 2047 | 2011_000965__20 2048 | 2011_001149__20 2049 | 2011_001223__20 2050 | 2011_001302__20 2051 | 2011_001357__20 2052 | 2011_001387__20 2053 | 2011_001394__20 2054 | 2011_001456__20 2055 | 2011_001466__20 2056 | 2011_001507__20 2057 | 2011_001573__20 2058 | 2011_001689__20 2059 | 2011_001705__20 2060 | 2011_001730__20 2061 | 2011_001833__20 2062 | 2011_001837__20 2063 | 2011_001856__20 2064 | 2011_001926__20 2065 | 2011_001928__20 2066 | 2011_002005__20 2067 | 2011_002154__20 2068 | 2011_002218__20 2069 | 2011_002292__20 2070 | 2011_002418__20 2071 | 2011_002462__20 2072 | 2011_002511__20 2073 | 2011_002514__20 2074 | 2011_002554__20 2075 | 2011_002560__20 2076 | 2011_002656__20 2077 | 2011_002756__20 2078 | 2011_002775__20 2079 | 2011_002802__20 2080 | 2011_002942__20 2081 | 2011_002966__20 2082 | 2011_002970__20 2083 | 2011_003047__20 2084 | 2011_003079__20 2085 | 2011_003194__20 2086 | 2011_003254__20 2087 | --------------------------------------------------------------------------------