├── .gitignore ├── .gitmodules ├── README.md ├── assets └── seqnet.jpg ├── datasets.py ├── download.sh ├── get_datasets.py ├── get_models.py ├── main.py ├── seqNet.py ├── structFiles ├── imageNamesFiles │ ├── msls_amman_database_imageNames.txt │ ├── msls_amman_query_imageNames.txt │ ├── msls_austin_database_imageNames.txt │ ├── msls_austin_query_imageNames.txt │ ├── msls_melbourne_database_imageNames.txt │ ├── msls_melbourne_query_imageNames.txt │ ├── nordland_clean_imageNames.txt │ ├── oxford_2014-12-16-18-44-24_imagenames_subsampled-2m.txt │ └── oxford_2015-03-17-11-08-44_imagenames_subsampled-2m.txt ├── msls_amman_d-20_d2-5.db ├── msls_austin_d-20_d2-5.db ├── msls_melbourne_d-20_d2-5.db ├── nordland_test_d-1_d2-1.db ├── nordland_train_d-40_d2-10.db ├── nordland_val_d-1_d2-1.db ├── oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-25_d2-5.db ├── oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db ├── oxford-pnv_splitInds.npz ├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db ├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db ├── oxford-v1.0_splitInds.npz └── seqBoundsFiles │ ├── msls_amman_database_seqBounds.txt │ ├── msls_amman_query_seqBounds.txt │ ├── msls_austin_database_seqBounds.txt │ ├── msls_austin_query_seqBounds.txt │ ├── msls_melbourne_database_seqBounds.txt │ └── msls_melbourne_query_seqBounds.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb* 2 | *__pycache__/* 3 | data/ 4 | 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/Patch-NetVLAD"] 2 | path = thirdparty/Patch-NetVLAD 3 | url = https://github.com/oravus/Patch-NetVLAD.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition 2 | 3 | [[ArXiv+Supplementary](https://arxiv.org/abs/2102.11603)] [[IEEE Xplore RA-L 2021](https://ieeexplore.ieee.org/abstract/document/9382076/)] [[ICRA 2021 YouTube Video](https://www.youtube.com/watch?v=KYw7RhDfxY0)] 4 | 5 | **and** 6 | 7 | # SeqNetVLAD vs PointNetVLAD: Image Sequence vs 3D Point Clouds for Day-Night Place Recognition 8 | 9 | [[ArXiv](https://arxiv.org/abs/2106.11481)] [[CVPR 2021 Workshop 3DVR](https://sites.google.com/view/cvpr2021-3d-vision-robotics/)] 10 | 11 |

12 | 13 |
Sequence-Based Hierarchical Visual Place Recognition. 14 |

15 | 16 | ## News: 17 | **Jan 27, 2024** : Download all pretrained models from [here](https://universityofadelaide.box.com/s/mp45yapl0j0by6aijf5kj8obt8ky0swk), Nordland dataset from [here](https://universityofadelaide.box.com/s/zkfk1akpbo5318fzqmtvlpp7030ex4up) and precomputed descriptors from [here](https://universityofadelaide.box.com/s/p8uh5yncsaxk7g8lwr8pihnwkqbc2pkf) 18 | 19 | **Jan 18, 2022** : MSLS training setup included. 20 | 21 | **Jan 07, 2022** : Single Image Vanilla NetVLAD feature extraction enabled. 22 | 23 | **Oct 13, 2021** : ~~Oxford & Brisbane Day-Night pretrained models [download link](https://cloudstor.aarnet.edu.au/plus/s/wx0zIGi3WBTtq5F).~~ (use the latest link provided above) 24 | 25 | **Aug 03, 2021** : Added Oxford dataset files ~~and a [direct link](https://cloudstor.aarnet.edu.au/plus/s/8L7loyTZjK0FsWT) to download the Nordland dataset.~~ (use the latest link provided above) 26 | 27 | **Jun 23, 2021**: CVPR 2021 Workshop 3DVR paper, "SeqNetVLAD vs PointNetVLAD", now available on [arXiv](https://arxiv.org/abs/2106.11481). 28 | 29 | ## Setup 30 | ### Conda 31 | ```bash 32 | conda create -n seqnet numpy pytorch=1.8.0 torchvision tqdm scikit-learn faiss tensorboardx h5py -c pytorch -c conda-forge 33 | ``` 34 | 35 | ### Download 36 | ~~Run `bash download.sh` to download single image NetVLAD descriptors (3.4 GB) for the Nordland-clean dataset [[a]](#nordclean) and the Oxford dataset (0.3 GB), and Nordland-trained model files (1.5 GB) [[b]](#saveLoc). Other pre-trained models for Oxford and Brisbane Day-Night can be downloaded from [here](https://universityofadelaide.box.com/s/mp45yapl0j0by6aijf5kj8obt8ky0swk).~~ [Please see download links at the top news from 27 Jan 2024] 37 | 38 | ## Run 39 | 40 | ### Train 41 | To train sequential descriptors through SeqNet on the Nordland dataset: 42 | ```python 43 | python main.py --mode train --pooling seqnet --dataset nordland-sw --seqL 10 --w 5 --outDims 4096 --expName "w5" 44 | ``` 45 | or the Oxford dataset (set `--dataset oxford-pnv` for pointnetvlad-like data split as described in the [CVPR 2021 Workshop paper](https://arxiv.org/abs/2106.11481)): 46 | ```python 47 | python main.py --mode train --pooling seqnet --dataset oxford-v1.0 --seqL 5 --w 3 --outDims 4096 --expName "w3" 48 | ``` 49 | or the MSLS dataset (specifying `--msls_trainCity` and `--msls_valCity` as default values): 50 | ```python 51 | python main.py --mode train --pooling seqnet --dataset msls --msls_trainCity melbourne --msls_valCity austin --seqL 5 --w 3 --outDims 4096 --expName "msls_w3" 52 | ``` 53 | 54 | To train transformed single descriptors through SeqNet: 55 | ```python 56 | python main.py --mode train --pooling seqnet --dataset nordland-sw --seqL 1 --w 1 --outDims 4096 --expName "w1" 57 | ``` 58 | 59 | ### Test 60 | On the Nordland dataset: 61 | ```python 62 | python main.py --mode test --pooling seqnet --dataset nordland-sf --seqL 5 --split test --resume ./data/runs/Jun03_15-22-44_l10_w5/ 63 | ``` 64 | On the MSLS dataset (can change `--msls_valCity` to `melbourne` or `austin` too): 65 | ```python 66 | python main.py --mode test --pooling seqnet --dataset msls --msls_valCity amman --seqL 5 --split test --resume ./data/runs// 67 | ``` 68 | 69 | The above will reproduce results for SeqNet (S5) as per [Supp. Table III on Page 10](https://arxiv.org/pdf/2102.11603.pdf). 70 | 71 |
72 | [Expand this] To obtain other results from the same table in the paper, expand this. 73 | 74 | ```python 75 | # Raw Single (NetVLAD) Descriptor 76 | python main.py --mode test --pooling single --dataset nordland-sf --seqL 1 --split test 77 | 78 | # SeqNet (S1) 79 | python main.py --mode test --pooling seqnet --dataset nordland-sf --seqL 1 --split test --resume ./data/runs/Jun03_15-07-46_l1_w1/ 80 | 81 | # Raw + Smoothing 82 | python main.py --mode test --pooling smooth --dataset nordland-sf --seqL 5 --split test 83 | 84 | # Raw + Delta 85 | python main.py --mode test --pooling delta --dataset nordland-sf --seqL 5 --split test 86 | 87 | # Raw + SeqMatch 88 | python main.py --mode test --pooling single+seqmatch --dataset nordland-sf --seqL 5 --split test 89 | 90 | # SeqNet (S1) + SeqMatch 91 | python main.py --mode test --pooling s1+seqmatch --dataset nordland-sf --seqL 5 --split test --resume ./data/runs/Jun03_15-07-46_l1_w1/ 92 | 93 | # HVPR (S5 to S1) 94 | # Run S5 first and save its predictions by specifying `resultsPath` 95 | python main.py --mode test --pooling seqnet --dataset nordland-sf --seqL 5 --split test --resume ./data/runs/Jun03_15-22-44_l10_w5/ --resultsPath ./data/results/ 96 | # Now run S1 + SeqMatch using results from above (the timestamp of `predictionsFile` would be different in your case) 97 | python main.py --mode test --pooling s1+seqmatch --dataset nordland-sf --seqL 5 --split test --resume ./data/runs/Jun03_15-07-46_l1_w1/ --predictionsFile ./data/results/Jun03_16-07-36_l5_0.npz 98 | 99 | ``` 100 |
101 | 102 | ### Single Image Vanilla NetVLAD Extraction 103 |
104 | [Expand this] To obtain the single image vanilla NetVLAD descriptors (i.e. the provided precomputed .npy descriptors) 105 | 106 | ```bash 107 | # Setup Patch-NetVLAD submodule from the seqNet repo: 108 | cd seqNet 109 | git submodule update --init 110 | 111 | # Download NetVLAD+PCA model 112 | cd thirdparty/Patch-NetVLAD/patchnetvlad/pretrained_models 113 | wget -O pitts_orig_WPCA4096.pth.tar https://huggingface.co/TobiasRobotics/Patch-NetVLAD/resolve/main/pitts_WPCA4096.pth.tar?download=true 114 | 115 | # Compute global descriptors 116 | cd ../../../Patch-NetVLAD/ 117 | python feature_extract.py --config_path patchnetvlad/configs/seqnet.ini --dataset_file_path ../../structFiles/imageNamesFiles/oxford_2014-12-16-18-44-24_imagenames_subsampled-2m.txt --dataset_root_dir --output_features_fullpath ../../data/descData/netvlad-pytorch/oxford_2014-12-16-18-44-24_stereo_left.npy 118 | 119 | # example for MSLS (replace 'database' with 'query' and use different city names to compute all) 120 | python feature_extract.py --config_path patchnetvlad/configs/seqnet.ini --dataset_file_path ../../structFiles/imageNamesFiles/msls_melbourne_database_imageNames.txt --dataset_root_dir --output_features_fullpath ../../data/descData/netvlad-pytorch/msls_melbourne_database.npy 121 | ``` 122 |
123 | 124 | ## Acknowledgement 125 | The code in this repository is based on [Nanne/pytorch-NetVlad](https://github.com/Nanne/pytorch-NetVlad). Thanks to [Tobias Fischer](https://github.com/Tobias-Fischer) for his contributions to this code during the development of our project [QVPR/Patch-NetVLAD](https://github.com/QVPR/Patch-NetVLAD). 126 | 127 | ## Citation 128 | ``` 129 | @article{garg2021seqnet, 130 | title={SeqNet: Learning Descriptors for Sequence-based Hierarchical Place Recognition}, 131 | author={Garg, Sourav and Milford, Michael}, 132 | journal={IEEE Robotics and Automation Letters}, 133 | volume={6}, 134 | number={3}, 135 | pages={4305-4312}, 136 | year={2021}, 137 | publisher={IEEE}, 138 | doi={10.1109/LRA.2021.3067633} 139 | } 140 | 141 | @misc{garg2021seqnetvlad, 142 | title={SeqNetVLAD vs PointNetVLAD: Image Sequence vs 3D Point Clouds for Day-Night Place Recognition}, 143 | author={Garg, Sourav and Milford, Michael}, 144 | howpublished={CVPR 2021 Workshop on 3D Vision and Robotics (3DVR)}, 145 | month={Jun}, 146 | year={2021}, 147 | } 148 | ``` 149 | 150 | #### Other Related Projects 151 | [SeqMatchNet (2021)](https://github.com/oravus/SeqMatchNet); 152 | [Patch-NetVLAD (2021)](https://github.com/QVPR/Patch-NetVLAD); 153 | [Delta Descriptors (2020)](https://github.com/oravus/DeltaDescriptors); 154 | [CoarseHash (2020)](https://github.com/oravus/CoarseHash); 155 | [seq2single (2019)](https://github.com/oravus/seq2single); 156 | [LoST (2018)](https://github.com/oravus/lostX) 157 | 158 | [a] This is the clean version of the dataset that excludes images from the tunnels and red lights and can be downloaded from [here](https://universityofadelaide.app.box.com/s/zkfk1akpbo5318fzqmtvlpp7030ex4up). 159 | 160 | [b] These will automatically save to `./data/`, you can modify this path in [download.sh](https://github.com/oravus/seqNet/blob/main/download.sh) and [get_datasets.py](https://github.com/oravus/seqNet/blob/5450829c4294fe1d14966bfa1ac9b7c93237369b/get_datasets.py#L6) to specify your workdir. 161 | -------------------------------------------------------------------------------- /assets/seqnet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/assets/seqnet.jpg -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import itertools 4 | 5 | import os 6 | from os.path import join, exists 7 | from scipy.io import loadmat, savemat 8 | import numpy as np 9 | from collections import namedtuple 10 | 11 | from sklearn.neighbors import NearestNeighbors 12 | import faiss 13 | import h5py 14 | 15 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 16 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 17 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr', 18 | 'dbTimeStamp', 'qTimeStamp', 'gpsDb', 'gpsQ']) 19 | 20 | class Dataset(): 21 | def __init__(self, dataset_name, train_mat_file, test_mat_file, val_mat_file, opt): 22 | self.dataset_name = dataset_name 23 | self.train_mat_file = train_mat_file 24 | self.test_mat_file = test_mat_file 25 | self.val_mat_file = val_mat_file 26 | self.struct_dir = "./structFiles/" 27 | self.seqL = opt.seqL 28 | self.seqL_filterData = opt.seqL_filterData 29 | 30 | # descriptor settings 31 | self.dbDescs = None 32 | self.qDescs = None 33 | self.trainInds = None 34 | self.valInds = None 35 | self.testInds = None 36 | self.db_seqBounds = None 37 | self.q_seqBounds = None 38 | 39 | def loadPreComputedDescriptors(self,ft1,ft2,seqBounds=None): 40 | self.dbDescs = ft1 41 | self.qDescs = ft2 42 | print("All Db descs: ", self.dbDescs.shape) 43 | print("All Qry descs: ", self.qDescs.shape) 44 | if seqBounds is None: 45 | self.db_seqBounds = None 46 | self.q_seqBounds = None 47 | else: 48 | self.db_seqBounds = seqBounds[0] 49 | self.q_seqBounds = seqBounds[1] 50 | return self.dbDescs.shape[1] 51 | 52 | def get_whole_training_set(self, onlyDB=False): 53 | structFile = join(self.struct_dir, self.train_mat_file) 54 | indsSplit = self.trainInds 55 | return WholeDatasetFromStruct( structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, onlyDB=onlyDB, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData) 56 | 57 | def get_whole_val_set(self): 58 | structFile = join(self.struct_dir, self.val_mat_file) 59 | indsSplit = self.valInds 60 | if self.seqL_filterData is None and self.dataset_name == 'msls': 61 | self.seqL_filterData = self.seqL 62 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData) 63 | 64 | def get_whole_test_set(self): 65 | if self.test_mat_file is not None: 66 | structFile = join(self.struct_dir, self.test_mat_file) 67 | indsSplit = self.testInds 68 | if self.seqL_filterData is None and self.dataset_name == 'msls': 69 | self.seqL_filterData = self.seqL 70 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData) 71 | else: 72 | raise ValueError('test set not available for dataset ' + self.dataset_name) 73 | 74 | def get_training_query_set(self, margin=0.1, nNegSample=1000, use_regions=False): 75 | structFile = join(self.struct_dir, self.train_mat_file) 76 | indsSplit = self.trainInds 77 | return QueryDatasetFromStruct(structFile,indsSplit, self.dbDescs, self.qDescs, nNegSample=nNegSample, margin=margin,use_regions=use_regions, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds]) 78 | 79 | def get_val_query_set(self): 80 | structFile = join(self.struct_dir, self.val_mat_file) 81 | indsSplit = self.valInds 82 | return QueryDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds]) 83 | 84 | @staticmethod 85 | def collate_fn(batch): 86 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 87 | 88 | Args: 89 | batch: list of tuple (query, positive, negatives). 90 | - query: torch tensor of shape (T, C). e.g. (5,4096) 91 | - positive: torch tensor of shape (T, C). 92 | - negative: torch tensor of shape (N, T, C). 93 | Returns: 94 | query: torch tensor of shape (batch_size, T, C). 95 | positive: torch tensor of shape (batch_size, T, C). 96 | negatives: torch tensor of shape (batch_size, T, C). 97 | """ 98 | 99 | batch = list(filter(lambda x: x is not None, batch)) 100 | if len(batch) == 0: 101 | return None, None, None, None, None 102 | 103 | query, positive, negatives, indices = zip(*batch) 104 | 105 | query = data.dataloader.default_collate(query) 106 | positive = data.dataloader.default_collate(positive) 107 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 108 | negatives = torch.cat(negatives, 0) 109 | indices = list(itertools.chain(*indices)) 110 | 111 | return query, positive, negatives, negCounts, indices 112 | 113 | def getSeqInds(idx,seqL,maxNum,minNum=0,retLenDiff=False): 114 | seqLOrig = seqL 115 | seqInds = np.arange(max(minNum,idx-seqL//2),min(idx+seqL-seqL//2,maxNum),1) 116 | lenDiff = seqLOrig - len(seqInds) 117 | if retLenDiff: 118 | return lenDiff 119 | 120 | if seqInds[0] == minNum: 121 | seqInds = np.concatenate([seqInds,np.arange(seqInds[-1]+1,seqInds[-1]+1+lenDiff,1)]) 122 | elif lenDiff > 0 and seqInds[-1] in range(maxNum-1,maxNum): 123 | seqInds = np.concatenate([np.arange(seqInds[0]-lenDiff,seqInds[0],1),seqInds]) 124 | return seqInds 125 | 126 | def getValidSeqInds(seqBounds,seqL): 127 | validFlags = [] 128 | for i in range(len(seqBounds)): 129 | sIdMin, sIdMax = seqBounds[i] 130 | lenDiff = getSeqInds(i,seqL,sIdMax,minNum=sIdMin,retLenDiff=True) 131 | validFlags.append(True if lenDiff == 0 else False) 132 | return validFlags 133 | 134 | def parse_db_struct(path): 135 | mat = loadmat(path) 136 | 137 | fieldnames = list(mat['dbStruct'][0, 0].dtype.names) 138 | 139 | dataset = mat['dbStruct'][0, 0]['dataset'].item() 140 | whichSet = mat['dbStruct'][0, 0]['whichSet'].item() 141 | 142 | dbImage = [f[0].item() for f in mat['dbStruct'][0, 0]['dbImageFns']] 143 | qImage = [f[0].item() for f in mat['dbStruct'][0, 0]['qImageFns']] 144 | 145 | numDb = mat['dbStruct'][0, 0]['numImages'].item() 146 | numQ = mat['dbStruct'][0, 0]['numQueries'].item() 147 | 148 | posDistThr = mat['dbStruct'][0, 0]['posDistThr'].item() 149 | posDistSqThr = mat['dbStruct'][0, 0]['posDistSqThr'].item() 150 | if 'nonTrivPosDistSqThr' in fieldnames: 151 | nonTrivPosDistSqThr = mat['dbStruct'][0, 0]['nonTrivPosDistSqThr'].item() 152 | else: 153 | nonTrivPosDistSqThr = None 154 | 155 | if 'dbTimeStamp' in fieldnames and 'qTimeStamp' in fieldnames: 156 | dbTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['dbTimeStamp'].T] 157 | qTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['qTimeStamp'].T] 158 | dbTimeStamp = np.array(dbTimeStamp) 159 | qTimeStamp = np.array(qTimeStamp) 160 | else: 161 | dbTimeStamp = None 162 | qTimeStamp = None 163 | 164 | if 'utmQ' in fieldnames and 'utmDb' in fieldnames: 165 | utmDb = mat['dbStruct'][0, 0]['utmDb'].T 166 | utmQ = mat['dbStruct'][0, 0]['utmQ'].T 167 | else: 168 | utmQ = None 169 | utmDb = None 170 | 171 | if 'gpsQ' in fieldnames and 'gpsDb' in fieldnames: 172 | gpsDb = mat['dbStruct'][0, 0]['gpsDb'].T 173 | gpsQ = mat['dbStruct'][0, 0]['gpsQ'].T 174 | else: 175 | gpsQ = None 176 | gpsDb = None 177 | 178 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, 179 | posDistSqThr, nonTrivPosDistSqThr, dbTimeStamp, qTimeStamp, gpsQ, gpsDb) 180 | 181 | 182 | def save_db_struct(path, db_struct): 183 | assert db_struct.numDb == len(db_struct.dbImage) 184 | assert db_struct.numQ == len(db_struct.qImage) 185 | 186 | inner_dict = { 187 | 'whichSet': db_struct.whichSet, 188 | 'dbImageFns': np.array(db_struct.dbImage, dtype=np.object).reshape(-1, 1), 189 | 'qImageFns': np.array(db_struct.qImage, dtype=np.object).reshape(-1, 1), 190 | 'numImages': db_struct.numDb, 191 | 'numQueries': db_struct.numQ, 192 | 'posDistThr': db_struct.posDistThr, 193 | 'posDistSqThr': db_struct.posDistSqThr, 194 | } 195 | 196 | if db_struct.dataset is not None: 197 | inner_dict['dataset'] = db_struct.dataset 198 | 199 | if db_struct.nonTrivPosDistSqThr is not None: 200 | inner_dict['nonTrivPosDistSqThr'] = db_struct.nonTrivPosDistSqThr 201 | 202 | if db_struct.utmDb is not None and db_struct.utmQ is not None: 203 | assert db_struct.numDb == len(db_struct.utmDb) 204 | assert db_struct.numQ == len(db_struct.utmQ) 205 | inner_dict['utmDb'] = db_struct.utmDb.T 206 | inner_dict['utmQ'] = db_struct.utmQ.T 207 | 208 | if db_struct.gpsDb is not None and db_struct.gpsQ is not None: 209 | assert db_struct.numDb == len(db_struct.gpsDb) 210 | assert db_struct.numQ == len(db_struct.gpsQ) 211 | inner_dict['gpsDb'] = db_struct.gpsDb.T 212 | inner_dict['gpsQ'] = db_struct.gpsQ.T 213 | 214 | if db_struct.dbTimeStamp is not None and db_struct.qTimeStamp is not None: 215 | inner_dict['dbTimeStamp'] = db_struct.dbTimeStamp.astype(np.float64) 216 | inner_dict['qTimeStamp'] = db_struct.qTimeStamp.astype(np.float64) 217 | 218 | savemat(path, {'dbStruct': inner_dict}) 219 | 220 | def print_db_concise(db): 221 | [print('\033[1m' + k + '\033[0m', v[:10] if type(v) is list else v) for k,v in db._asdict().items()] 222 | 223 | class WholeDatasetFromStruct(data.Dataset): 224 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, onlyDB=False, seqL=1, seqBounds=None,seqL_filterData=None): 225 | super().__init__() 226 | 227 | self.seqL = seqL 228 | self.filterBoundaryInds = False if seqL_filterData is None else True 229 | 230 | self.dbStruct = parse_db_struct(structFile) 231 | 232 | self.images = dbDescs[indsSplit[0]] 233 | 234 | if seqBounds[0] is None: 235 | self.seqBounds = np.array([[0,len(self.images)] for _ in range(len(self.images))]) 236 | 237 | if not onlyDB: 238 | qImages = qDescs[indsSplit[1]] 239 | self.images = np.concatenate([self.images,qImages],0) 240 | if seqBounds[0] is None: 241 | q_seqBounds = np.array([[len(self.seqBounds),len(self.images)] for _ in range(len(qImages))]) 242 | self.seqBounds = np.vstack([self.seqBounds,q_seqBounds]) 243 | 244 | if seqBounds[0] is not None: 245 | db_seqBounds = seqBounds[0][indsSplit[0]] 246 | q_seqBounds = db_seqBounds[-1,-1] + seqBounds[1][indsSplit[1]] 247 | self.seqBounds = np.vstack([db_seqBounds,q_seqBounds]) 248 | 249 | self.validInds = np.arange(len(self.images)) 250 | self.validInds_db = np.arange(self.dbStruct.numDb) 251 | self.validInds_q = np.arange(self.dbStruct.numQ) 252 | if self.filterBoundaryInds: 253 | validFlags = getValidSeqInds(self.seqBounds,seqL_filterData) 254 | self.validInds = np.argwhere(validFlags).flatten() 255 | self.validInds_db = np.argwhere(validFlags[:self.dbStruct.numDb]).flatten() 256 | self.validInds_q = np.argwhere(validFlags[self.dbStruct.numDb:]).flatten() 257 | self.dbStruct = self.dbStruct._replace(utmDb=self.dbStruct.utmDb[self.validInds_db], numDb=len(self.validInds_db), utmQ=self.dbStruct.utmQ[self.validInds_q], numQ=len(self.validInds_q)) 258 | 259 | self.whichSet = self.dbStruct.whichSet 260 | self.dataset = self.dbStruct.dataset 261 | 262 | self.positives = None 263 | self.distances = None 264 | 265 | def __getitem__(self, index): 266 | origIndex = index 267 | index = self.validInds[index] 268 | sIdMin, sIdMax = self.seqBounds[index] 269 | img = self.images[getSeqInds(index,self.seqL,sIdMax,minNum=sIdMin)] 270 | 271 | return img, origIndex 272 | 273 | def __len__(self): 274 | return len(self.validInds) 275 | 276 | def get_positives(self,retDists=False): 277 | # positives for evaluation are those within trivial threshold range 278 | # fit NN to find them, search by radius 279 | if self.positives is None: 280 | knn = NearestNeighbors(n_jobs=-1) 281 | knn.fit(self.dbStruct.utmDb) 282 | 283 | print("Using Localization Radius: ", self.dbStruct.posDistThr) 284 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr) 285 | 286 | if retDists: 287 | return self.positives, self.distances 288 | else: 289 | return self.positives 290 | 291 | 292 | class QueryDatasetFromStruct(data.Dataset): 293 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, nNegSample=1000, nNeg=10, margin=0.1, use_regions=False, seqL=1, seqBounds=None): 294 | super().__init__() 295 | 296 | self.seqL = seqL 297 | 298 | self.dbDescs = dbDescs[indsSplit[0]] 299 | self.qDescs = qDescs[indsSplit[1]] 300 | 301 | self.margin = margin 302 | 303 | self.dbStruct = parse_db_struct(structFile) 304 | 305 | if seqBounds[0] is None: 306 | self.db_seqBounds = np.array([[0,len(self.dbDescs)] for _ in range(len(self.dbDescs))]) 307 | self.q_seqBounds = np.array([[0,len(self.qDescs)] for _ in range(len(self.qDescs))]) 308 | else: 309 | self.db_seqBounds = seqBounds[0][indsSplit[0]] 310 | self.q_seqBounds = seqBounds[1][indsSplit[1]] 311 | self.whichSet = self.dbStruct.whichSet 312 | self.dataset = self.dbStruct.dataset 313 | self.nNegSample = nNegSample # number of negatives to randomly sample 314 | self.nNeg = nNeg # number of negatives used for training 315 | self.use_faiss = True 316 | self.use_regions = use_regions 317 | 318 | # potential positives are those within nontrivial threshold range 319 | # fit NN to find them, search by radius 320 | knn = NearestNeighbors(n_jobs=-1) 321 | knn.fit(self.dbStruct.utmDb) 322 | 323 | # TODO use sqeuclidean as metric? 324 | self.nontrivial_distances, self.nontrivial_positives = \ 325 | knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, 326 | return_distance=True) 327 | 328 | self.nontrivial_positives = list(self.nontrivial_positives) 329 | 330 | # radius returns unsorted, sort once now so we dont have to later 331 | for i, posi in enumerate(self.nontrivial_positives): 332 | self.nontrivial_positives[i] = np.sort(posi) 333 | 334 | # its possible some queries don't have any non trivial potential positives 335 | # lets filter those out 336 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 337 | print("\n Queries within range ",len(self.queries), len(self.nontrivial_positives),"\n") 338 | 339 | # potential negatives are those outside of posDistThr range 340 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, 341 | radius=self.dbStruct.posDistThr, 342 | return_distance=False) 343 | 344 | self.potential_negatives = [] 345 | for pos in potential_positives: 346 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 347 | 348 | self.cache = None # filepath of HDF5 containing feature vectors for images 349 | self.h5feat = None 350 | 351 | self.negCache = [np.empty((0,)) for _ in range(self.dbStruct.numQ)] 352 | 353 | def __getitem__(self, index): 354 | with h5py.File(self.cache, mode='r') as h5: 355 | h5feat = h5.get("features") 356 | 357 | qOffset = self.dbStruct.numDb 358 | qFeat = h5feat[index + qOffset] 359 | 360 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 361 | 362 | if self.use_faiss: 363 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1]) 364 | # noinspection PyArgumentList 365 | faiss_index.add(posFeat) 366 | # noinspection PyArgumentList 367 | dPos, posNN = faiss_index.search(qFeat.reshape(1, -1), 1)#posFeat.shape[0]) 368 | dPos = np.sqrt(dPos) # faiss returns squared distance 369 | else: 370 | knn = NearestNeighbors(n_jobs=-1) 371 | knn.fit(posFeat) 372 | dPos, posNN = knn.kneighbors(qFeat.reshape(1, -1), 1)#posFeat.shape[0]) 373 | if len(self.nontrivial_positives[index]) < 1: 374 | # if none are violating then skip this query 375 | return None 376 | dPos = dPos[0][-1].item() 377 | posIndex = self.nontrivial_positives[index][posNN[0,-1]].item() 378 | 379 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) 380 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) 381 | negSample = np.sort(negSample) #essential to order ascending, speeds up h5 by about double 382 | 383 | negFeat = h5feat[negSample.astype(int).tolist()] 384 | if self.use_faiss: 385 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1]) 386 | # noinspection PyArgumentList 387 | faiss_index.add(negFeat) 388 | # noinspection PyArgumentList 389 | dNeg, negNN = faiss_index.search(qFeat.reshape(1, -1), self.nNeg * 10) 390 | dNeg = np.sqrt(dNeg) 391 | else: 392 | knn.fit(negFeat) 393 | 394 | # to quote netvlad paper code: 10x is hacky but fine 395 | dNeg, negNN = knn.kneighbors(qFeat.reshape(1, -1), self.nNeg * 10) 396 | 397 | dNeg = dNeg.reshape(-1) 398 | negNN = negNN.reshape(-1) 399 | 400 | # try to find negatives that are within margin, if there aren't any return none 401 | violatingNeg = dNeg < dPos + self.margin**0.5 402 | 403 | if np.sum(violatingNeg) < 1: 404 | # if none are violating then skip this query 405 | return None 406 | 407 | negNN = negNN[violatingNeg][:self.nNeg] 408 | negIndices = negSample[negNN].astype(np.int32) 409 | self.negCache[index] = negIndices 410 | 411 | sIdMin_q, sIdMax_q = self.q_seqBounds[index] 412 | query = self.qDescs[getSeqInds(index,self.seqL,sIdMax_q,sIdMin_q)] 413 | sIdMin_p, sIdMax_p = self.db_seqBounds[posIndex] 414 | positive = self.dbDescs[getSeqInds(posIndex,self.seqL,sIdMax_p,sIdMin_p)] 415 | 416 | negatives = [] 417 | for negIndex in negIndices: 418 | sIdMin_n, sIdMax_n = self.db_seqBounds[negIndex] 419 | negative = torch.tensor(self.dbDescs[getSeqInds(negIndex,self.seqL,sIdMax_n,sIdMin_n)]) 420 | negatives.append(negative) 421 | 422 | negatives = torch.stack(negatives, 0) 423 | 424 | # noinspection PyTypeChecker 425 | return query, positive, negatives, [index, posIndex] + negIndices.tolist() 426 | 427 | def __len__(self): 428 | return len(self.qDescs) 429 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # EDIT: 27 Jan 2024 2 | # Please download nordland-clean dataset from here: https://universityofadelaide.box.com/s/zkfk1akpbo5318fzqmtvlpp7030ex4up 3 | # Please download nordland/oxford precomputed descriptors from here: https://universityofadelaide.box.com/s/p8uh5yncsaxk7g8lwr8pihnwkqbc2pkf 4 | # Please download trained models from here: https://universityofadelaide.box.com/s/mp45yapl0j0by6aijf5kj8obt8ky0swk 5 | 6 | # download nordland-clean dataset 7 | #wget -cO - https://cloudstor.aarnet.edu.au/plus/s/PK98pDvLAesL1aL/download > nordland-clean.zip 8 | #mkdir -p ./data/ 9 | #unzip nordland-clean.zip -d ./data/ 10 | #rm nordland-clean.zip 11 | 12 | # download oxford descriptors 13 | #wget -cO - https://cloudstor.aarnet.edu.au/plus/s/T0M1Ry4HXOAkkGz/download > oxford_2014-12-16-18-44-24_stereo_left.npy 14 | #wget -cO - https://cloudstor.aarnet.edu.au/plus/s/vr21RnhMmOkW8S9/download > oxford_2015-03-17-11-08-44_stereo_left.npy 15 | #mv oxford* ./data/descData/netvlad-pytorch/ 16 | 17 | # download trained models 18 | #wget -cO - https://cloudstor.aarnet.edu.au/plus/s/oMwpOzex5ld4nQq/download > models-nordland.zip 19 | #unzip models-nordland.zip -d ./data/ 20 | #rm models-nordland.zip 21 | -------------------------------------------------------------------------------- /get_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | from torch.utils.data import DataLoader, SubsetRandomSampler 3 | import numpy as np 4 | from os.path import join 5 | from itertools import product 6 | 7 | prefix_data = "./data/" 8 | 9 | def get_dataset(opt): 10 | 11 | if 'nordland' in opt.dataset.lower(): 12 | dataset = Dataset('nordland', 'nordland_train_d-40_d2-10.db', 'nordland_test_d-1_d2-1.db', 'nordland_val_d-1_d2-1.db', opt) # train, test, val structs 13 | if 'sw' in opt.dataset.lower(): 14 | ref, qry = 'summer', 'winter' 15 | elif 'sf' in opt.dataset.lower(): 16 | ref, qry = 'spring', 'fall' 17 | ft1 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,ref))) 18 | ft2 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,qry))) 19 | trainInds, testInds, valInds = np.arange(15000), np.arange(15100,18100), np.arange(18200,21200) 20 | 21 | dataset.trainInds = [trainInds, trainInds] 22 | dataset.valInds = [valInds, valInds] 23 | dataset.testInds = [testInds, testInds] 24 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2) 25 | 26 | elif 'oxford' in opt.dataset.lower(): 27 | ref, qry = '2015-03-17-11-08-44', '2014-12-16-18-44-24' 28 | structStr = "{}_{}_{}".format(opt.dataset,ref,qry) 29 | # note: for now temporarily use ox_test as ox_val 30 | if 'v1.0' in opt.dataset: 31 | testStr = '_test_d-10_d2-5.db' 32 | elif 'pnv' in opt.dataset: 33 | testStr = '_test_d-25_d2-5.db' 34 | dataset = Dataset(opt.dataset, structStr+'_train_d-20_d2-5.db', structStr+testStr, structStr+testStr, opt) # train, test, val structs 35 | ft1 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,ref))) 36 | ft2 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,qry))) 37 | splitInds = np.load("./structFiles/{}_splitInds.npz".format(opt.dataset), allow_pickle=True) 38 | 39 | dataset.trainInds = splitInds['trainInds'].tolist() 40 | dataset.valInds = splitInds['valInds'].tolist() 41 | dataset.testInds = splitInds['testInds'].tolist() 42 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2) 43 | 44 | elif 'msls' in opt.dataset.lower(): 45 | def get_msls_modImgNames(names): 46 | return ["/".join(n.split("/")[8:]) for n in names] 47 | trav1, trav2 = 'database', 'query' 48 | trainCity, valCity = opt.msls_trainCity, opt.msls_valCity 49 | dbFileName_train = f'msls_{trainCity}_d-20_d2-5.db' 50 | dbFileName_val = f'msls_{valCity}_d-20_d2-5.db' 51 | dataset = Dataset('msls', dbFileName_train, dbFileName_val, dbFileName_val, opt) # train, test, val structs 52 | ftReadPath = join(prefix_data,"descData/{}/msls_{}_{}.npy") 53 | seqBounds_all, ft_all = [], [] 54 | for splitCity, trav in product([trainCity, valCity],[trav1, trav2]): 55 | seqBounds_all.append(np.loadtxt(f"./structFiles/seqBoundsFiles/msls_{splitCity}_{trav}_seqBounds.txt",int)) 56 | ft_all.append(np.load(ftReadPath.format(opt.descType,splitCity,trav))) 57 | ft_train_ref, ft_train_qry, ft_val_ref, ft_val_qry = ft_all 58 | sb_train_ref, sb_train_qry, sb_val_ref, sb_val_qry = seqBounds_all 59 | dataset.trainInds = [np.arange(ft_train_ref.shape[0]),np.arange(ft_train_qry.shape[0])] # append ref & qry 60 | dataset.valInds = [ft_train_ref.shape[0]+np.arange(ft_val_ref.shape[0]),ft_train_qry.shape[0]+np.arange(ft_val_qry.shape[0])] # shift val by train count 61 | dataset.testInds = dataset.valInds 62 | encoder_dim = dataset.loadPreComputedDescriptors(np.vstack([ft_train_ref,ft_val_ref]), np.vstack([ft_train_qry,ft_val_qry]), \ 63 | [np.vstack([sb_train_ref,sb_val_ref]),np.vstack([sb_train_qry,sb_val_qry])]) 64 | 65 | else: 66 | raise Exception('Unknown dataset') 67 | 68 | return dataset, encoder_dim 69 | 70 | 71 | def get_splits(opt, dataset): 72 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = None, None, None, None 73 | if opt.mode.lower() == 'train': 74 | whole_train_set = dataset.get_whole_training_set() 75 | whole_training_data_loader = DataLoader(dataset=whole_train_set, 76 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 77 | pin_memory=not opt.nocuda) 78 | 79 | train_set = dataset.get_training_query_set(opt.margin) 80 | 81 | print('====> Training whole set:', len(whole_train_set)) 82 | print('====> Training query set:', len(train_set)) 83 | whole_test_set = dataset.get_whole_val_set() 84 | print('===> Evaluating on val set, query count:', whole_test_set.dbStruct.numQ) 85 | elif opt.mode.lower() == 'test': 86 | if opt.split.lower() == 'test': 87 | whole_test_set = dataset.get_whole_test_set() 88 | print('===> Evaluating on test set') 89 | elif opt.split.lower() == 'train': 90 | whole_test_set = dataset.get_whole_training_set() 91 | print('===> Evaluating on train set') 92 | elif opt.split.lower() in ['val']: 93 | whole_test_set = dataset.get_whole_val_set() 94 | print('===> Evaluating on val set') 95 | else: 96 | raise ValueError('Unknown dataset split: ' + opt.split) 97 | print('====> Query count:', whole_test_set.dbStruct.numQ) 98 | 99 | return whole_train_set, whole_training_data_loader, train_set, whole_test_set 100 | -------------------------------------------------------------------------------- /get_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from os.path import join, isfile 6 | import seqNet 7 | 8 | class Flatten(nn.Module): 9 | def forward(self, input): 10 | return input.view(input.size(0), -1) 11 | 12 | class L2Norm(nn.Module): 13 | def __init__(self, dim=1): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | def forward(self, input): 18 | return F.normalize(input, p=2, dim=self.dim) 19 | 20 | def get_model(opt,encoder_dim,device): 21 | model = nn.Module() 22 | 23 | if opt.seqL == 1 and opt.pooling.lower() not in ['single', 'seqnet']: 24 | raise Exception("For sequential matching/pooling, set seqL > 1") 25 | elif opt.seqL != 1 and opt.pooling.lower() in ['single']: 26 | raise Exception("For single frame based evaluation, set seqL = 1") 27 | 28 | if opt.pooling.lower() == 'smooth': 29 | global_pool = nn.AdaptiveAvgPool2d((1,None)) 30 | model.add_module('pool', nn.Sequential(*[global_pool, Flatten(), L2Norm()])) 31 | elif opt.pooling.lower() == 'seqnet': 32 | seqFt = seqNet.seqNet(encoder_dim, opt.outDims, opt.seqL, opt.w) 33 | model.add_module('pool', nn.Sequential(*[seqFt, Flatten(), L2Norm()])) 34 | elif opt.pooling.lower() == 's1+seqmatch': 35 | seqFt = seqNet.seqNet(encoder_dim, opt.outDims, 1, opt.w) 36 | model.add_module('pool', nn.Sequential(*[seqFt, Flatten(), L2Norm()])) 37 | elif opt.pooling.lower() == 'delta': 38 | deltaFt = seqNet.Delta(inDims=encoder_dim,seqL=opt.seqL) 39 | model.add_module('pool', nn.Sequential(*[deltaFt, L2Norm()])) 40 | elif opt.pooling.lower() == 'single': 41 | single = nn.AdaptiveAvgPool2d((opt.seqL,None)) # shoud have no effect 42 | model.add_module('pool', nn.Sequential(*[single, Flatten(), L2Norm()])) 43 | elif opt.pooling.lower() == 'single+seqmatch': 44 | 45 | model.add_module('pool', nn.Sequential(*[L2Norm(dim=-1)])) 46 | else: 47 | raise ValueError('Unknown pooling type: ' + opt.pooling) 48 | 49 | if not opt.resume: 50 | model = model.to(device) 51 | 52 | scheduler, optimizer, criterion = None, None, None 53 | if opt.mode.lower() == 'train': 54 | if opt.optim.upper() == 'ADAM': 55 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, 56 | model.parameters()), lr=opt.lr)#, betas=(0,0.9)) 57 | elif opt.optim.upper() == 'SGD': 58 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, 59 | model.parameters()), lr=opt.lr, 60 | momentum=opt.momentum, 61 | weight_decay=opt.weightDecay) 62 | 63 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrStep, gamma=opt.lrGamma) 64 | else: 65 | raise ValueError('Unknown optimizer: ' + opt.optim) 66 | 67 | criterion = nn.TripletMarginLoss(margin=opt.margin**0.5, p=2, reduction='sum').to(device) 68 | 69 | if opt.resume: 70 | if opt.ckpt.lower() == 'latest': 71 | resume_ckpt = join(opt.resume, 'checkpoints', 'checkpoint.pth.tar') 72 | elif opt.ckpt.lower() == 'best': 73 | resume_ckpt = join(opt.resume, 'checkpoints', 'model_best.pth.tar') 74 | 75 | if isfile(resume_ckpt): 76 | print("=> loading checkpoint '{}'".format(resume_ckpt)) 77 | checkpoint = torch.load(resume_ckpt, map_location=lambda storage, loc: storage) 78 | opt.start_epoch = checkpoint['epoch'] 79 | best_metric = checkpoint['best_score'] 80 | model.load_state_dict(checkpoint['state_dict']) 81 | model = model.to(device) 82 | if opt.mode == 'train': 83 | optimizer.load_state_dict(checkpoint['optimizer']) 84 | print("=> loaded checkpoint '{}' (epoch {})" 85 | .format(resume_ckpt, checkpoint['epoch'])) 86 | else: 87 | print("=> no checkpoint found at '{}'".format(resume_ckpt)) 88 | 89 | return model, optimizer, scheduler, criterion -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random, shutil, json 4 | from os.path import join, exists 5 | from os import makedirs 6 | 7 | import torch 8 | from datetime import datetime 9 | from tqdm import tqdm 10 | 11 | from tensorboardX import SummaryWriter 12 | import numpy as np 13 | import sys 14 | 15 | from get_datasets import get_dataset, get_splits, prefix_data 16 | from get_models import get_model 17 | from train import train 18 | from test import test 19 | 20 | parser = argparse.ArgumentParser(description='seqnet') 21 | parser.add_argument('--mode', type=str, default='train', help='Mode', choices=['train', 'test']) 22 | 23 | # train settings 24 | parser.add_argument('--batchSize', type=int, default=16, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.') 25 | parser.add_argument('--cacheBatchSize', type=int, default=24, help='Batch size for caching and testing') 26 | parser.add_argument('--cacheRefreshRate', type=int, default=0, help='How often to refresh cache, in number of queries. 0 for off') 27 | parser.add_argument('--nEpochs', type=int, default=200, help='number of epochs to train for') 28 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 29 | parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.') 30 | parser.add_argument('--optim', type=str, default='SGD', help='optimizer to use', choices=['SGD', 'ADAM']) 31 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate.') 32 | parser.add_argument('--lrStep', type=float, default=50, help='Decay LR ever N steps.') 33 | parser.add_argument('--lrGamma', type=float, default=0.5, help='Multiply LR by Gamma for decaying.') 34 | parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.') 35 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.') 36 | parser.add_argument('--nocuda', action='store_true', help='Dont use cuda') 37 | parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use') 38 | parser.add_argument('--seed', type=int, default=123, help='Random seed to use.') 39 | parser.add_argument('--expName', default='0', help='Unique string for an experiment') 40 | 41 | # path settings 42 | parser.add_argument('--runsPath', type=str, default=join(prefix_data,'runs'), help='Path to save runs to.') 43 | parser.add_argument('--savePath', type=str, default='checkpoints', help='Path to save checkpoints to in logdir. Default=checkpoints/') 44 | parser.add_argument('--cachePath', type=str, default=join(prefix_data,'cache'), help='Path to save cache to.') 45 | parser.add_argument('--resultsPath', type=str, default=None, help='Path to save evaluation results to when mode=test') 46 | 47 | # test settings 48 | parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.') 49 | parser.add_argument('--ckpt', type=str, default='latest', help='Resume from latest or best checkpoint.', choices=['latest', 'best']) 50 | parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.') 51 | parser.add_argument('--patience', type=int, default=0, help='Patience for early stopping. 0 is off.') 52 | parser.add_argument('--split', type=str, default='val', help='Data split to use for testing. Default is val', choices=['test', 'train', 'val']) 53 | parser.add_argument('--numSamples2Project', type=int, default=-1, help='TSNE uses these many samples ([:n]) to project data to 2D; set to -1 to disable') 54 | parser.add_argument('--extractOnly', action='store_true', help='extract descriptors') 55 | parser.add_argument('--predictionsFile', type=str, default=None, help='path to prior predictions data') 56 | parser.add_argument('--seqL_filterData', type=int, help='during testing, db and qry inds will be removed that violate sequence boundaries for this given sequence length') 57 | 58 | # dataset, model etc. 59 | parser.add_argument('--dataset', type=str, default='nordland-sw', help='Dataset to use', choices=['nordland-sw', 'nordland-sf', 'oxford-v1.0', 'oxford-pnv', 'msls']) 60 | parser.add_argument('--msls_trainCity', type=str, default='melbourne', help='trainCityName') 61 | parser.add_argument('--msls_valCity', type=str, default='austin', help='valCityName') 62 | parser.add_argument('--pooling', type=str, default='seqnet', help='type of pooling to use', choices=[ 'seqnet', 'smooth', 'delta', 'single','single+seqmatch', 's1+seqmatch']) 63 | parser.add_argument('--seqL', type=int, default=5, help='Sequence Length') 64 | parser.add_argument('--w', type=int, default=3, help='filter size for seqNet') 65 | parser.add_argument('--outDims', type=int, default=None, help='Output descriptor dimensions') 66 | parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1') 67 | parser.add_argument('--descType', type=str, default="netvlad-pytorch", help='underlying descriptor type') 68 | 69 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 70 | model_out_path = join(opt.savePath, filename) 71 | torch.save(state, model_out_path) 72 | if is_best: 73 | shutil.copyfile(model_out_path, join(opt.savePath, 'model_best.pth.tar')) 74 | 75 | if __name__ == "__main__": 76 | 77 | opt = parser.parse_args() 78 | 79 | restore_var = ['lr', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 80 | 'runsPath', 'savePath', 'optim', 'margin', 'seed', 'patience', 'outDims', 'w'] 81 | if opt.pooling.lower() != 's1+seqmatch': 82 | restore_var = restore_var + ['pooling'] 83 | if opt.resume: 84 | if opt.pooling.lower() in ['single', 'smooth', 'delta', 'single+seqmatch']: 85 | raise Exception("Use pooling 'seqnet' with resume") 86 | flag_file = join(opt.resume, 'checkpoints', 'flags.json') 87 | if exists(flag_file): 88 | with open(flag_file, 'r') as f: 89 | stored_flags = {'--'+k : str(v) for k,v in json.load(f).items() if k in restore_var} 90 | to_del = [] 91 | for flag, val in stored_flags.items(): 92 | for act in parser._actions: 93 | if act.dest == flag[2:]: 94 | # store_true / store_false args don't accept arguments, filter these 95 | if type(act.const) == type(True): 96 | if val == str(act.default): 97 | to_del.append(flag) 98 | else: 99 | stored_flags[flag] = '' 100 | for flag in to_del: del stored_flags[flag] 101 | 102 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0] 103 | print('Restored flags:', train_flags) 104 | opt = parser.parse_args(train_flags, namespace=opt) 105 | 106 | print(opt) 107 | 108 | cuda = not opt.nocuda 109 | if cuda and not torch.cuda.is_available(): 110 | raise Exception("No GPU found, please run with --nocuda") 111 | 112 | device = torch.device("cuda" if cuda else "cpu") 113 | 114 | random.seed(opt.seed) 115 | np.random.seed(opt.seed) 116 | torch.manual_seed(opt.seed) 117 | if cuda: 118 | torch.cuda.manual_seed(opt.seed) 119 | 120 | print('===> Loading dataset(s)') 121 | dataset, encoder_dim = get_dataset(opt) 122 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = get_splits(opt, dataset) 123 | 124 | print('===> Building model') 125 | model, optimizer, scheduler, criterion = get_model(opt, encoder_dim, device) 126 | 127 | unique_string = datetime.now().strftime('%b%d_%H-%M-%S')+'_l'+str(opt.seqL)+'_'+ opt.expName 128 | writer = None 129 | 130 | if opt.mode.lower() == 'test': 131 | print('===> Running evaluation step') 132 | epoch = 1 133 | recallsOrDesc, dbEmb, qEmb, rAtL, preds = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch, extract_noEval=opt.extractOnly) 134 | if opt.resultsPath is not None: 135 | if not exists(opt.resultsPath): 136 | makedirs(opt.resultsPath) 137 | if opt.extractOnly: 138 | gt = whole_test_set.get_positives() 139 | numDb = whole_test_set.dbStruct.numDb 140 | np.savez(join(opt.resultsPath,unique_string),dbDesc=recallsOrDesc[:numDb],qDesc=recallsOrDesc[numDb:],gt=gt) 141 | else: 142 | np.savez(join(opt.resultsPath,unique_string),args=opt.__dict__,recalls=recallsOrDesc, dbEmb=dbEmb,qEmb=qEmb,rAtL=rAtL,preds=preds) 143 | 144 | elif opt.mode.lower() == 'train': 145 | print('===> Training model') 146 | writer = SummaryWriter(log_dir=join(opt.runsPath,unique_string)) 147 | train_set.cache = join(opt.cachePath, train_set.whichSet + '_feat_cache_{}.hdf5'.format(unique_string)) 148 | if not exists(opt.cachePath): 149 | makedirs(opt.cachePath) 150 | 151 | # write checkpoints in logdir 152 | logdir = writer.file_writer.get_logdir() 153 | opt.savePath = join(logdir, opt.savePath) 154 | if not opt.resume: 155 | makedirs(opt.savePath) 156 | 157 | with open(join(opt.savePath, 'flags.json'), 'w') as f: 158 | f.write(json.dumps( 159 | {k:v for k,v in vars(opt).items()} 160 | )) 161 | print('===> Saving state to:', logdir) 162 | 163 | not_improved = 0 164 | best_score = 0 165 | for epoch in range(opt.start_epoch+1, opt.nEpochs + 1): 166 | train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer) 167 | if opt.optim.upper() == 'SGD': 168 | scheduler.step(epoch) 169 | if (epoch % opt.evalEvery) == 0: 170 | recalls = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch)[0] 171 | is_best = recalls[5] > best_score 172 | if is_best: 173 | not_improved = 0 174 | best_score = recalls[5] 175 | else: 176 | not_improved += 1 177 | 178 | save_checkpoint({ 179 | 'epoch': epoch, 180 | 'state_dict': model.state_dict(), 181 | 'recalls': recalls, 182 | 'best_score': best_score, 183 | 'optimizer' : optimizer.state_dict(), 184 | 'parallel' : False, 185 | }, is_best) 186 | 187 | if opt.patience > 0 and not_improved > (opt.patience / opt.evalEvery): 188 | print('Performance did not improve for', opt.patience, 'epochs. Stopping.') 189 | break 190 | 191 | print("=> Best Recall@5: {:.4f}".format(best_score), flush=True) 192 | writer.close() 193 | -------------------------------------------------------------------------------- /seqNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class seqNet(nn.Module): 6 | def __init__(self, inDims, outDims, seqL, w=5): 7 | 8 | super(seqNet, self).__init__() 9 | self.inDims = inDims 10 | self.outDims = outDims 11 | self.w = w 12 | self.conv = nn.Conv1d(inDims, outDims, kernel_size=self.w) 13 | 14 | def forward(self, x): 15 | 16 | if len(x.shape) < 3: 17 | x = x.unsqueeze(1) # convert [B,C] to [B,1,C] 18 | 19 | x = x.permute(0,2,1) # from [B,T,C] to [B,C,T] 20 | seqFt = self.conv(x) 21 | seqFt = torch.mean(seqFt,-1) 22 | 23 | return seqFt 24 | 25 | class Delta(nn.Module): 26 | def __init__(self, inDims, seqL): 27 | 28 | super(Delta, self).__init__() 29 | self.inDims = inDims 30 | self.weight = (np.ones(seqL,np.float32))/(seqL/2.0) 31 | self.weight[:seqL//2] *= -1 32 | self.weight = nn.Parameter(torch.from_numpy(self.weight),requires_grad=False) 33 | 34 | def forward(self, x): 35 | 36 | # make desc dim as C 37 | x = x.permute(0,2,1) # makes [B,T,C] as [B,C,T] 38 | delta = torch.matmul(x,self.weight) 39 | 40 | return delta 41 | -------------------------------------------------------------------------------- /structFiles/msls_amman_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/msls_amman_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/msls_austin_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/msls_austin_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/msls_melbourne_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/msls_melbourne_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/nordland_test_d-1_d2-1.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/nordland_test_d-1_d2-1.db -------------------------------------------------------------------------------- /structFiles/nordland_train_d-40_d2-10.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/nordland_train_d-40_d2-10.db -------------------------------------------------------------------------------- /structFiles/nordland_val_d-1_d2-1.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/nordland_val_d-1_d2-1.db -------------------------------------------------------------------------------- /structFiles/oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-25_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-25_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-pnv_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-pnv_splitInds.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-pnv_splitInds.npz -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_splitInds.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/seqNet/6b83d36aec4a49cae415c3e63c8069066329da02/structFiles/oxford-v1.0_splitInds.npz -------------------------------------------------------------------------------- /structFiles/seqBoundsFiles/msls_amman_database_seqBounds.txt: -------------------------------------------------------------------------------- 1 | 0 10 2 | 0 10 3 | 0 10 4 | 0 10 5 | 0 10 6 | 0 10 7 | 0 10 8 | 0 10 9 | 0 10 10 | 0 10 11 | 10 18 12 | 10 18 13 | 10 18 14 | 10 18 15 | 10 18 16 | 10 18 17 | 10 18 18 | 10 18 19 | 18 27 20 | 18 27 21 | 18 27 22 | 18 27 23 | 18 27 24 | 18 27 25 | 18 27 26 | 18 27 27 | 18 27 28 | 27 35 29 | 27 35 30 | 27 35 31 | 27 35 32 | 27 35 33 | 27 35 34 | 27 35 35 | 27 35 36 | 35 42 37 | 35 42 38 | 35 42 39 | 35 42 40 | 35 42 41 | 35 42 42 | 35 42 43 | 42 48 44 | 42 48 45 | 42 48 46 | 42 48 47 | 42 48 48 | 42 48 49 | 48 56 50 | 48 56 51 | 48 56 52 | 48 56 53 | 48 56 54 | 48 56 55 | 48 56 56 | 48 56 57 | 56 62 58 | 56 62 59 | 56 62 60 | 56 62 61 | 56 62 62 | 56 62 63 | 62 152 64 | 62 152 65 | 62 152 66 | 62 152 67 | 62 152 68 | 62 152 69 | 62 152 70 | 62 152 71 | 62 152 72 | 62 152 73 | 62 152 74 | 62 152 75 | 62 152 76 | 62 152 77 | 62 152 78 | 62 152 79 | 62 152 80 | 62 152 81 | 62 152 82 | 62 152 83 | 62 152 84 | 62 152 85 | 62 152 86 | 62 152 87 | 62 152 88 | 62 152 89 | 62 152 90 | 62 152 91 | 62 152 92 | 62 152 93 | 62 152 94 | 62 152 95 | 62 152 96 | 62 152 97 | 62 152 98 | 62 152 99 | 62 152 100 | 62 152 101 | 62 152 102 | 62 152 103 | 62 152 104 | 62 152 105 | 62 152 106 | 62 152 107 | 62 152 108 | 62 152 109 | 62 152 110 | 62 152 111 | 62 152 112 | 62 152 113 | 62 152 114 | 62 152 115 | 62 152 116 | 62 152 117 | 62 152 118 | 62 152 119 | 62 152 120 | 62 152 121 | 62 152 122 | 62 152 123 | 62 152 124 | 62 152 125 | 62 152 126 | 62 152 127 | 62 152 128 | 62 152 129 | 62 152 130 | 62 152 131 | 62 152 132 | 62 152 133 | 62 152 134 | 62 152 135 | 62 152 136 | 62 152 137 | 62 152 138 | 62 152 139 | 62 152 140 | 62 152 141 | 62 152 142 | 62 152 143 | 62 152 144 | 62 152 145 | 62 152 146 | 62 152 147 | 62 152 148 | 62 152 149 | 62 152 150 | 62 152 151 | 62 152 152 | 62 152 153 | 152 193 154 | 152 193 155 | 152 193 156 | 152 193 157 | 152 193 158 | 152 193 159 | 152 193 160 | 152 193 161 | 152 193 162 | 152 193 163 | 152 193 164 | 152 193 165 | 152 193 166 | 152 193 167 | 152 193 168 | 152 193 169 | 152 193 170 | 152 193 171 | 152 193 172 | 152 193 173 | 152 193 174 | 152 193 175 | 152 193 176 | 152 193 177 | 152 193 178 | 152 193 179 | 152 193 180 | 152 193 181 | 152 193 182 | 152 193 183 | 152 193 184 | 152 193 185 | 152 193 186 | 152 193 187 | 152 193 188 | 152 193 189 | 152 193 190 | 152 193 191 | 152 193 192 | 152 193 193 | 152 193 194 | 193 209 195 | 193 209 196 | 193 209 197 | 193 209 198 | 193 209 199 | 193 209 200 | 193 209 201 | 193 209 202 | 193 209 203 | 193 209 204 | 193 209 205 | 193 209 206 | 193 209 207 | 193 209 208 | 193 209 209 | 193 209 210 | 209 246 211 | 209 246 212 | 209 246 213 | 209 246 214 | 209 246 215 | 209 246 216 | 209 246 217 | 209 246 218 | 209 246 219 | 209 246 220 | 209 246 221 | 209 246 222 | 209 246 223 | 209 246 224 | 209 246 225 | 209 246 226 | 209 246 227 | 209 246 228 | 209 246 229 | 209 246 230 | 209 246 231 | 209 246 232 | 209 246 233 | 209 246 234 | 209 246 235 | 209 246 236 | 209 246 237 | 209 246 238 | 209 246 239 | 209 246 240 | 209 246 241 | 209 246 242 | 209 246 243 | 209 246 244 | 209 246 245 | 209 246 246 | 209 246 247 | 246 261 248 | 246 261 249 | 246 261 250 | 246 261 251 | 246 261 252 | 246 261 253 | 246 261 254 | 246 261 255 | 246 261 256 | 246 261 257 | 246 261 258 | 246 261 259 | 246 261 260 | 246 261 261 | 246 261 262 | 261 271 263 | 261 271 264 | 261 271 265 | 261 271 266 | 261 271 267 | 261 271 268 | 261 271 269 | 261 271 270 | 261 271 271 | 261 271 272 | 271 278 273 | 271 278 274 | 271 278 275 | 271 278 276 | 271 278 277 | 271 278 278 | 271 278 279 | 278 301 280 | 278 301 281 | 278 301 282 | 278 301 283 | 278 301 284 | 278 301 285 | 278 301 286 | 278 301 287 | 278 301 288 | 278 301 289 | 278 301 290 | 278 301 291 | 278 301 292 | 278 301 293 | 278 301 294 | 278 301 295 | 278 301 296 | 278 301 297 | 278 301 298 | 278 301 299 | 278 301 300 | 278 301 301 | 278 301 302 | 301 330 303 | 301 330 304 | 301 330 305 | 301 330 306 | 301 330 307 | 301 330 308 | 301 330 309 | 301 330 310 | 301 330 311 | 301 330 312 | 301 330 313 | 301 330 314 | 301 330 315 | 301 330 316 | 301 330 317 | 301 330 318 | 301 330 319 | 301 330 320 | 301 330 321 | 301 330 322 | 301 330 323 | 301 330 324 | 301 330 325 | 301 330 326 | 301 330 327 | 301 330 328 | 301 330 329 | 301 330 330 | 301 330 331 | 330 360 332 | 330 360 333 | 330 360 334 | 330 360 335 | 330 360 336 | 330 360 337 | 330 360 338 | 330 360 339 | 330 360 340 | 330 360 341 | 330 360 342 | 330 360 343 | 330 360 344 | 330 360 345 | 330 360 346 | 330 360 347 | 330 360 348 | 330 360 349 | 330 360 350 | 330 360 351 | 330 360 352 | 330 360 353 | 330 360 354 | 330 360 355 | 330 360 356 | 330 360 357 | 330 360 358 | 330 360 359 | 330 360 360 | 330 360 361 | 360 366 362 | 360 366 363 | 360 366 364 | 360 366 365 | 360 366 366 | 360 366 367 | 366 372 368 | 366 372 369 | 366 372 370 | 366 372 371 | 366 372 372 | 366 372 373 | 372 416 374 | 372 416 375 | 372 416 376 | 372 416 377 | 372 416 378 | 372 416 379 | 372 416 380 | 372 416 381 | 372 416 382 | 372 416 383 | 372 416 384 | 372 416 385 | 372 416 386 | 372 416 387 | 372 416 388 | 372 416 389 | 372 416 390 | 372 416 391 | 372 416 392 | 372 416 393 | 372 416 394 | 372 416 395 | 372 416 396 | 372 416 397 | 372 416 398 | 372 416 399 | 372 416 400 | 372 416 401 | 372 416 402 | 372 416 403 | 372 416 404 | 372 416 405 | 372 416 406 | 372 416 407 | 372 416 408 | 372 416 409 | 372 416 410 | 372 416 411 | 372 416 412 | 372 416 413 | 372 416 414 | 372 416 415 | 372 416 416 | 372 416 417 | 416 434 418 | 416 434 419 | 416 434 420 | 416 434 421 | 416 434 422 | 416 434 423 | 416 434 424 | 416 434 425 | 416 434 426 | 416 434 427 | 416 434 428 | 416 434 429 | 416 434 430 | 416 434 431 | 416 434 432 | 416 434 433 | 416 434 434 | 416 434 435 | 434 440 436 | 434 440 437 | 434 440 438 | 434 440 439 | 434 440 440 | 434 440 441 | 440 447 442 | 440 447 443 | 440 447 444 | 440 447 445 | 440 447 446 | 440 447 447 | 440 447 448 | 447 473 449 | 447 473 450 | 447 473 451 | 447 473 452 | 447 473 453 | 447 473 454 | 447 473 455 | 447 473 456 | 447 473 457 | 447 473 458 | 447 473 459 | 447 473 460 | 447 473 461 | 447 473 462 | 447 473 463 | 447 473 464 | 447 473 465 | 447 473 466 | 447 473 467 | 447 473 468 | 447 473 469 | 447 473 470 | 447 473 471 | 447 473 472 | 447 473 473 | 447 473 474 | 473 485 475 | 473 485 476 | 473 485 477 | 473 485 478 | 473 485 479 | 473 485 480 | 473 485 481 | 473 485 482 | 473 485 483 | 473 485 484 | 473 485 485 | 473 485 486 | 485 582 487 | 485 582 488 | 485 582 489 | 485 582 490 | 485 582 491 | 485 582 492 | 485 582 493 | 485 582 494 | 485 582 495 | 485 582 496 | 485 582 497 | 485 582 498 | 485 582 499 | 485 582 500 | 485 582 501 | 485 582 502 | 485 582 503 | 485 582 504 | 485 582 505 | 485 582 506 | 485 582 507 | 485 582 508 | 485 582 509 | 485 582 510 | 485 582 511 | 485 582 512 | 485 582 513 | 485 582 514 | 485 582 515 | 485 582 516 | 485 582 517 | 485 582 518 | 485 582 519 | 485 582 520 | 485 582 521 | 485 582 522 | 485 582 523 | 485 582 524 | 485 582 525 | 485 582 526 | 485 582 527 | 485 582 528 | 485 582 529 | 485 582 530 | 485 582 531 | 485 582 532 | 485 582 533 | 485 582 534 | 485 582 535 | 485 582 536 | 485 582 537 | 485 582 538 | 485 582 539 | 485 582 540 | 485 582 541 | 485 582 542 | 485 582 543 | 485 582 544 | 485 582 545 | 485 582 546 | 485 582 547 | 485 582 548 | 485 582 549 | 485 582 550 | 485 582 551 | 485 582 552 | 485 582 553 | 485 582 554 | 485 582 555 | 485 582 556 | 485 582 557 | 485 582 558 | 485 582 559 | 485 582 560 | 485 582 561 | 485 582 562 | 485 582 563 | 485 582 564 | 485 582 565 | 485 582 566 | 485 582 567 | 485 582 568 | 485 582 569 | 485 582 570 | 485 582 571 | 485 582 572 | 485 582 573 | 485 582 574 | 485 582 575 | 485 582 576 | 485 582 577 | 485 582 578 | 485 582 579 | 485 582 580 | 485 582 581 | 485 582 582 | 485 582 583 | 582 620 584 | 582 620 585 | 582 620 586 | 582 620 587 | 582 620 588 | 582 620 589 | 582 620 590 | 582 620 591 | 582 620 592 | 582 620 593 | 582 620 594 | 582 620 595 | 582 620 596 | 582 620 597 | 582 620 598 | 582 620 599 | 582 620 600 | 582 620 601 | 582 620 602 | 582 620 603 | 582 620 604 | 582 620 605 | 582 620 606 | 582 620 607 | 582 620 608 | 582 620 609 | 582 620 610 | 582 620 611 | 582 620 612 | 582 620 613 | 582 620 614 | 582 620 615 | 582 620 616 | 582 620 617 | 582 620 618 | 582 620 619 | 582 620 620 | 582 620 621 | 620 643 622 | 620 643 623 | 620 643 624 | 620 643 625 | 620 643 626 | 620 643 627 | 620 643 628 | 620 643 629 | 620 643 630 | 620 643 631 | 620 643 632 | 620 643 633 | 620 643 634 | 620 643 635 | 620 643 636 | 620 643 637 | 620 643 638 | 620 643 639 | 620 643 640 | 620 643 641 | 620 643 642 | 620 643 643 | 620 643 644 | 643 653 645 | 643 653 646 | 643 653 647 | 643 653 648 | 643 653 649 | 643 653 650 | 643 653 651 | 643 653 652 | 643 653 653 | 643 653 654 | 653 660 655 | 653 660 656 | 653 660 657 | 653 660 658 | 653 660 659 | 653 660 660 | 653 660 661 | 660 668 662 | 660 668 663 | 660 668 664 | 660 668 665 | 660 668 666 | 660 668 667 | 660 668 668 | 660 668 669 | 668 676 670 | 668 676 671 | 668 676 672 | 668 676 673 | 668 676 674 | 668 676 675 | 668 676 676 | 668 676 677 | 676 685 678 | 676 685 679 | 676 685 680 | 676 685 681 | 676 685 682 | 676 685 683 | 676 685 684 | 676 685 685 | 676 685 686 | 685 693 687 | 685 693 688 | 685 693 689 | 685 693 690 | 685 693 691 | 685 693 692 | 685 693 693 | 685 693 694 | 693 719 695 | 693 719 696 | 693 719 697 | 693 719 698 | 693 719 699 | 693 719 700 | 693 719 701 | 693 719 702 | 693 719 703 | 693 719 704 | 693 719 705 | 693 719 706 | 693 719 707 | 693 719 708 | 693 719 709 | 693 719 710 | 693 719 711 | 693 719 712 | 693 719 713 | 693 719 714 | 693 719 715 | 693 719 716 | 693 719 717 | 693 719 718 | 693 719 719 | 693 719 720 | 719 735 721 | 719 735 722 | 719 735 723 | 719 735 724 | 719 735 725 | 719 735 726 | 719 735 727 | 719 735 728 | 719 735 729 | 719 735 730 | 719 735 731 | 719 735 732 | 719 735 733 | 719 735 734 | 719 735 735 | 719 735 736 | 735 759 737 | 735 759 738 | 735 759 739 | 735 759 740 | 735 759 741 | 735 759 742 | 735 759 743 | 735 759 744 | 735 759 745 | 735 759 746 | 735 759 747 | 735 759 748 | 735 759 749 | 735 759 750 | 735 759 751 | 735 759 752 | 735 759 753 | 735 759 754 | 735 759 755 | 735 759 756 | 735 759 757 | 735 759 758 | 735 759 759 | 735 759 760 | 759 777 761 | 759 777 762 | 759 777 763 | 759 777 764 | 759 777 765 | 759 777 766 | 759 777 767 | 759 777 768 | 759 777 769 | 759 777 770 | 759 777 771 | 759 777 772 | 759 777 773 | 759 777 774 | 759 777 775 | 759 777 776 | 759 777 777 | 759 777 778 | 777 840 779 | 777 840 780 | 777 840 781 | 777 840 782 | 777 840 783 | 777 840 784 | 777 840 785 | 777 840 786 | 777 840 787 | 777 840 788 | 777 840 789 | 777 840 790 | 777 840 791 | 777 840 792 | 777 840 793 | 777 840 794 | 777 840 795 | 777 840 796 | 777 840 797 | 777 840 798 | 777 840 799 | 777 840 800 | 777 840 801 | 777 840 802 | 777 840 803 | 777 840 804 | 777 840 805 | 777 840 806 | 777 840 807 | 777 840 808 | 777 840 809 | 777 840 810 | 777 840 811 | 777 840 812 | 777 840 813 | 777 840 814 | 777 840 815 | 777 840 816 | 777 840 817 | 777 840 818 | 777 840 819 | 777 840 820 | 777 840 821 | 777 840 822 | 777 840 823 | 777 840 824 | 777 840 825 | 777 840 826 | 777 840 827 | 777 840 828 | 777 840 829 | 777 840 830 | 777 840 831 | 777 840 832 | 777 840 833 | 777 840 834 | 777 840 835 | 777 840 836 | 777 840 837 | 777 840 838 | 777 840 839 | 777 840 840 | 777 840 841 | 840 880 842 | 840 880 843 | 840 880 844 | 840 880 845 | 840 880 846 | 840 880 847 | 840 880 848 | 840 880 849 | 840 880 850 | 840 880 851 | 840 880 852 | 840 880 853 | 840 880 854 | 840 880 855 | 840 880 856 | 840 880 857 | 840 880 858 | 840 880 859 | 840 880 860 | 840 880 861 | 840 880 862 | 840 880 863 | 840 880 864 | 840 880 865 | 840 880 866 | 840 880 867 | 840 880 868 | 840 880 869 | 840 880 870 | 840 880 871 | 840 880 872 | 840 880 873 | 840 880 874 | 840 880 875 | 840 880 876 | 840 880 877 | 840 880 878 | 840 880 879 | 840 880 880 | 840 880 881 | 880 921 882 | 880 921 883 | 880 921 884 | 880 921 885 | 880 921 886 | 880 921 887 | 880 921 888 | 880 921 889 | 880 921 890 | 880 921 891 | 880 921 892 | 880 921 893 | 880 921 894 | 880 921 895 | 880 921 896 | 880 921 897 | 880 921 898 | 880 921 899 | 880 921 900 | 880 921 901 | 880 921 902 | 880 921 903 | 880 921 904 | 880 921 905 | 880 921 906 | 880 921 907 | 880 921 908 | 880 921 909 | 880 921 910 | 880 921 911 | 880 921 912 | 880 921 913 | 880 921 914 | 880 921 915 | 880 921 916 | 880 921 917 | 880 921 918 | 880 921 919 | 880 921 920 | 880 921 921 | 880 921 922 | 921 953 923 | 921 953 924 | 921 953 925 | 921 953 926 | 921 953 927 | 921 953 928 | 921 953 929 | 921 953 930 | 921 953 931 | 921 953 932 | 921 953 933 | 921 953 934 | 921 953 935 | 921 953 936 | 921 953 937 | 921 953 938 | 921 953 939 | 921 953 940 | 921 953 941 | 921 953 942 | 921 953 943 | 921 953 944 | 921 953 945 | 921 953 946 | 921 953 947 | 921 953 948 | 921 953 949 | 921 953 950 | 921 953 951 | 921 953 952 | 921 953 953 | 921 953 954 | -------------------------------------------------------------------------------- /structFiles/seqBoundsFiles/msls_amman_query_seqBounds.txt: -------------------------------------------------------------------------------- 1 | 0 75 2 | 0 75 3 | 0 75 4 | 0 75 5 | 0 75 6 | 0 75 7 | 0 75 8 | 0 75 9 | 0 75 10 | 0 75 11 | 0 75 12 | 0 75 13 | 0 75 14 | 0 75 15 | 0 75 16 | 0 75 17 | 0 75 18 | 0 75 19 | 0 75 20 | 0 75 21 | 0 75 22 | 0 75 23 | 0 75 24 | 0 75 25 | 0 75 26 | 0 75 27 | 0 75 28 | 0 75 29 | 0 75 30 | 0 75 31 | 0 75 32 | 0 75 33 | 0 75 34 | 0 75 35 | 0 75 36 | 0 75 37 | 0 75 38 | 0 75 39 | 0 75 40 | 0 75 41 | 0 75 42 | 0 75 43 | 0 75 44 | 0 75 45 | 0 75 46 | 0 75 47 | 0 75 48 | 0 75 49 | 0 75 50 | 0 75 51 | 0 75 52 | 0 75 53 | 0 75 54 | 0 75 55 | 0 75 56 | 0 75 57 | 0 75 58 | 0 75 59 | 0 75 60 | 0 75 61 | 0 75 62 | 0 75 63 | 0 75 64 | 0 75 65 | 0 75 66 | 0 75 67 | 0 75 68 | 0 75 69 | 0 75 70 | 0 75 71 | 0 75 72 | 0 75 73 | 0 75 74 | 0 75 75 | 0 75 76 | 75 84 77 | 75 84 78 | 75 84 79 | 75 84 80 | 75 84 81 | 75 84 82 | 75 84 83 | 75 84 84 | 75 84 85 | 84 99 86 | 84 99 87 | 84 99 88 | 84 99 89 | 84 99 90 | 84 99 91 | 84 99 92 | 84 99 93 | 84 99 94 | 84 99 95 | 84 99 96 | 84 99 97 | 84 99 98 | 84 99 99 | 84 99 100 | 99 113 101 | 99 113 102 | 99 113 103 | 99 113 104 | 99 113 105 | 99 113 106 | 99 113 107 | 99 113 108 | 99 113 109 | 99 113 110 | 99 113 111 | 99 113 112 | 99 113 113 | 99 113 114 | 113 152 115 | 113 152 116 | 113 152 117 | 113 152 118 | 113 152 119 | 113 152 120 | 113 152 121 | 113 152 122 | 113 152 123 | 113 152 124 | 113 152 125 | 113 152 126 | 113 152 127 | 113 152 128 | 113 152 129 | 113 152 130 | 113 152 131 | 113 152 132 | 113 152 133 | 113 152 134 | 113 152 135 | 113 152 136 | 113 152 137 | 113 152 138 | 113 152 139 | 113 152 140 | 113 152 141 | 113 152 142 | 113 152 143 | 113 152 144 | 113 152 145 | 113 152 146 | 113 152 147 | 113 152 148 | 113 152 149 | 113 152 150 | 113 152 151 | 113 152 152 | 113 152 153 | 152 172 154 | 152 172 155 | 152 172 156 | 152 172 157 | 152 172 158 | 152 172 159 | 152 172 160 | 152 172 161 | 152 172 162 | 152 172 163 | 152 172 164 | 152 172 165 | 152 172 166 | 152 172 167 | 152 172 168 | 152 172 169 | 152 172 170 | 152 172 171 | 152 172 172 | 152 172 173 | 172 204 174 | 172 204 175 | 172 204 176 | 172 204 177 | 172 204 178 | 172 204 179 | 172 204 180 | 172 204 181 | 172 204 182 | 172 204 183 | 172 204 184 | 172 204 185 | 172 204 186 | 172 204 187 | 172 204 188 | 172 204 189 | 172 204 190 | 172 204 191 | 172 204 192 | 172 204 193 | 172 204 194 | 172 204 195 | 172 204 196 | 172 204 197 | 172 204 198 | 172 204 199 | 172 204 200 | 172 204 201 | 172 204 202 | 172 204 203 | 172 204 204 | 172 204 205 | 204 233 206 | 204 233 207 | 204 233 208 | 204 233 209 | 204 233 210 | 204 233 211 | 204 233 212 | 204 233 213 | 204 233 214 | 204 233 215 | 204 233 216 | 204 233 217 | 204 233 218 | 204 233 219 | 204 233 220 | 204 233 221 | 204 233 222 | 204 233 223 | 204 233 224 | 204 233 225 | 204 233 226 | 204 233 227 | 204 233 228 | 204 233 229 | 204 233 230 | 204 233 231 | 204 233 232 | 204 233 233 | 204 233 234 | 233 267 235 | 233 267 236 | 233 267 237 | 233 267 238 | 233 267 239 | 233 267 240 | 233 267 241 | 233 267 242 | 233 267 243 | 233 267 244 | 233 267 245 | 233 267 246 | 233 267 247 | 233 267 248 | 233 267 249 | 233 267 250 | 233 267 251 | 233 267 252 | 233 267 253 | 233 267 254 | 233 267 255 | 233 267 256 | 233 267 257 | 233 267 258 | 233 267 259 | 233 267 260 | 233 267 261 | 233 267 262 | 233 267 263 | 233 267 264 | 233 267 265 | 233 267 266 | 233 267 267 | 233 267 268 | 267 282 269 | 267 282 270 | 267 282 271 | 267 282 272 | 267 282 273 | 267 282 274 | 267 282 275 | 267 282 276 | 267 282 277 | 267 282 278 | 267 282 279 | 267 282 280 | 267 282 281 | 267 282 282 | 267 282 283 | 282 342 284 | 282 342 285 | 282 342 286 | 282 342 287 | 282 342 288 | 282 342 289 | 282 342 290 | 282 342 291 | 282 342 292 | 282 342 293 | 282 342 294 | 282 342 295 | 282 342 296 | 282 342 297 | 282 342 298 | 282 342 299 | 282 342 300 | 282 342 301 | 282 342 302 | 282 342 303 | 282 342 304 | 282 342 305 | 282 342 306 | 282 342 307 | 282 342 308 | 282 342 309 | 282 342 310 | 282 342 311 | 282 342 312 | 282 342 313 | 282 342 314 | 282 342 315 | 282 342 316 | 282 342 317 | 282 342 318 | 282 342 319 | 282 342 320 | 282 342 321 | 282 342 322 | 282 342 323 | 282 342 324 | 282 342 325 | 282 342 326 | 282 342 327 | 282 342 328 | 282 342 329 | 282 342 330 | 282 342 331 | 282 342 332 | 282 342 333 | 282 342 334 | 282 342 335 | 282 342 336 | 282 342 337 | 282 342 338 | 282 342 339 | 282 342 340 | 282 342 341 | 282 342 342 | 282 342 343 | 342 350 344 | 342 350 345 | 342 350 346 | 342 350 347 | 342 350 348 | 342 350 349 | 342 350 350 | 342 350 351 | 350 366 352 | 350 366 353 | 350 366 354 | 350 366 355 | 350 366 356 | 350 366 357 | 350 366 358 | 350 366 359 | 350 366 360 | 350 366 361 | 350 366 362 | 350 366 363 | 350 366 364 | 350 366 365 | 350 366 366 | 350 366 367 | 366 372 368 | 366 372 369 | 366 372 370 | 366 372 371 | 366 372 372 | 366 372 373 | 372 385 374 | 372 385 375 | 372 385 376 | 372 385 377 | 372 385 378 | 372 385 379 | 372 385 380 | 372 385 381 | 372 385 382 | 372 385 383 | 372 385 384 | 372 385 385 | 372 385 386 | 385 486 387 | 385 486 388 | 385 486 389 | 385 486 390 | 385 486 391 | 385 486 392 | 385 486 393 | 385 486 394 | 385 486 395 | 385 486 396 | 385 486 397 | 385 486 398 | 385 486 399 | 385 486 400 | 385 486 401 | 385 486 402 | 385 486 403 | 385 486 404 | 385 486 405 | 385 486 406 | 385 486 407 | 385 486 408 | 385 486 409 | 385 486 410 | 385 486 411 | 385 486 412 | 385 486 413 | 385 486 414 | 385 486 415 | 385 486 416 | 385 486 417 | 385 486 418 | 385 486 419 | 385 486 420 | 385 486 421 | 385 486 422 | 385 486 423 | 385 486 424 | 385 486 425 | 385 486 426 | 385 486 427 | 385 486 428 | 385 486 429 | 385 486 430 | 385 486 431 | 385 486 432 | 385 486 433 | 385 486 434 | 385 486 435 | 385 486 436 | 385 486 437 | 385 486 438 | 385 486 439 | 385 486 440 | 385 486 441 | 385 486 442 | 385 486 443 | 385 486 444 | 385 486 445 | 385 486 446 | 385 486 447 | 385 486 448 | 385 486 449 | 385 486 450 | 385 486 451 | 385 486 452 | 385 486 453 | 385 486 454 | 385 486 455 | 385 486 456 | 385 486 457 | 385 486 458 | 385 486 459 | 385 486 460 | 385 486 461 | 385 486 462 | 385 486 463 | 385 486 464 | 385 486 465 | 385 486 466 | 385 486 467 | 385 486 468 | 385 486 469 | 385 486 470 | 385 486 471 | 385 486 472 | 385 486 473 | 385 486 474 | 385 486 475 | 385 486 476 | 385 486 477 | 385 486 478 | 385 486 479 | 385 486 480 | 385 486 481 | 385 486 482 | 385 486 483 | 385 486 484 | 385 486 485 | 385 486 486 | 385 486 487 | 486 582 488 | 486 582 489 | 486 582 490 | 486 582 491 | 486 582 492 | 486 582 493 | 486 582 494 | 486 582 495 | 486 582 496 | 486 582 497 | 486 582 498 | 486 582 499 | 486 582 500 | 486 582 501 | 486 582 502 | 486 582 503 | 486 582 504 | 486 582 505 | 486 582 506 | 486 582 507 | 486 582 508 | 486 582 509 | 486 582 510 | 486 582 511 | 486 582 512 | 486 582 513 | 486 582 514 | 486 582 515 | 486 582 516 | 486 582 517 | 486 582 518 | 486 582 519 | 486 582 520 | 486 582 521 | 486 582 522 | 486 582 523 | 486 582 524 | 486 582 525 | 486 582 526 | 486 582 527 | 486 582 528 | 486 582 529 | 486 582 530 | 486 582 531 | 486 582 532 | 486 582 533 | 486 582 534 | 486 582 535 | 486 582 536 | 486 582 537 | 486 582 538 | 486 582 539 | 486 582 540 | 486 582 541 | 486 582 542 | 486 582 543 | 486 582 544 | 486 582 545 | 486 582 546 | 486 582 547 | 486 582 548 | 486 582 549 | 486 582 550 | 486 582 551 | 486 582 552 | 486 582 553 | 486 582 554 | 486 582 555 | 486 582 556 | 486 582 557 | 486 582 558 | 486 582 559 | 486 582 560 | 486 582 561 | 486 582 562 | 486 582 563 | 486 582 564 | 486 582 565 | 486 582 566 | 486 582 567 | 486 582 568 | 486 582 569 | 486 582 570 | 486 582 571 | 486 582 572 | 486 582 573 | 486 582 574 | 486 582 575 | 486 582 576 | 486 582 577 | 486 582 578 | 486 582 579 | 486 582 580 | 486 582 581 | 486 582 582 | 486 582 583 | 582 624 584 | 582 624 585 | 582 624 586 | 582 624 587 | 582 624 588 | 582 624 589 | 582 624 590 | 582 624 591 | 582 624 592 | 582 624 593 | 582 624 594 | 582 624 595 | 582 624 596 | 582 624 597 | 582 624 598 | 582 624 599 | 582 624 600 | 582 624 601 | 582 624 602 | 582 624 603 | 582 624 604 | 582 624 605 | 582 624 606 | 582 624 607 | 582 624 608 | 582 624 609 | 582 624 610 | 582 624 611 | 582 624 612 | 582 624 613 | 582 624 614 | 582 624 615 | 582 624 616 | 582 624 617 | 582 624 618 | 582 624 619 | 582 624 620 | 582 624 621 | 582 624 622 | 582 624 623 | 582 624 624 | 582 624 625 | 624 631 626 | 624 631 627 | 624 631 628 | 624 631 629 | 624 631 630 | 624 631 631 | 624 631 632 | 631 638 633 | 631 638 634 | 631 638 635 | 631 638 636 | 631 638 637 | 631 638 638 | 631 638 639 | 638 665 640 | 638 665 641 | 638 665 642 | 638 665 643 | 638 665 644 | 638 665 645 | 638 665 646 | 638 665 647 | 638 665 648 | 638 665 649 | 638 665 650 | 638 665 651 | 638 665 652 | 638 665 653 | 638 665 654 | 638 665 655 | 638 665 656 | 638 665 657 | 638 665 658 | 638 665 659 | 638 665 660 | 638 665 661 | 638 665 662 | 638 665 663 | 638 665 664 | 638 665 665 | 638 665 666 | 665 753 667 | 665 753 668 | 665 753 669 | 665 753 670 | 665 753 671 | 665 753 672 | 665 753 673 | 665 753 674 | 665 753 675 | 665 753 676 | 665 753 677 | 665 753 678 | 665 753 679 | 665 753 680 | 665 753 681 | 665 753 682 | 665 753 683 | 665 753 684 | 665 753 685 | 665 753 686 | 665 753 687 | 665 753 688 | 665 753 689 | 665 753 690 | 665 753 691 | 665 753 692 | 665 753 693 | 665 753 694 | 665 753 695 | 665 753 696 | 665 753 697 | 665 753 698 | 665 753 699 | 665 753 700 | 665 753 701 | 665 753 702 | 665 753 703 | 665 753 704 | 665 753 705 | 665 753 706 | 665 753 707 | 665 753 708 | 665 753 709 | 665 753 710 | 665 753 711 | 665 753 712 | 665 753 713 | 665 753 714 | 665 753 715 | 665 753 716 | 665 753 717 | 665 753 718 | 665 753 719 | 665 753 720 | 665 753 721 | 665 753 722 | 665 753 723 | 665 753 724 | 665 753 725 | 665 753 726 | 665 753 727 | 665 753 728 | 665 753 729 | 665 753 730 | 665 753 731 | 665 753 732 | 665 753 733 | 665 753 734 | 665 753 735 | 665 753 736 | 665 753 737 | 665 753 738 | 665 753 739 | 665 753 740 | 665 753 741 | 665 753 742 | 665 753 743 | 665 753 744 | 665 753 745 | 665 753 746 | 665 753 747 | 665 753 748 | 665 753 749 | 665 753 750 | 665 753 751 | 665 753 752 | 665 753 753 | 665 753 754 | 753 765 755 | 753 765 756 | 753 765 757 | 753 765 758 | 753 765 759 | 753 765 760 | 753 765 761 | 753 765 762 | 753 765 763 | 753 765 764 | 753 765 765 | 753 765 766 | 765 790 767 | 765 790 768 | 765 790 769 | 765 790 770 | 765 790 771 | 765 790 772 | 765 790 773 | 765 790 774 | 765 790 775 | 765 790 776 | 765 790 777 | 765 790 778 | 765 790 779 | 765 790 780 | 765 790 781 | 765 790 782 | 765 790 783 | 765 790 784 | 765 790 785 | 765 790 786 | 765 790 787 | 765 790 788 | 765 790 789 | 765 790 790 | 765 790 791 | 790 796 792 | 790 796 793 | 790 796 794 | 790 796 795 | 790 796 796 | 790 796 797 | 796 829 798 | 796 829 799 | 796 829 800 | 796 829 801 | 796 829 802 | 796 829 803 | 796 829 804 | 796 829 805 | 796 829 806 | 796 829 807 | 796 829 808 | 796 829 809 | 796 829 810 | 796 829 811 | 796 829 812 | 796 829 813 | 796 829 814 | 796 829 815 | 796 829 816 | 796 829 817 | 796 829 818 | 796 829 819 | 796 829 820 | 796 829 821 | 796 829 822 | 796 829 823 | 796 829 824 | 796 829 825 | 796 829 826 | 796 829 827 | 796 829 828 | 796 829 829 | 796 829 830 | 829 835 831 | 829 835 832 | 829 835 833 | 829 835 834 | 829 835 835 | 829 835 836 | -------------------------------------------------------------------------------- /structFiles/seqBoundsFiles/msls_austin_query_seqBounds.txt: -------------------------------------------------------------------------------- 1 | 0 159 2 | 0 159 3 | 0 159 4 | 0 159 5 | 0 159 6 | 0 159 7 | 0 159 8 | 0 159 9 | 0 159 10 | 0 159 11 | 0 159 12 | 0 159 13 | 0 159 14 | 0 159 15 | 0 159 16 | 0 159 17 | 0 159 18 | 0 159 19 | 0 159 20 | 0 159 21 | 0 159 22 | 0 159 23 | 0 159 24 | 0 159 25 | 0 159 26 | 0 159 27 | 0 159 28 | 0 159 29 | 0 159 30 | 0 159 31 | 0 159 32 | 0 159 33 | 0 159 34 | 0 159 35 | 0 159 36 | 0 159 37 | 0 159 38 | 0 159 39 | 0 159 40 | 0 159 41 | 0 159 42 | 0 159 43 | 0 159 44 | 0 159 45 | 0 159 46 | 0 159 47 | 0 159 48 | 0 159 49 | 0 159 50 | 0 159 51 | 0 159 52 | 0 159 53 | 0 159 54 | 0 159 55 | 0 159 56 | 0 159 57 | 0 159 58 | 0 159 59 | 0 159 60 | 0 159 61 | 0 159 62 | 0 159 63 | 0 159 64 | 0 159 65 | 0 159 66 | 0 159 67 | 0 159 68 | 0 159 69 | 0 159 70 | 0 159 71 | 0 159 72 | 0 159 73 | 0 159 74 | 0 159 75 | 0 159 76 | 0 159 77 | 0 159 78 | 0 159 79 | 0 159 80 | 0 159 81 | 0 159 82 | 0 159 83 | 0 159 84 | 0 159 85 | 0 159 86 | 0 159 87 | 0 159 88 | 0 159 89 | 0 159 90 | 0 159 91 | 0 159 92 | 0 159 93 | 0 159 94 | 0 159 95 | 0 159 96 | 0 159 97 | 0 159 98 | 0 159 99 | 0 159 100 | 0 159 101 | 0 159 102 | 0 159 103 | 0 159 104 | 0 159 105 | 0 159 106 | 0 159 107 | 0 159 108 | 0 159 109 | 0 159 110 | 0 159 111 | 0 159 112 | 0 159 113 | 0 159 114 | 0 159 115 | 0 159 116 | 0 159 117 | 0 159 118 | 0 159 119 | 0 159 120 | 0 159 121 | 0 159 122 | 0 159 123 | 0 159 124 | 0 159 125 | 0 159 126 | 0 159 127 | 0 159 128 | 0 159 129 | 0 159 130 | 0 159 131 | 0 159 132 | 0 159 133 | 0 159 134 | 0 159 135 | 0 159 136 | 0 159 137 | 0 159 138 | 0 159 139 | 0 159 140 | 0 159 141 | 0 159 142 | 0 159 143 | 0 159 144 | 0 159 145 | 0 159 146 | 0 159 147 | 0 159 148 | 0 159 149 | 0 159 150 | 0 159 151 | 0 159 152 | 0 159 153 | 0 159 154 | 0 159 155 | 0 159 156 | 0 159 157 | 0 159 158 | 0 159 159 | 0 159 160 | 159 186 161 | 159 186 162 | 159 186 163 | 159 186 164 | 159 186 165 | 159 186 166 | 159 186 167 | 159 186 168 | 159 186 169 | 159 186 170 | 159 186 171 | 159 186 172 | 159 186 173 | 159 186 174 | 159 186 175 | 159 186 176 | 159 186 177 | 159 186 178 | 159 186 179 | 159 186 180 | 159 186 181 | 159 186 182 | 159 186 183 | 159 186 184 | 159 186 185 | 159 186 186 | 159 186 187 | 186 209 188 | 186 209 189 | 186 209 190 | 186 209 191 | 186 209 192 | 186 209 193 | 186 209 194 | 186 209 195 | 186 209 196 | 186 209 197 | 186 209 198 | 186 209 199 | 186 209 200 | 186 209 201 | 186 209 202 | 186 209 203 | 186 209 204 | 186 209 205 | 186 209 206 | 186 209 207 | 186 209 208 | 186 209 209 | 186 209 210 | 209 245 211 | 209 245 212 | 209 245 213 | 209 245 214 | 209 245 215 | 209 245 216 | 209 245 217 | 209 245 218 | 209 245 219 | 209 245 220 | 209 245 221 | 209 245 222 | 209 245 223 | 209 245 224 | 209 245 225 | 209 245 226 | 209 245 227 | 209 245 228 | 209 245 229 | 209 245 230 | 209 245 231 | 209 245 232 | 209 245 233 | 209 245 234 | 209 245 235 | 209 245 236 | 209 245 237 | 209 245 238 | 209 245 239 | 209 245 240 | 209 245 241 | 209 245 242 | 209 245 243 | 209 245 244 | 209 245 245 | 209 245 246 | 245 278 247 | 245 278 248 | 245 278 249 | 245 278 250 | 245 278 251 | 245 278 252 | 245 278 253 | 245 278 254 | 245 278 255 | 245 278 256 | 245 278 257 | 245 278 258 | 245 278 259 | 245 278 260 | 245 278 261 | 245 278 262 | 245 278 263 | 245 278 264 | 245 278 265 | 245 278 266 | 245 278 267 | 245 278 268 | 245 278 269 | 245 278 270 | 245 278 271 | 245 278 272 | 245 278 273 | 245 278 274 | 245 278 275 | 245 278 276 | 245 278 277 | 245 278 278 | 245 278 279 | 278 487 280 | 278 487 281 | 278 487 282 | 278 487 283 | 278 487 284 | 278 487 285 | 278 487 286 | 278 487 287 | 278 487 288 | 278 487 289 | 278 487 290 | 278 487 291 | 278 487 292 | 278 487 293 | 278 487 294 | 278 487 295 | 278 487 296 | 278 487 297 | 278 487 298 | 278 487 299 | 278 487 300 | 278 487 301 | 278 487 302 | 278 487 303 | 278 487 304 | 278 487 305 | 278 487 306 | 278 487 307 | 278 487 308 | 278 487 309 | 278 487 310 | 278 487 311 | 278 487 312 | 278 487 313 | 278 487 314 | 278 487 315 | 278 487 316 | 278 487 317 | 278 487 318 | 278 487 319 | 278 487 320 | 278 487 321 | 278 487 322 | 278 487 323 | 278 487 324 | 278 487 325 | 278 487 326 | 278 487 327 | 278 487 328 | 278 487 329 | 278 487 330 | 278 487 331 | 278 487 332 | 278 487 333 | 278 487 334 | 278 487 335 | 278 487 336 | 278 487 337 | 278 487 338 | 278 487 339 | 278 487 340 | 278 487 341 | 278 487 342 | 278 487 343 | 278 487 344 | 278 487 345 | 278 487 346 | 278 487 347 | 278 487 348 | 278 487 349 | 278 487 350 | 278 487 351 | 278 487 352 | 278 487 353 | 278 487 354 | 278 487 355 | 278 487 356 | 278 487 357 | 278 487 358 | 278 487 359 | 278 487 360 | 278 487 361 | 278 487 362 | 278 487 363 | 278 487 364 | 278 487 365 | 278 487 366 | 278 487 367 | 278 487 368 | 278 487 369 | 278 487 370 | 278 487 371 | 278 487 372 | 278 487 373 | 278 487 374 | 278 487 375 | 278 487 376 | 278 487 377 | 278 487 378 | 278 487 379 | 278 487 380 | 278 487 381 | 278 487 382 | 278 487 383 | 278 487 384 | 278 487 385 | 278 487 386 | 278 487 387 | 278 487 388 | 278 487 389 | 278 487 390 | 278 487 391 | 278 487 392 | 278 487 393 | 278 487 394 | 278 487 395 | 278 487 396 | 278 487 397 | 278 487 398 | 278 487 399 | 278 487 400 | 278 487 401 | 278 487 402 | 278 487 403 | 278 487 404 | 278 487 405 | 278 487 406 | 278 487 407 | 278 487 408 | 278 487 409 | 278 487 410 | 278 487 411 | 278 487 412 | 278 487 413 | 278 487 414 | 278 487 415 | 278 487 416 | 278 487 417 | 278 487 418 | 278 487 419 | 278 487 420 | 278 487 421 | 278 487 422 | 278 487 423 | 278 487 424 | 278 487 425 | 278 487 426 | 278 487 427 | 278 487 428 | 278 487 429 | 278 487 430 | 278 487 431 | 278 487 432 | 278 487 433 | 278 487 434 | 278 487 435 | 278 487 436 | 278 487 437 | 278 487 438 | 278 487 439 | 278 487 440 | 278 487 441 | 278 487 442 | 278 487 443 | 278 487 444 | 278 487 445 | 278 487 446 | 278 487 447 | 278 487 448 | 278 487 449 | 278 487 450 | 278 487 451 | 278 487 452 | 278 487 453 | 278 487 454 | 278 487 455 | 278 487 456 | 278 487 457 | 278 487 458 | 278 487 459 | 278 487 460 | 278 487 461 | 278 487 462 | 278 487 463 | 278 487 464 | 278 487 465 | 278 487 466 | 278 487 467 | 278 487 468 | 278 487 469 | 278 487 470 | 278 487 471 | 278 487 472 | 278 487 473 | 278 487 474 | 278 487 475 | 278 487 476 | 278 487 477 | 278 487 478 | 278 487 479 | 278 487 480 | 278 487 481 | 278 487 482 | 278 487 483 | 278 487 484 | 278 487 485 | 278 487 486 | 278 487 487 | 278 487 488 | 487 652 489 | 487 652 490 | 487 652 491 | 487 652 492 | 487 652 493 | 487 652 494 | 487 652 495 | 487 652 496 | 487 652 497 | 487 652 498 | 487 652 499 | 487 652 500 | 487 652 501 | 487 652 502 | 487 652 503 | 487 652 504 | 487 652 505 | 487 652 506 | 487 652 507 | 487 652 508 | 487 652 509 | 487 652 510 | 487 652 511 | 487 652 512 | 487 652 513 | 487 652 514 | 487 652 515 | 487 652 516 | 487 652 517 | 487 652 518 | 487 652 519 | 487 652 520 | 487 652 521 | 487 652 522 | 487 652 523 | 487 652 524 | 487 652 525 | 487 652 526 | 487 652 527 | 487 652 528 | 487 652 529 | 487 652 530 | 487 652 531 | 487 652 532 | 487 652 533 | 487 652 534 | 487 652 535 | 487 652 536 | 487 652 537 | 487 652 538 | 487 652 539 | 487 652 540 | 487 652 541 | 487 652 542 | 487 652 543 | 487 652 544 | 487 652 545 | 487 652 546 | 487 652 547 | 487 652 548 | 487 652 549 | 487 652 550 | 487 652 551 | 487 652 552 | 487 652 553 | 487 652 554 | 487 652 555 | 487 652 556 | 487 652 557 | 487 652 558 | 487 652 559 | 487 652 560 | 487 652 561 | 487 652 562 | 487 652 563 | 487 652 564 | 487 652 565 | 487 652 566 | 487 652 567 | 487 652 568 | 487 652 569 | 487 652 570 | 487 652 571 | 487 652 572 | 487 652 573 | 487 652 574 | 487 652 575 | 487 652 576 | 487 652 577 | 487 652 578 | 487 652 579 | 487 652 580 | 487 652 581 | 487 652 582 | 487 652 583 | 487 652 584 | 487 652 585 | 487 652 586 | 487 652 587 | 487 652 588 | 487 652 589 | 487 652 590 | 487 652 591 | 487 652 592 | 487 652 593 | 487 652 594 | 487 652 595 | 487 652 596 | 487 652 597 | 487 652 598 | 487 652 599 | 487 652 600 | 487 652 601 | 487 652 602 | 487 652 603 | 487 652 604 | 487 652 605 | 487 652 606 | 487 652 607 | 487 652 608 | 487 652 609 | 487 652 610 | 487 652 611 | 487 652 612 | 487 652 613 | 487 652 614 | 487 652 615 | 487 652 616 | 487 652 617 | 487 652 618 | 487 652 619 | 487 652 620 | 487 652 621 | 487 652 622 | 487 652 623 | 487 652 624 | 487 652 625 | 487 652 626 | 487 652 627 | 487 652 628 | 487 652 629 | 487 652 630 | 487 652 631 | 487 652 632 | 487 652 633 | 487 652 634 | 487 652 635 | 487 652 636 | 487 652 637 | 487 652 638 | 487 652 639 | 487 652 640 | 487 652 641 | 487 652 642 | 487 652 643 | 487 652 644 | 487 652 645 | 487 652 646 | 487 652 647 | 487 652 648 | 487 652 649 | 487 652 650 | 487 652 651 | 487 652 652 | 487 652 653 | 652 697 654 | 652 697 655 | 652 697 656 | 652 697 657 | 652 697 658 | 652 697 659 | 652 697 660 | 652 697 661 | 652 697 662 | 652 697 663 | 652 697 664 | 652 697 665 | 652 697 666 | 652 697 667 | 652 697 668 | 652 697 669 | 652 697 670 | 652 697 671 | 652 697 672 | 652 697 673 | 652 697 674 | 652 697 675 | 652 697 676 | 652 697 677 | 652 697 678 | 652 697 679 | 652 697 680 | 652 697 681 | 652 697 682 | 652 697 683 | 652 697 684 | 652 697 685 | 652 697 686 | 652 697 687 | 652 697 688 | 652 697 689 | 652 697 690 | 652 697 691 | 652 697 692 | 652 697 693 | 652 697 694 | 652 697 695 | 652 697 696 | 652 697 697 | 652 697 698 | 697 717 699 | 697 717 700 | 697 717 701 | 697 717 702 | 697 717 703 | 697 717 704 | 697 717 705 | 697 717 706 | 697 717 707 | 697 717 708 | 697 717 709 | 697 717 710 | 697 717 711 | 697 717 712 | 697 717 713 | 697 717 714 | 697 717 715 | 697 717 716 | 697 717 717 | 697 717 718 | 717 735 719 | 717 735 720 | 717 735 721 | 717 735 722 | 717 735 723 | 717 735 724 | 717 735 725 | 717 735 726 | 717 735 727 | 717 735 728 | 717 735 729 | 717 735 730 | 717 735 731 | 717 735 732 | 717 735 733 | 717 735 734 | 717 735 735 | 717 735 736 | 735 742 737 | 735 742 738 | 735 742 739 | 735 742 740 | 735 742 741 | 735 742 742 | 735 742 743 | 742 926 744 | 742 926 745 | 742 926 746 | 742 926 747 | 742 926 748 | 742 926 749 | 742 926 750 | 742 926 751 | 742 926 752 | 742 926 753 | 742 926 754 | 742 926 755 | 742 926 756 | 742 926 757 | 742 926 758 | 742 926 759 | 742 926 760 | 742 926 761 | 742 926 762 | 742 926 763 | 742 926 764 | 742 926 765 | 742 926 766 | 742 926 767 | 742 926 768 | 742 926 769 | 742 926 770 | 742 926 771 | 742 926 772 | 742 926 773 | 742 926 774 | 742 926 775 | 742 926 776 | 742 926 777 | 742 926 778 | 742 926 779 | 742 926 780 | 742 926 781 | 742 926 782 | 742 926 783 | 742 926 784 | 742 926 785 | 742 926 786 | 742 926 787 | 742 926 788 | 742 926 789 | 742 926 790 | 742 926 791 | 742 926 792 | 742 926 793 | 742 926 794 | 742 926 795 | 742 926 796 | 742 926 797 | 742 926 798 | 742 926 799 | 742 926 800 | 742 926 801 | 742 926 802 | 742 926 803 | 742 926 804 | 742 926 805 | 742 926 806 | 742 926 807 | 742 926 808 | 742 926 809 | 742 926 810 | 742 926 811 | 742 926 812 | 742 926 813 | 742 926 814 | 742 926 815 | 742 926 816 | 742 926 817 | 742 926 818 | 742 926 819 | 742 926 820 | 742 926 821 | 742 926 822 | 742 926 823 | 742 926 824 | 742 926 825 | 742 926 826 | 742 926 827 | 742 926 828 | 742 926 829 | 742 926 830 | 742 926 831 | 742 926 832 | 742 926 833 | 742 926 834 | 742 926 835 | 742 926 836 | 742 926 837 | 742 926 838 | 742 926 839 | 742 926 840 | 742 926 841 | 742 926 842 | 742 926 843 | 742 926 844 | 742 926 845 | 742 926 846 | 742 926 847 | 742 926 848 | 742 926 849 | 742 926 850 | 742 926 851 | 742 926 852 | 742 926 853 | 742 926 854 | 742 926 855 | 742 926 856 | 742 926 857 | 742 926 858 | 742 926 859 | 742 926 860 | 742 926 861 | 742 926 862 | 742 926 863 | 742 926 864 | 742 926 865 | 742 926 866 | 742 926 867 | 742 926 868 | 742 926 869 | 742 926 870 | 742 926 871 | 742 926 872 | 742 926 873 | 742 926 874 | 742 926 875 | 742 926 876 | 742 926 877 | 742 926 878 | 742 926 879 | 742 926 880 | 742 926 881 | 742 926 882 | 742 926 883 | 742 926 884 | 742 926 885 | 742 926 886 | 742 926 887 | 742 926 888 | 742 926 889 | 742 926 890 | 742 926 891 | 742 926 892 | 742 926 893 | 742 926 894 | 742 926 895 | 742 926 896 | 742 926 897 | 742 926 898 | 742 926 899 | 742 926 900 | 742 926 901 | 742 926 902 | 742 926 903 | 742 926 904 | 742 926 905 | 742 926 906 | 742 926 907 | 742 926 908 | 742 926 909 | 742 926 910 | 742 926 911 | 742 926 912 | 742 926 913 | 742 926 914 | 742 926 915 | 742 926 916 | 742 926 917 | 742 926 918 | 742 926 919 | 742 926 920 | 742 926 921 | 742 926 922 | 742 926 923 | 742 926 924 | 742 926 925 | 742 926 926 | 742 926 927 | 926 966 928 | 926 966 929 | 926 966 930 | 926 966 931 | 926 966 932 | 926 966 933 | 926 966 934 | 926 966 935 | 926 966 936 | 926 966 937 | 926 966 938 | 926 966 939 | 926 966 940 | 926 966 941 | 926 966 942 | 926 966 943 | 926 966 944 | 926 966 945 | 926 966 946 | 926 966 947 | 926 966 948 | 926 966 949 | 926 966 950 | 926 966 951 | 926 966 952 | 926 966 953 | 926 966 954 | 926 966 955 | 926 966 956 | 926 966 957 | 926 966 958 | 926 966 959 | 926 966 960 | 926 966 961 | 926 966 962 | 926 966 963 | 926 966 964 | 926 966 965 | 926 966 966 | 926 966 967 | 966 980 968 | 966 980 969 | 966 980 970 | 966 980 971 | 966 980 972 | 966 980 973 | 966 980 974 | 966 980 975 | 966 980 976 | 966 980 977 | 966 980 978 | 966 980 979 | 966 980 980 | 966 980 981 | 980 1040 982 | 980 1040 983 | 980 1040 984 | 980 1040 985 | 980 1040 986 | 980 1040 987 | 980 1040 988 | 980 1040 989 | 980 1040 990 | 980 1040 991 | 980 1040 992 | 980 1040 993 | 980 1040 994 | 980 1040 995 | 980 1040 996 | 980 1040 997 | 980 1040 998 | 980 1040 999 | 980 1040 1000 | 980 1040 1001 | 980 1040 1002 | 980 1040 1003 | 980 1040 1004 | 980 1040 1005 | 980 1040 1006 | 980 1040 1007 | 980 1040 1008 | 980 1040 1009 | 980 1040 1010 | 980 1040 1011 | 980 1040 1012 | 980 1040 1013 | 980 1040 1014 | 980 1040 1015 | 980 1040 1016 | 980 1040 1017 | 980 1040 1018 | 980 1040 1019 | 980 1040 1020 | 980 1040 1021 | 980 1040 1022 | 980 1040 1023 | 980 1040 1024 | 980 1040 1025 | 980 1040 1026 | 980 1040 1027 | 980 1040 1028 | 980 1040 1029 | 980 1040 1030 | 980 1040 1031 | 980 1040 1032 | 980 1040 1033 | 980 1040 1034 | 980 1040 1035 | 980 1040 1036 | 980 1040 1037 | 980 1040 1038 | 980 1040 1039 | 980 1040 1040 | 980 1040 1041 | 1040 1058 1042 | 1040 1058 1043 | 1040 1058 1044 | 1040 1058 1045 | 1040 1058 1046 | 1040 1058 1047 | 1040 1058 1048 | 1040 1058 1049 | 1040 1058 1050 | 1040 1058 1051 | 1040 1058 1052 | 1040 1058 1053 | 1040 1058 1054 | 1040 1058 1055 | 1040 1058 1056 | 1040 1058 1057 | 1040 1058 1058 | 1040 1058 1059 | 1058 1244 1060 | 1058 1244 1061 | 1058 1244 1062 | 1058 1244 1063 | 1058 1244 1064 | 1058 1244 1065 | 1058 1244 1066 | 1058 1244 1067 | 1058 1244 1068 | 1058 1244 1069 | 1058 1244 1070 | 1058 1244 1071 | 1058 1244 1072 | 1058 1244 1073 | 1058 1244 1074 | 1058 1244 1075 | 1058 1244 1076 | 1058 1244 1077 | 1058 1244 1078 | 1058 1244 1079 | 1058 1244 1080 | 1058 1244 1081 | 1058 1244 1082 | 1058 1244 1083 | 1058 1244 1084 | 1058 1244 1085 | 1058 1244 1086 | 1058 1244 1087 | 1058 1244 1088 | 1058 1244 1089 | 1058 1244 1090 | 1058 1244 1091 | 1058 1244 1092 | 1058 1244 1093 | 1058 1244 1094 | 1058 1244 1095 | 1058 1244 1096 | 1058 1244 1097 | 1058 1244 1098 | 1058 1244 1099 | 1058 1244 1100 | 1058 1244 1101 | 1058 1244 1102 | 1058 1244 1103 | 1058 1244 1104 | 1058 1244 1105 | 1058 1244 1106 | 1058 1244 1107 | 1058 1244 1108 | 1058 1244 1109 | 1058 1244 1110 | 1058 1244 1111 | 1058 1244 1112 | 1058 1244 1113 | 1058 1244 1114 | 1058 1244 1115 | 1058 1244 1116 | 1058 1244 1117 | 1058 1244 1118 | 1058 1244 1119 | 1058 1244 1120 | 1058 1244 1121 | 1058 1244 1122 | 1058 1244 1123 | 1058 1244 1124 | 1058 1244 1125 | 1058 1244 1126 | 1058 1244 1127 | 1058 1244 1128 | 1058 1244 1129 | 1058 1244 1130 | 1058 1244 1131 | 1058 1244 1132 | 1058 1244 1133 | 1058 1244 1134 | 1058 1244 1135 | 1058 1244 1136 | 1058 1244 1137 | 1058 1244 1138 | 1058 1244 1139 | 1058 1244 1140 | 1058 1244 1141 | 1058 1244 1142 | 1058 1244 1143 | 1058 1244 1144 | 1058 1244 1145 | 1058 1244 1146 | 1058 1244 1147 | 1058 1244 1148 | 1058 1244 1149 | 1058 1244 1150 | 1058 1244 1151 | 1058 1244 1152 | 1058 1244 1153 | 1058 1244 1154 | 1058 1244 1155 | 1058 1244 1156 | 1058 1244 1157 | 1058 1244 1158 | 1058 1244 1159 | 1058 1244 1160 | 1058 1244 1161 | 1058 1244 1162 | 1058 1244 1163 | 1058 1244 1164 | 1058 1244 1165 | 1058 1244 1166 | 1058 1244 1167 | 1058 1244 1168 | 1058 1244 1169 | 1058 1244 1170 | 1058 1244 1171 | 1058 1244 1172 | 1058 1244 1173 | 1058 1244 1174 | 1058 1244 1175 | 1058 1244 1176 | 1058 1244 1177 | 1058 1244 1178 | 1058 1244 1179 | 1058 1244 1180 | 1058 1244 1181 | 1058 1244 1182 | 1058 1244 1183 | 1058 1244 1184 | 1058 1244 1185 | 1058 1244 1186 | 1058 1244 1187 | 1058 1244 1188 | 1058 1244 1189 | 1058 1244 1190 | 1058 1244 1191 | 1058 1244 1192 | 1058 1244 1193 | 1058 1244 1194 | 1058 1244 1195 | 1058 1244 1196 | 1058 1244 1197 | 1058 1244 1198 | 1058 1244 1199 | 1058 1244 1200 | 1058 1244 1201 | 1058 1244 1202 | 1058 1244 1203 | 1058 1244 1204 | 1058 1244 1205 | 1058 1244 1206 | 1058 1244 1207 | 1058 1244 1208 | 1058 1244 1209 | 1058 1244 1210 | 1058 1244 1211 | 1058 1244 1212 | 1058 1244 1213 | 1058 1244 1214 | 1058 1244 1215 | 1058 1244 1216 | 1058 1244 1217 | 1058 1244 1218 | 1058 1244 1219 | 1058 1244 1220 | 1058 1244 1221 | 1058 1244 1222 | 1058 1244 1223 | 1058 1244 1224 | 1058 1244 1225 | 1058 1244 1226 | 1058 1244 1227 | 1058 1244 1228 | 1058 1244 1229 | 1058 1244 1230 | 1058 1244 1231 | 1058 1244 1232 | 1058 1244 1233 | 1058 1244 1234 | 1058 1244 1235 | 1058 1244 1236 | 1058 1244 1237 | 1058 1244 1238 | 1058 1244 1239 | 1058 1244 1240 | 1058 1244 1241 | 1058 1244 1242 | 1058 1244 1243 | 1058 1244 1244 | 1058 1244 1245 | 1244 1312 1246 | 1244 1312 1247 | 1244 1312 1248 | 1244 1312 1249 | 1244 1312 1250 | 1244 1312 1251 | 1244 1312 1252 | 1244 1312 1253 | 1244 1312 1254 | 1244 1312 1255 | 1244 1312 1256 | 1244 1312 1257 | 1244 1312 1258 | 1244 1312 1259 | 1244 1312 1260 | 1244 1312 1261 | 1244 1312 1262 | 1244 1312 1263 | 1244 1312 1264 | 1244 1312 1265 | 1244 1312 1266 | 1244 1312 1267 | 1244 1312 1268 | 1244 1312 1269 | 1244 1312 1270 | 1244 1312 1271 | 1244 1312 1272 | 1244 1312 1273 | 1244 1312 1274 | 1244 1312 1275 | 1244 1312 1276 | 1244 1312 1277 | 1244 1312 1278 | 1244 1312 1279 | 1244 1312 1280 | 1244 1312 1281 | 1244 1312 1282 | 1244 1312 1283 | 1244 1312 1284 | 1244 1312 1285 | 1244 1312 1286 | 1244 1312 1287 | 1244 1312 1288 | 1244 1312 1289 | 1244 1312 1290 | 1244 1312 1291 | 1244 1312 1292 | 1244 1312 1293 | 1244 1312 1294 | 1244 1312 1295 | 1244 1312 1296 | 1244 1312 1297 | 1244 1312 1298 | 1244 1312 1299 | 1244 1312 1300 | 1244 1312 1301 | 1244 1312 1302 | 1244 1312 1303 | 1244 1312 1304 | 1244 1312 1305 | 1244 1312 1306 | 1244 1312 1307 | 1244 1312 1308 | 1244 1312 1309 | 1244 1312 1310 | 1244 1312 1311 | 1244 1312 1312 | 1244 1312 1313 | 1312 1416 1314 | 1312 1416 1315 | 1312 1416 1316 | 1312 1416 1317 | 1312 1416 1318 | 1312 1416 1319 | 1312 1416 1320 | 1312 1416 1321 | 1312 1416 1322 | 1312 1416 1323 | 1312 1416 1324 | 1312 1416 1325 | 1312 1416 1326 | 1312 1416 1327 | 1312 1416 1328 | 1312 1416 1329 | 1312 1416 1330 | 1312 1416 1331 | 1312 1416 1332 | 1312 1416 1333 | 1312 1416 1334 | 1312 1416 1335 | 1312 1416 1336 | 1312 1416 1337 | 1312 1416 1338 | 1312 1416 1339 | 1312 1416 1340 | 1312 1416 1341 | 1312 1416 1342 | 1312 1416 1343 | 1312 1416 1344 | 1312 1416 1345 | 1312 1416 1346 | 1312 1416 1347 | 1312 1416 1348 | 1312 1416 1349 | 1312 1416 1350 | 1312 1416 1351 | 1312 1416 1352 | 1312 1416 1353 | 1312 1416 1354 | 1312 1416 1355 | 1312 1416 1356 | 1312 1416 1357 | 1312 1416 1358 | 1312 1416 1359 | 1312 1416 1360 | 1312 1416 1361 | 1312 1416 1362 | 1312 1416 1363 | 1312 1416 1364 | 1312 1416 1365 | 1312 1416 1366 | 1312 1416 1367 | 1312 1416 1368 | 1312 1416 1369 | 1312 1416 1370 | 1312 1416 1371 | 1312 1416 1372 | 1312 1416 1373 | 1312 1416 1374 | 1312 1416 1375 | 1312 1416 1376 | 1312 1416 1377 | 1312 1416 1378 | 1312 1416 1379 | 1312 1416 1380 | 1312 1416 1381 | 1312 1416 1382 | 1312 1416 1383 | 1312 1416 1384 | 1312 1416 1385 | 1312 1416 1386 | 1312 1416 1387 | 1312 1416 1388 | 1312 1416 1389 | 1312 1416 1390 | 1312 1416 1391 | 1312 1416 1392 | 1312 1416 1393 | 1312 1416 1394 | 1312 1416 1395 | 1312 1416 1396 | 1312 1416 1397 | 1312 1416 1398 | 1312 1416 1399 | 1312 1416 1400 | 1312 1416 1401 | 1312 1416 1402 | 1312 1416 1403 | 1312 1416 1404 | 1312 1416 1405 | 1312 1416 1406 | 1312 1416 1407 | 1312 1416 1408 | 1312 1416 1409 | 1312 1416 1410 | 1312 1416 1411 | 1312 1416 1412 | 1312 1416 1413 | 1312 1416 1414 | 1312 1416 1415 | 1312 1416 1416 | 1312 1416 1417 | 1416 1428 1418 | 1416 1428 1419 | 1416 1428 1420 | 1416 1428 1421 | 1416 1428 1422 | 1416 1428 1423 | 1416 1428 1424 | 1416 1428 1425 | 1416 1428 1426 | 1416 1428 1427 | 1416 1428 1428 | 1416 1428 1429 | 1428 1442 1430 | 1428 1442 1431 | 1428 1442 1432 | 1428 1442 1433 | 1428 1442 1434 | 1428 1442 1435 | 1428 1442 1436 | 1428 1442 1437 | 1428 1442 1438 | 1428 1442 1439 | 1428 1442 1440 | 1428 1442 1441 | 1428 1442 1442 | 1428 1442 1443 | 1442 1479 1444 | 1442 1479 1445 | 1442 1479 1446 | 1442 1479 1447 | 1442 1479 1448 | 1442 1479 1449 | 1442 1479 1450 | 1442 1479 1451 | 1442 1479 1452 | 1442 1479 1453 | 1442 1479 1454 | 1442 1479 1455 | 1442 1479 1456 | 1442 1479 1457 | 1442 1479 1458 | 1442 1479 1459 | 1442 1479 1460 | 1442 1479 1461 | 1442 1479 1462 | 1442 1479 1463 | 1442 1479 1464 | 1442 1479 1465 | 1442 1479 1466 | 1442 1479 1467 | 1442 1479 1468 | 1442 1479 1469 | 1442 1479 1470 | 1442 1479 1471 | 1442 1479 1472 | 1442 1479 1473 | 1442 1479 1474 | 1442 1479 1475 | 1442 1479 1476 | 1442 1479 1477 | 1442 1479 1478 | 1442 1479 1479 | 1442 1479 1480 | 1479 1495 1481 | 1479 1495 1482 | 1479 1495 1483 | 1479 1495 1484 | 1479 1495 1485 | 1479 1495 1486 | 1479 1495 1487 | 1479 1495 1488 | 1479 1495 1489 | 1479 1495 1490 | 1479 1495 1491 | 1479 1495 1492 | 1479 1495 1493 | 1479 1495 1494 | 1479 1495 1495 | 1479 1495 1496 | 1495 1642 1497 | 1495 1642 1498 | 1495 1642 1499 | 1495 1642 1500 | 1495 1642 1501 | 1495 1642 1502 | 1495 1642 1503 | 1495 1642 1504 | 1495 1642 1505 | 1495 1642 1506 | 1495 1642 1507 | 1495 1642 1508 | 1495 1642 1509 | 1495 1642 1510 | 1495 1642 1511 | 1495 1642 1512 | 1495 1642 1513 | 1495 1642 1514 | 1495 1642 1515 | 1495 1642 1516 | 1495 1642 1517 | 1495 1642 1518 | 1495 1642 1519 | 1495 1642 1520 | 1495 1642 1521 | 1495 1642 1522 | 1495 1642 1523 | 1495 1642 1524 | 1495 1642 1525 | 1495 1642 1526 | 1495 1642 1527 | 1495 1642 1528 | 1495 1642 1529 | 1495 1642 1530 | 1495 1642 1531 | 1495 1642 1532 | 1495 1642 1533 | 1495 1642 1534 | 1495 1642 1535 | 1495 1642 1536 | 1495 1642 1537 | 1495 1642 1538 | 1495 1642 1539 | 1495 1642 1540 | 1495 1642 1541 | 1495 1642 1542 | 1495 1642 1543 | 1495 1642 1544 | 1495 1642 1545 | 1495 1642 1546 | 1495 1642 1547 | 1495 1642 1548 | 1495 1642 1549 | 1495 1642 1550 | 1495 1642 1551 | 1495 1642 1552 | 1495 1642 1553 | 1495 1642 1554 | 1495 1642 1555 | 1495 1642 1556 | 1495 1642 1557 | 1495 1642 1558 | 1495 1642 1559 | 1495 1642 1560 | 1495 1642 1561 | 1495 1642 1562 | 1495 1642 1563 | 1495 1642 1564 | 1495 1642 1565 | 1495 1642 1566 | 1495 1642 1567 | 1495 1642 1568 | 1495 1642 1569 | 1495 1642 1570 | 1495 1642 1571 | 1495 1642 1572 | 1495 1642 1573 | 1495 1642 1574 | 1495 1642 1575 | 1495 1642 1576 | 1495 1642 1577 | 1495 1642 1578 | 1495 1642 1579 | 1495 1642 1580 | 1495 1642 1581 | 1495 1642 1582 | 1495 1642 1583 | 1495 1642 1584 | 1495 1642 1585 | 1495 1642 1586 | 1495 1642 1587 | 1495 1642 1588 | 1495 1642 1589 | 1495 1642 1590 | 1495 1642 1591 | 1495 1642 1592 | 1495 1642 1593 | 1495 1642 1594 | 1495 1642 1595 | 1495 1642 1596 | 1495 1642 1597 | 1495 1642 1598 | 1495 1642 1599 | 1495 1642 1600 | 1495 1642 1601 | 1495 1642 1602 | 1495 1642 1603 | 1495 1642 1604 | 1495 1642 1605 | 1495 1642 1606 | 1495 1642 1607 | 1495 1642 1608 | 1495 1642 1609 | 1495 1642 1610 | 1495 1642 1611 | 1495 1642 1612 | 1495 1642 1613 | 1495 1642 1614 | 1495 1642 1615 | 1495 1642 1616 | 1495 1642 1617 | 1495 1642 1618 | 1495 1642 1619 | 1495 1642 1620 | 1495 1642 1621 | 1495 1642 1622 | 1495 1642 1623 | 1495 1642 1624 | 1495 1642 1625 | 1495 1642 1626 | 1495 1642 1627 | 1495 1642 1628 | 1495 1642 1629 | 1495 1642 1630 | 1495 1642 1631 | 1495 1642 1632 | 1495 1642 1633 | 1495 1642 1634 | 1495 1642 1635 | 1495 1642 1636 | 1495 1642 1637 | 1495 1642 1638 | 1495 1642 1639 | 1495 1642 1640 | 1495 1642 1641 | 1495 1642 1642 | 1495 1642 1643 | 1642 1694 1644 | 1642 1694 1645 | 1642 1694 1646 | 1642 1694 1647 | 1642 1694 1648 | 1642 1694 1649 | 1642 1694 1650 | 1642 1694 1651 | 1642 1694 1652 | 1642 1694 1653 | 1642 1694 1654 | 1642 1694 1655 | 1642 1694 1656 | 1642 1694 1657 | 1642 1694 1658 | 1642 1694 1659 | 1642 1694 1660 | 1642 1694 1661 | 1642 1694 1662 | 1642 1694 1663 | 1642 1694 1664 | 1642 1694 1665 | 1642 1694 1666 | 1642 1694 1667 | 1642 1694 1668 | 1642 1694 1669 | 1642 1694 1670 | 1642 1694 1671 | 1642 1694 1672 | 1642 1694 1673 | 1642 1694 1674 | 1642 1694 1675 | 1642 1694 1676 | 1642 1694 1677 | 1642 1694 1678 | 1642 1694 1679 | 1642 1694 1680 | 1642 1694 1681 | 1642 1694 1682 | 1642 1694 1683 | 1642 1694 1684 | 1642 1694 1685 | 1642 1694 1686 | 1642 1694 1687 | 1642 1694 1688 | 1642 1694 1689 | 1642 1694 1690 | 1642 1694 1691 | 1642 1694 1692 | 1642 1694 1693 | 1642 1694 1694 | 1642 1694 1695 | 1694 1716 1696 | 1694 1716 1697 | 1694 1716 1698 | 1694 1716 1699 | 1694 1716 1700 | 1694 1716 1701 | 1694 1716 1702 | 1694 1716 1703 | 1694 1716 1704 | 1694 1716 1705 | 1694 1716 1706 | 1694 1716 1707 | 1694 1716 1708 | 1694 1716 1709 | 1694 1716 1710 | 1694 1716 1711 | 1694 1716 1712 | 1694 1716 1713 | 1694 1716 1714 | 1694 1716 1715 | 1694 1716 1716 | 1694 1716 1717 | 1716 1731 1718 | 1716 1731 1719 | 1716 1731 1720 | 1716 1731 1721 | 1716 1731 1722 | 1716 1731 1723 | 1716 1731 1724 | 1716 1731 1725 | 1716 1731 1726 | 1716 1731 1727 | 1716 1731 1728 | 1716 1731 1729 | 1716 1731 1730 | 1716 1731 1731 | 1716 1731 1732 | 1731 1732 1733 | -------------------------------------------------------------------------------- /structFiles/seqBoundsFiles/msls_melbourne_query_seqBounds.txt: -------------------------------------------------------------------------------- 1 | 0 31 2 | 0 31 3 | 0 31 4 | 0 31 5 | 0 31 6 | 0 31 7 | 0 31 8 | 0 31 9 | 0 31 10 | 0 31 11 | 0 31 12 | 0 31 13 | 0 31 14 | 0 31 15 | 0 31 16 | 0 31 17 | 0 31 18 | 0 31 19 | 0 31 20 | 0 31 21 | 0 31 22 | 0 31 23 | 0 31 24 | 0 31 25 | 0 31 26 | 0 31 27 | 0 31 28 | 0 31 29 | 0 31 30 | 0 31 31 | 0 31 32 | 31 119 33 | 31 119 34 | 31 119 35 | 31 119 36 | 31 119 37 | 31 119 38 | 31 119 39 | 31 119 40 | 31 119 41 | 31 119 42 | 31 119 43 | 31 119 44 | 31 119 45 | 31 119 46 | 31 119 47 | 31 119 48 | 31 119 49 | 31 119 50 | 31 119 51 | 31 119 52 | 31 119 53 | 31 119 54 | 31 119 55 | 31 119 56 | 31 119 57 | 31 119 58 | 31 119 59 | 31 119 60 | 31 119 61 | 31 119 62 | 31 119 63 | 31 119 64 | 31 119 65 | 31 119 66 | 31 119 67 | 31 119 68 | 31 119 69 | 31 119 70 | 31 119 71 | 31 119 72 | 31 119 73 | 31 119 74 | 31 119 75 | 31 119 76 | 31 119 77 | 31 119 78 | 31 119 79 | 31 119 80 | 31 119 81 | 31 119 82 | 31 119 83 | 31 119 84 | 31 119 85 | 31 119 86 | 31 119 87 | 31 119 88 | 31 119 89 | 31 119 90 | 31 119 91 | 31 119 92 | 31 119 93 | 31 119 94 | 31 119 95 | 31 119 96 | 31 119 97 | 31 119 98 | 31 119 99 | 31 119 100 | 31 119 101 | 31 119 102 | 31 119 103 | 31 119 104 | 31 119 105 | 31 119 106 | 31 119 107 | 31 119 108 | 31 119 109 | 31 119 110 | 31 119 111 | 31 119 112 | 31 119 113 | 31 119 114 | 31 119 115 | 31 119 116 | 31 119 117 | 31 119 118 | 31 119 119 | 31 119 120 | 119 138 121 | 119 138 122 | 119 138 123 | 119 138 124 | 119 138 125 | 119 138 126 | 119 138 127 | 119 138 128 | 119 138 129 | 119 138 130 | 119 138 131 | 119 138 132 | 119 138 133 | 119 138 134 | 119 138 135 | 119 138 136 | 119 138 137 | 119 138 138 | 119 138 139 | 138 212 140 | 138 212 141 | 138 212 142 | 138 212 143 | 138 212 144 | 138 212 145 | 138 212 146 | 138 212 147 | 138 212 148 | 138 212 149 | 138 212 150 | 138 212 151 | 138 212 152 | 138 212 153 | 138 212 154 | 138 212 155 | 138 212 156 | 138 212 157 | 138 212 158 | 138 212 159 | 138 212 160 | 138 212 161 | 138 212 162 | 138 212 163 | 138 212 164 | 138 212 165 | 138 212 166 | 138 212 167 | 138 212 168 | 138 212 169 | 138 212 170 | 138 212 171 | 138 212 172 | 138 212 173 | 138 212 174 | 138 212 175 | 138 212 176 | 138 212 177 | 138 212 178 | 138 212 179 | 138 212 180 | 138 212 181 | 138 212 182 | 138 212 183 | 138 212 184 | 138 212 185 | 138 212 186 | 138 212 187 | 138 212 188 | 138 212 189 | 138 212 190 | 138 212 191 | 138 212 192 | 138 212 193 | 138 212 194 | 138 212 195 | 138 212 196 | 138 212 197 | 138 212 198 | 138 212 199 | 138 212 200 | 138 212 201 | 138 212 202 | 138 212 203 | 138 212 204 | 138 212 205 | 138 212 206 | 138 212 207 | 138 212 208 | 138 212 209 | 138 212 210 | 138 212 211 | 138 212 212 | 138 212 213 | 212 616 214 | 212 616 215 | 212 616 216 | 212 616 217 | 212 616 218 | 212 616 219 | 212 616 220 | 212 616 221 | 212 616 222 | 212 616 223 | 212 616 224 | 212 616 225 | 212 616 226 | 212 616 227 | 212 616 228 | 212 616 229 | 212 616 230 | 212 616 231 | 212 616 232 | 212 616 233 | 212 616 234 | 212 616 235 | 212 616 236 | 212 616 237 | 212 616 238 | 212 616 239 | 212 616 240 | 212 616 241 | 212 616 242 | 212 616 243 | 212 616 244 | 212 616 245 | 212 616 246 | 212 616 247 | 212 616 248 | 212 616 249 | 212 616 250 | 212 616 251 | 212 616 252 | 212 616 253 | 212 616 254 | 212 616 255 | 212 616 256 | 212 616 257 | 212 616 258 | 212 616 259 | 212 616 260 | 212 616 261 | 212 616 262 | 212 616 263 | 212 616 264 | 212 616 265 | 212 616 266 | 212 616 267 | 212 616 268 | 212 616 269 | 212 616 270 | 212 616 271 | 212 616 272 | 212 616 273 | 212 616 274 | 212 616 275 | 212 616 276 | 212 616 277 | 212 616 278 | 212 616 279 | 212 616 280 | 212 616 281 | 212 616 282 | 212 616 283 | 212 616 284 | 212 616 285 | 212 616 286 | 212 616 287 | 212 616 288 | 212 616 289 | 212 616 290 | 212 616 291 | 212 616 292 | 212 616 293 | 212 616 294 | 212 616 295 | 212 616 296 | 212 616 297 | 212 616 298 | 212 616 299 | 212 616 300 | 212 616 301 | 212 616 302 | 212 616 303 | 212 616 304 | 212 616 305 | 212 616 306 | 212 616 307 | 212 616 308 | 212 616 309 | 212 616 310 | 212 616 311 | 212 616 312 | 212 616 313 | 212 616 314 | 212 616 315 | 212 616 316 | 212 616 317 | 212 616 318 | 212 616 319 | 212 616 320 | 212 616 321 | 212 616 322 | 212 616 323 | 212 616 324 | 212 616 325 | 212 616 326 | 212 616 327 | 212 616 328 | 212 616 329 | 212 616 330 | 212 616 331 | 212 616 332 | 212 616 333 | 212 616 334 | 212 616 335 | 212 616 336 | 212 616 337 | 212 616 338 | 212 616 339 | 212 616 340 | 212 616 341 | 212 616 342 | 212 616 343 | 212 616 344 | 212 616 345 | 212 616 346 | 212 616 347 | 212 616 348 | 212 616 349 | 212 616 350 | 212 616 351 | 212 616 352 | 212 616 353 | 212 616 354 | 212 616 355 | 212 616 356 | 212 616 357 | 212 616 358 | 212 616 359 | 212 616 360 | 212 616 361 | 212 616 362 | 212 616 363 | 212 616 364 | 212 616 365 | 212 616 366 | 212 616 367 | 212 616 368 | 212 616 369 | 212 616 370 | 212 616 371 | 212 616 372 | 212 616 373 | 212 616 374 | 212 616 375 | 212 616 376 | 212 616 377 | 212 616 378 | 212 616 379 | 212 616 380 | 212 616 381 | 212 616 382 | 212 616 383 | 212 616 384 | 212 616 385 | 212 616 386 | 212 616 387 | 212 616 388 | 212 616 389 | 212 616 390 | 212 616 391 | 212 616 392 | 212 616 393 | 212 616 394 | 212 616 395 | 212 616 396 | 212 616 397 | 212 616 398 | 212 616 399 | 212 616 400 | 212 616 401 | 212 616 402 | 212 616 403 | 212 616 404 | 212 616 405 | 212 616 406 | 212 616 407 | 212 616 408 | 212 616 409 | 212 616 410 | 212 616 411 | 212 616 412 | 212 616 413 | 212 616 414 | 212 616 415 | 212 616 416 | 212 616 417 | 212 616 418 | 212 616 419 | 212 616 420 | 212 616 421 | 212 616 422 | 212 616 423 | 212 616 424 | 212 616 425 | 212 616 426 | 212 616 427 | 212 616 428 | 212 616 429 | 212 616 430 | 212 616 431 | 212 616 432 | 212 616 433 | 212 616 434 | 212 616 435 | 212 616 436 | 212 616 437 | 212 616 438 | 212 616 439 | 212 616 440 | 212 616 441 | 212 616 442 | 212 616 443 | 212 616 444 | 212 616 445 | 212 616 446 | 212 616 447 | 212 616 448 | 212 616 449 | 212 616 450 | 212 616 451 | 212 616 452 | 212 616 453 | 212 616 454 | 212 616 455 | 212 616 456 | 212 616 457 | 212 616 458 | 212 616 459 | 212 616 460 | 212 616 461 | 212 616 462 | 212 616 463 | 212 616 464 | 212 616 465 | 212 616 466 | 212 616 467 | 212 616 468 | 212 616 469 | 212 616 470 | 212 616 471 | 212 616 472 | 212 616 473 | 212 616 474 | 212 616 475 | 212 616 476 | 212 616 477 | 212 616 478 | 212 616 479 | 212 616 480 | 212 616 481 | 212 616 482 | 212 616 483 | 212 616 484 | 212 616 485 | 212 616 486 | 212 616 487 | 212 616 488 | 212 616 489 | 212 616 490 | 212 616 491 | 212 616 492 | 212 616 493 | 212 616 494 | 212 616 495 | 212 616 496 | 212 616 497 | 212 616 498 | 212 616 499 | 212 616 500 | 212 616 501 | 212 616 502 | 212 616 503 | 212 616 504 | 212 616 505 | 212 616 506 | 212 616 507 | 212 616 508 | 212 616 509 | 212 616 510 | 212 616 511 | 212 616 512 | 212 616 513 | 212 616 514 | 212 616 515 | 212 616 516 | 212 616 517 | 212 616 518 | 212 616 519 | 212 616 520 | 212 616 521 | 212 616 522 | 212 616 523 | 212 616 524 | 212 616 525 | 212 616 526 | 212 616 527 | 212 616 528 | 212 616 529 | 212 616 530 | 212 616 531 | 212 616 532 | 212 616 533 | 212 616 534 | 212 616 535 | 212 616 536 | 212 616 537 | 212 616 538 | 212 616 539 | 212 616 540 | 212 616 541 | 212 616 542 | 212 616 543 | 212 616 544 | 212 616 545 | 212 616 546 | 212 616 547 | 212 616 548 | 212 616 549 | 212 616 550 | 212 616 551 | 212 616 552 | 212 616 553 | 212 616 554 | 212 616 555 | 212 616 556 | 212 616 557 | 212 616 558 | 212 616 559 | 212 616 560 | 212 616 561 | 212 616 562 | 212 616 563 | 212 616 564 | 212 616 565 | 212 616 566 | 212 616 567 | 212 616 568 | 212 616 569 | 212 616 570 | 212 616 571 | 212 616 572 | 212 616 573 | 212 616 574 | 212 616 575 | 212 616 576 | 212 616 577 | 212 616 578 | 212 616 579 | 212 616 580 | 212 616 581 | 212 616 582 | 212 616 583 | 212 616 584 | 212 616 585 | 212 616 586 | 212 616 587 | 212 616 588 | 212 616 589 | 212 616 590 | 212 616 591 | 212 616 592 | 212 616 593 | 212 616 594 | 212 616 595 | 212 616 596 | 212 616 597 | 212 616 598 | 212 616 599 | 212 616 600 | 212 616 601 | 212 616 602 | 212 616 603 | 212 616 604 | 212 616 605 | 212 616 606 | 212 616 607 | 212 616 608 | 212 616 609 | 212 616 610 | 212 616 611 | 212 616 612 | 212 616 613 | 212 616 614 | 212 616 615 | 212 616 616 | 212 616 617 | 616 652 618 | 616 652 619 | 616 652 620 | 616 652 621 | 616 652 622 | 616 652 623 | 616 652 624 | 616 652 625 | 616 652 626 | 616 652 627 | 616 652 628 | 616 652 629 | 616 652 630 | 616 652 631 | 616 652 632 | 616 652 633 | 616 652 634 | 616 652 635 | 616 652 636 | 616 652 637 | 616 652 638 | 616 652 639 | 616 652 640 | 616 652 641 | 616 652 642 | 616 652 643 | 616 652 644 | 616 652 645 | 616 652 646 | 616 652 647 | 616 652 648 | 616 652 649 | 616 652 650 | 616 652 651 | 616 652 652 | 616 652 653 | 652 777 654 | 652 777 655 | 652 777 656 | 652 777 657 | 652 777 658 | 652 777 659 | 652 777 660 | 652 777 661 | 652 777 662 | 652 777 663 | 652 777 664 | 652 777 665 | 652 777 666 | 652 777 667 | 652 777 668 | 652 777 669 | 652 777 670 | 652 777 671 | 652 777 672 | 652 777 673 | 652 777 674 | 652 777 675 | 652 777 676 | 652 777 677 | 652 777 678 | 652 777 679 | 652 777 680 | 652 777 681 | 652 777 682 | 652 777 683 | 652 777 684 | 652 777 685 | 652 777 686 | 652 777 687 | 652 777 688 | 652 777 689 | 652 777 690 | 652 777 691 | 652 777 692 | 652 777 693 | 652 777 694 | 652 777 695 | 652 777 696 | 652 777 697 | 652 777 698 | 652 777 699 | 652 777 700 | 652 777 701 | 652 777 702 | 652 777 703 | 652 777 704 | 652 777 705 | 652 777 706 | 652 777 707 | 652 777 708 | 652 777 709 | 652 777 710 | 652 777 711 | 652 777 712 | 652 777 713 | 652 777 714 | 652 777 715 | 652 777 716 | 652 777 717 | 652 777 718 | 652 777 719 | 652 777 720 | 652 777 721 | 652 777 722 | 652 777 723 | 652 777 724 | 652 777 725 | 652 777 726 | 652 777 727 | 652 777 728 | 652 777 729 | 652 777 730 | 652 777 731 | 652 777 732 | 652 777 733 | 652 777 734 | 652 777 735 | 652 777 736 | 652 777 737 | 652 777 738 | 652 777 739 | 652 777 740 | 652 777 741 | 652 777 742 | 652 777 743 | 652 777 744 | 652 777 745 | 652 777 746 | 652 777 747 | 652 777 748 | 652 777 749 | 652 777 750 | 652 777 751 | 652 777 752 | 652 777 753 | 652 777 754 | 652 777 755 | 652 777 756 | 652 777 757 | 652 777 758 | 652 777 759 | 652 777 760 | 652 777 761 | 652 777 762 | 652 777 763 | 652 777 764 | 652 777 765 | 652 777 766 | 652 777 767 | 652 777 768 | 652 777 769 | 652 777 770 | 652 777 771 | 652 777 772 | 652 777 773 | 652 777 774 | 652 777 775 | 652 777 776 | 652 777 777 | 652 777 778 | 777 806 779 | 777 806 780 | 777 806 781 | 777 806 782 | 777 806 783 | 777 806 784 | 777 806 785 | 777 806 786 | 777 806 787 | 777 806 788 | 777 806 789 | 777 806 790 | 777 806 791 | 777 806 792 | 777 806 793 | 777 806 794 | 777 806 795 | 777 806 796 | 777 806 797 | 777 806 798 | 777 806 799 | 777 806 800 | 777 806 801 | 777 806 802 | 777 806 803 | 777 806 804 | 777 806 805 | 777 806 806 | 777 806 807 | 806 882 808 | 806 882 809 | 806 882 810 | 806 882 811 | 806 882 812 | 806 882 813 | 806 882 814 | 806 882 815 | 806 882 816 | 806 882 817 | 806 882 818 | 806 882 819 | 806 882 820 | 806 882 821 | 806 882 822 | 806 882 823 | 806 882 824 | 806 882 825 | 806 882 826 | 806 882 827 | 806 882 828 | 806 882 829 | 806 882 830 | 806 882 831 | 806 882 832 | 806 882 833 | 806 882 834 | 806 882 835 | 806 882 836 | 806 882 837 | 806 882 838 | 806 882 839 | 806 882 840 | 806 882 841 | 806 882 842 | 806 882 843 | 806 882 844 | 806 882 845 | 806 882 846 | 806 882 847 | 806 882 848 | 806 882 849 | 806 882 850 | 806 882 851 | 806 882 852 | 806 882 853 | 806 882 854 | 806 882 855 | 806 882 856 | 806 882 857 | 806 882 858 | 806 882 859 | 806 882 860 | 806 882 861 | 806 882 862 | 806 882 863 | 806 882 864 | 806 882 865 | 806 882 866 | 806 882 867 | 806 882 868 | 806 882 869 | 806 882 870 | 806 882 871 | 806 882 872 | 806 882 873 | 806 882 874 | 806 882 875 | 806 882 876 | 806 882 877 | 806 882 878 | 806 882 879 | 806 882 880 | 806 882 881 | 806 882 882 | 806 882 883 | 882 1124 884 | 882 1124 885 | 882 1124 886 | 882 1124 887 | 882 1124 888 | 882 1124 889 | 882 1124 890 | 882 1124 891 | 882 1124 892 | 882 1124 893 | 882 1124 894 | 882 1124 895 | 882 1124 896 | 882 1124 897 | 882 1124 898 | 882 1124 899 | 882 1124 900 | 882 1124 901 | 882 1124 902 | 882 1124 903 | 882 1124 904 | 882 1124 905 | 882 1124 906 | 882 1124 907 | 882 1124 908 | 882 1124 909 | 882 1124 910 | 882 1124 911 | 882 1124 912 | 882 1124 913 | 882 1124 914 | 882 1124 915 | 882 1124 916 | 882 1124 917 | 882 1124 918 | 882 1124 919 | 882 1124 920 | 882 1124 921 | 882 1124 922 | 882 1124 923 | 882 1124 924 | 882 1124 925 | 882 1124 926 | 882 1124 927 | 882 1124 928 | 882 1124 929 | 882 1124 930 | 882 1124 931 | 882 1124 932 | 882 1124 933 | 882 1124 934 | 882 1124 935 | 882 1124 936 | 882 1124 937 | 882 1124 938 | 882 1124 939 | 882 1124 940 | 882 1124 941 | 882 1124 942 | 882 1124 943 | 882 1124 944 | 882 1124 945 | 882 1124 946 | 882 1124 947 | 882 1124 948 | 882 1124 949 | 882 1124 950 | 882 1124 951 | 882 1124 952 | 882 1124 953 | 882 1124 954 | 882 1124 955 | 882 1124 956 | 882 1124 957 | 882 1124 958 | 882 1124 959 | 882 1124 960 | 882 1124 961 | 882 1124 962 | 882 1124 963 | 882 1124 964 | 882 1124 965 | 882 1124 966 | 882 1124 967 | 882 1124 968 | 882 1124 969 | 882 1124 970 | 882 1124 971 | 882 1124 972 | 882 1124 973 | 882 1124 974 | 882 1124 975 | 882 1124 976 | 882 1124 977 | 882 1124 978 | 882 1124 979 | 882 1124 980 | 882 1124 981 | 882 1124 982 | 882 1124 983 | 882 1124 984 | 882 1124 985 | 882 1124 986 | 882 1124 987 | 882 1124 988 | 882 1124 989 | 882 1124 990 | 882 1124 991 | 882 1124 992 | 882 1124 993 | 882 1124 994 | 882 1124 995 | 882 1124 996 | 882 1124 997 | 882 1124 998 | 882 1124 999 | 882 1124 1000 | 882 1124 1001 | 882 1124 1002 | 882 1124 1003 | 882 1124 1004 | 882 1124 1005 | 882 1124 1006 | 882 1124 1007 | 882 1124 1008 | 882 1124 1009 | 882 1124 1010 | 882 1124 1011 | 882 1124 1012 | 882 1124 1013 | 882 1124 1014 | 882 1124 1015 | 882 1124 1016 | 882 1124 1017 | 882 1124 1018 | 882 1124 1019 | 882 1124 1020 | 882 1124 1021 | 882 1124 1022 | 882 1124 1023 | 882 1124 1024 | 882 1124 1025 | 882 1124 1026 | 882 1124 1027 | 882 1124 1028 | 882 1124 1029 | 882 1124 1030 | 882 1124 1031 | 882 1124 1032 | 882 1124 1033 | 882 1124 1034 | 882 1124 1035 | 882 1124 1036 | 882 1124 1037 | 882 1124 1038 | 882 1124 1039 | 882 1124 1040 | 882 1124 1041 | 882 1124 1042 | 882 1124 1043 | 882 1124 1044 | 882 1124 1045 | 882 1124 1046 | 882 1124 1047 | 882 1124 1048 | 882 1124 1049 | 882 1124 1050 | 882 1124 1051 | 882 1124 1052 | 882 1124 1053 | 882 1124 1054 | 882 1124 1055 | 882 1124 1056 | 882 1124 1057 | 882 1124 1058 | 882 1124 1059 | 882 1124 1060 | 882 1124 1061 | 882 1124 1062 | 882 1124 1063 | 882 1124 1064 | 882 1124 1065 | 882 1124 1066 | 882 1124 1067 | 882 1124 1068 | 882 1124 1069 | 882 1124 1070 | 882 1124 1071 | 882 1124 1072 | 882 1124 1073 | 882 1124 1074 | 882 1124 1075 | 882 1124 1076 | 882 1124 1077 | 882 1124 1078 | 882 1124 1079 | 882 1124 1080 | 882 1124 1081 | 882 1124 1082 | 882 1124 1083 | 882 1124 1084 | 882 1124 1085 | 882 1124 1086 | 882 1124 1087 | 882 1124 1088 | 882 1124 1089 | 882 1124 1090 | 882 1124 1091 | 882 1124 1092 | 882 1124 1093 | 882 1124 1094 | 882 1124 1095 | 882 1124 1096 | 882 1124 1097 | 882 1124 1098 | 882 1124 1099 | 882 1124 1100 | 882 1124 1101 | 882 1124 1102 | 882 1124 1103 | 882 1124 1104 | 882 1124 1105 | 882 1124 1106 | 882 1124 1107 | 882 1124 1108 | 882 1124 1109 | 882 1124 1110 | 882 1124 1111 | 882 1124 1112 | 882 1124 1113 | 882 1124 1114 | 882 1124 1115 | 882 1124 1116 | 882 1124 1117 | 882 1124 1118 | 882 1124 1119 | 882 1124 1120 | 882 1124 1121 | 882 1124 1122 | 882 1124 1123 | 882 1124 1124 | 882 1124 1125 | 1124 1141 1126 | 1124 1141 1127 | 1124 1141 1128 | 1124 1141 1129 | 1124 1141 1130 | 1124 1141 1131 | 1124 1141 1132 | 1124 1141 1133 | 1124 1141 1134 | 1124 1141 1135 | 1124 1141 1136 | 1124 1141 1137 | 1124 1141 1138 | 1124 1141 1139 | 1124 1141 1140 | 1124 1141 1141 | 1124 1141 1142 | 1141 1266 1143 | 1141 1266 1144 | 1141 1266 1145 | 1141 1266 1146 | 1141 1266 1147 | 1141 1266 1148 | 1141 1266 1149 | 1141 1266 1150 | 1141 1266 1151 | 1141 1266 1152 | 1141 1266 1153 | 1141 1266 1154 | 1141 1266 1155 | 1141 1266 1156 | 1141 1266 1157 | 1141 1266 1158 | 1141 1266 1159 | 1141 1266 1160 | 1141 1266 1161 | 1141 1266 1162 | 1141 1266 1163 | 1141 1266 1164 | 1141 1266 1165 | 1141 1266 1166 | 1141 1266 1167 | 1141 1266 1168 | 1141 1266 1169 | 1141 1266 1170 | 1141 1266 1171 | 1141 1266 1172 | 1141 1266 1173 | 1141 1266 1174 | 1141 1266 1175 | 1141 1266 1176 | 1141 1266 1177 | 1141 1266 1178 | 1141 1266 1179 | 1141 1266 1180 | 1141 1266 1181 | 1141 1266 1182 | 1141 1266 1183 | 1141 1266 1184 | 1141 1266 1185 | 1141 1266 1186 | 1141 1266 1187 | 1141 1266 1188 | 1141 1266 1189 | 1141 1266 1190 | 1141 1266 1191 | 1141 1266 1192 | 1141 1266 1193 | 1141 1266 1194 | 1141 1266 1195 | 1141 1266 1196 | 1141 1266 1197 | 1141 1266 1198 | 1141 1266 1199 | 1141 1266 1200 | 1141 1266 1201 | 1141 1266 1202 | 1141 1266 1203 | 1141 1266 1204 | 1141 1266 1205 | 1141 1266 1206 | 1141 1266 1207 | 1141 1266 1208 | 1141 1266 1209 | 1141 1266 1210 | 1141 1266 1211 | 1141 1266 1212 | 1141 1266 1213 | 1141 1266 1214 | 1141 1266 1215 | 1141 1266 1216 | 1141 1266 1217 | 1141 1266 1218 | 1141 1266 1219 | 1141 1266 1220 | 1141 1266 1221 | 1141 1266 1222 | 1141 1266 1223 | 1141 1266 1224 | 1141 1266 1225 | 1141 1266 1226 | 1141 1266 1227 | 1141 1266 1228 | 1141 1266 1229 | 1141 1266 1230 | 1141 1266 1231 | 1141 1266 1232 | 1141 1266 1233 | 1141 1266 1234 | 1141 1266 1235 | 1141 1266 1236 | 1141 1266 1237 | 1141 1266 1238 | 1141 1266 1239 | 1141 1266 1240 | 1141 1266 1241 | 1141 1266 1242 | 1141 1266 1243 | 1141 1266 1244 | 1141 1266 1245 | 1141 1266 1246 | 1141 1266 1247 | 1141 1266 1248 | 1141 1266 1249 | 1141 1266 1250 | 1141 1266 1251 | 1141 1266 1252 | 1141 1266 1253 | 1141 1266 1254 | 1141 1266 1255 | 1141 1266 1256 | 1141 1266 1257 | 1141 1266 1258 | 1141 1266 1259 | 1141 1266 1260 | 1141 1266 1261 | 1141 1266 1262 | 1141 1266 1263 | 1141 1266 1264 | 1141 1266 1265 | 1141 1266 1266 | 1141 1266 1267 | 1266 1284 1268 | 1266 1284 1269 | 1266 1284 1270 | 1266 1284 1271 | 1266 1284 1272 | 1266 1284 1273 | 1266 1284 1274 | 1266 1284 1275 | 1266 1284 1276 | 1266 1284 1277 | 1266 1284 1278 | 1266 1284 1279 | 1266 1284 1280 | 1266 1284 1281 | 1266 1284 1282 | 1266 1284 1283 | 1266 1284 1284 | 1266 1284 1285 | 1284 1362 1286 | 1284 1362 1287 | 1284 1362 1288 | 1284 1362 1289 | 1284 1362 1290 | 1284 1362 1291 | 1284 1362 1292 | 1284 1362 1293 | 1284 1362 1294 | 1284 1362 1295 | 1284 1362 1296 | 1284 1362 1297 | 1284 1362 1298 | 1284 1362 1299 | 1284 1362 1300 | 1284 1362 1301 | 1284 1362 1302 | 1284 1362 1303 | 1284 1362 1304 | 1284 1362 1305 | 1284 1362 1306 | 1284 1362 1307 | 1284 1362 1308 | 1284 1362 1309 | 1284 1362 1310 | 1284 1362 1311 | 1284 1362 1312 | 1284 1362 1313 | 1284 1362 1314 | 1284 1362 1315 | 1284 1362 1316 | 1284 1362 1317 | 1284 1362 1318 | 1284 1362 1319 | 1284 1362 1320 | 1284 1362 1321 | 1284 1362 1322 | 1284 1362 1323 | 1284 1362 1324 | 1284 1362 1325 | 1284 1362 1326 | 1284 1362 1327 | 1284 1362 1328 | 1284 1362 1329 | 1284 1362 1330 | 1284 1362 1331 | 1284 1362 1332 | 1284 1362 1333 | 1284 1362 1334 | 1284 1362 1335 | 1284 1362 1336 | 1284 1362 1337 | 1284 1362 1338 | 1284 1362 1339 | 1284 1362 1340 | 1284 1362 1341 | 1284 1362 1342 | 1284 1362 1343 | 1284 1362 1344 | 1284 1362 1345 | 1284 1362 1346 | 1284 1362 1347 | 1284 1362 1348 | 1284 1362 1349 | 1284 1362 1350 | 1284 1362 1351 | 1284 1362 1352 | 1284 1362 1353 | 1284 1362 1354 | 1284 1362 1355 | 1284 1362 1356 | 1284 1362 1357 | 1284 1362 1358 | 1284 1362 1359 | 1284 1362 1360 | 1284 1362 1361 | 1284 1362 1362 | 1284 1362 1363 | 1362 1371 1364 | 1362 1371 1365 | 1362 1371 1366 | 1362 1371 1367 | 1362 1371 1368 | 1362 1371 1369 | 1362 1371 1370 | 1362 1371 1371 | 1362 1371 1372 | 1371 1389 1373 | 1371 1389 1374 | 1371 1389 1375 | 1371 1389 1376 | 1371 1389 1377 | 1371 1389 1378 | 1371 1389 1379 | 1371 1389 1380 | 1371 1389 1381 | 1371 1389 1382 | 1371 1389 1383 | 1371 1389 1384 | 1371 1389 1385 | 1371 1389 1386 | 1371 1389 1387 | 1371 1389 1388 | 1371 1389 1389 | 1371 1389 1390 | 1389 1655 1391 | 1389 1655 1392 | 1389 1655 1393 | 1389 1655 1394 | 1389 1655 1395 | 1389 1655 1396 | 1389 1655 1397 | 1389 1655 1398 | 1389 1655 1399 | 1389 1655 1400 | 1389 1655 1401 | 1389 1655 1402 | 1389 1655 1403 | 1389 1655 1404 | 1389 1655 1405 | 1389 1655 1406 | 1389 1655 1407 | 1389 1655 1408 | 1389 1655 1409 | 1389 1655 1410 | 1389 1655 1411 | 1389 1655 1412 | 1389 1655 1413 | 1389 1655 1414 | 1389 1655 1415 | 1389 1655 1416 | 1389 1655 1417 | 1389 1655 1418 | 1389 1655 1419 | 1389 1655 1420 | 1389 1655 1421 | 1389 1655 1422 | 1389 1655 1423 | 1389 1655 1424 | 1389 1655 1425 | 1389 1655 1426 | 1389 1655 1427 | 1389 1655 1428 | 1389 1655 1429 | 1389 1655 1430 | 1389 1655 1431 | 1389 1655 1432 | 1389 1655 1433 | 1389 1655 1434 | 1389 1655 1435 | 1389 1655 1436 | 1389 1655 1437 | 1389 1655 1438 | 1389 1655 1439 | 1389 1655 1440 | 1389 1655 1441 | 1389 1655 1442 | 1389 1655 1443 | 1389 1655 1444 | 1389 1655 1445 | 1389 1655 1446 | 1389 1655 1447 | 1389 1655 1448 | 1389 1655 1449 | 1389 1655 1450 | 1389 1655 1451 | 1389 1655 1452 | 1389 1655 1453 | 1389 1655 1454 | 1389 1655 1455 | 1389 1655 1456 | 1389 1655 1457 | 1389 1655 1458 | 1389 1655 1459 | 1389 1655 1460 | 1389 1655 1461 | 1389 1655 1462 | 1389 1655 1463 | 1389 1655 1464 | 1389 1655 1465 | 1389 1655 1466 | 1389 1655 1467 | 1389 1655 1468 | 1389 1655 1469 | 1389 1655 1470 | 1389 1655 1471 | 1389 1655 1472 | 1389 1655 1473 | 1389 1655 1474 | 1389 1655 1475 | 1389 1655 1476 | 1389 1655 1477 | 1389 1655 1478 | 1389 1655 1479 | 1389 1655 1480 | 1389 1655 1481 | 1389 1655 1482 | 1389 1655 1483 | 1389 1655 1484 | 1389 1655 1485 | 1389 1655 1486 | 1389 1655 1487 | 1389 1655 1488 | 1389 1655 1489 | 1389 1655 1490 | 1389 1655 1491 | 1389 1655 1492 | 1389 1655 1493 | 1389 1655 1494 | 1389 1655 1495 | 1389 1655 1496 | 1389 1655 1497 | 1389 1655 1498 | 1389 1655 1499 | 1389 1655 1500 | 1389 1655 1501 | 1389 1655 1502 | 1389 1655 1503 | 1389 1655 1504 | 1389 1655 1505 | 1389 1655 1506 | 1389 1655 1507 | 1389 1655 1508 | 1389 1655 1509 | 1389 1655 1510 | 1389 1655 1511 | 1389 1655 1512 | 1389 1655 1513 | 1389 1655 1514 | 1389 1655 1515 | 1389 1655 1516 | 1389 1655 1517 | 1389 1655 1518 | 1389 1655 1519 | 1389 1655 1520 | 1389 1655 1521 | 1389 1655 1522 | 1389 1655 1523 | 1389 1655 1524 | 1389 1655 1525 | 1389 1655 1526 | 1389 1655 1527 | 1389 1655 1528 | 1389 1655 1529 | 1389 1655 1530 | 1389 1655 1531 | 1389 1655 1532 | 1389 1655 1533 | 1389 1655 1534 | 1389 1655 1535 | 1389 1655 1536 | 1389 1655 1537 | 1389 1655 1538 | 1389 1655 1539 | 1389 1655 1540 | 1389 1655 1541 | 1389 1655 1542 | 1389 1655 1543 | 1389 1655 1544 | 1389 1655 1545 | 1389 1655 1546 | 1389 1655 1547 | 1389 1655 1548 | 1389 1655 1549 | 1389 1655 1550 | 1389 1655 1551 | 1389 1655 1552 | 1389 1655 1553 | 1389 1655 1554 | 1389 1655 1555 | 1389 1655 1556 | 1389 1655 1557 | 1389 1655 1558 | 1389 1655 1559 | 1389 1655 1560 | 1389 1655 1561 | 1389 1655 1562 | 1389 1655 1563 | 1389 1655 1564 | 1389 1655 1565 | 1389 1655 1566 | 1389 1655 1567 | 1389 1655 1568 | 1389 1655 1569 | 1389 1655 1570 | 1389 1655 1571 | 1389 1655 1572 | 1389 1655 1573 | 1389 1655 1574 | 1389 1655 1575 | 1389 1655 1576 | 1389 1655 1577 | 1389 1655 1578 | 1389 1655 1579 | 1389 1655 1580 | 1389 1655 1581 | 1389 1655 1582 | 1389 1655 1583 | 1389 1655 1584 | 1389 1655 1585 | 1389 1655 1586 | 1389 1655 1587 | 1389 1655 1588 | 1389 1655 1589 | 1389 1655 1590 | 1389 1655 1591 | 1389 1655 1592 | 1389 1655 1593 | 1389 1655 1594 | 1389 1655 1595 | 1389 1655 1596 | 1389 1655 1597 | 1389 1655 1598 | 1389 1655 1599 | 1389 1655 1600 | 1389 1655 1601 | 1389 1655 1602 | 1389 1655 1603 | 1389 1655 1604 | 1389 1655 1605 | 1389 1655 1606 | 1389 1655 1607 | 1389 1655 1608 | 1389 1655 1609 | 1389 1655 1610 | 1389 1655 1611 | 1389 1655 1612 | 1389 1655 1613 | 1389 1655 1614 | 1389 1655 1615 | 1389 1655 1616 | 1389 1655 1617 | 1389 1655 1618 | 1389 1655 1619 | 1389 1655 1620 | 1389 1655 1621 | 1389 1655 1622 | 1389 1655 1623 | 1389 1655 1624 | 1389 1655 1625 | 1389 1655 1626 | 1389 1655 1627 | 1389 1655 1628 | 1389 1655 1629 | 1389 1655 1630 | 1389 1655 1631 | 1389 1655 1632 | 1389 1655 1633 | 1389 1655 1634 | 1389 1655 1635 | 1389 1655 1636 | 1389 1655 1637 | 1389 1655 1638 | 1389 1655 1639 | 1389 1655 1640 | 1389 1655 1641 | 1389 1655 1642 | 1389 1655 1643 | 1389 1655 1644 | 1389 1655 1645 | 1389 1655 1646 | 1389 1655 1647 | 1389 1655 1648 | 1389 1655 1649 | 1389 1655 1650 | 1389 1655 1651 | 1389 1655 1652 | 1389 1655 1653 | 1389 1655 1654 | 1389 1655 1655 | 1389 1655 1656 | 1655 1684 1657 | 1655 1684 1658 | 1655 1684 1659 | 1655 1684 1660 | 1655 1684 1661 | 1655 1684 1662 | 1655 1684 1663 | 1655 1684 1664 | 1655 1684 1665 | 1655 1684 1666 | 1655 1684 1667 | 1655 1684 1668 | 1655 1684 1669 | 1655 1684 1670 | 1655 1684 1671 | 1655 1684 1672 | 1655 1684 1673 | 1655 1684 1674 | 1655 1684 1675 | 1655 1684 1676 | 1655 1684 1677 | 1655 1684 1678 | 1655 1684 1679 | 1655 1684 1680 | 1655 1684 1681 | 1655 1684 1682 | 1655 1684 1683 | 1655 1684 1684 | 1655 1684 1685 | 1684 1822 1686 | 1684 1822 1687 | 1684 1822 1688 | 1684 1822 1689 | 1684 1822 1690 | 1684 1822 1691 | 1684 1822 1692 | 1684 1822 1693 | 1684 1822 1694 | 1684 1822 1695 | 1684 1822 1696 | 1684 1822 1697 | 1684 1822 1698 | 1684 1822 1699 | 1684 1822 1700 | 1684 1822 1701 | 1684 1822 1702 | 1684 1822 1703 | 1684 1822 1704 | 1684 1822 1705 | 1684 1822 1706 | 1684 1822 1707 | 1684 1822 1708 | 1684 1822 1709 | 1684 1822 1710 | 1684 1822 1711 | 1684 1822 1712 | 1684 1822 1713 | 1684 1822 1714 | 1684 1822 1715 | 1684 1822 1716 | 1684 1822 1717 | 1684 1822 1718 | 1684 1822 1719 | 1684 1822 1720 | 1684 1822 1721 | 1684 1822 1722 | 1684 1822 1723 | 1684 1822 1724 | 1684 1822 1725 | 1684 1822 1726 | 1684 1822 1727 | 1684 1822 1728 | 1684 1822 1729 | 1684 1822 1730 | 1684 1822 1731 | 1684 1822 1732 | 1684 1822 1733 | 1684 1822 1734 | 1684 1822 1735 | 1684 1822 1736 | 1684 1822 1737 | 1684 1822 1738 | 1684 1822 1739 | 1684 1822 1740 | 1684 1822 1741 | 1684 1822 1742 | 1684 1822 1743 | 1684 1822 1744 | 1684 1822 1745 | 1684 1822 1746 | 1684 1822 1747 | 1684 1822 1748 | 1684 1822 1749 | 1684 1822 1750 | 1684 1822 1751 | 1684 1822 1752 | 1684 1822 1753 | 1684 1822 1754 | 1684 1822 1755 | 1684 1822 1756 | 1684 1822 1757 | 1684 1822 1758 | 1684 1822 1759 | 1684 1822 1760 | 1684 1822 1761 | 1684 1822 1762 | 1684 1822 1763 | 1684 1822 1764 | 1684 1822 1765 | 1684 1822 1766 | 1684 1822 1767 | 1684 1822 1768 | 1684 1822 1769 | 1684 1822 1770 | 1684 1822 1771 | 1684 1822 1772 | 1684 1822 1773 | 1684 1822 1774 | 1684 1822 1775 | 1684 1822 1776 | 1684 1822 1777 | 1684 1822 1778 | 1684 1822 1779 | 1684 1822 1780 | 1684 1822 1781 | 1684 1822 1782 | 1684 1822 1783 | 1684 1822 1784 | 1684 1822 1785 | 1684 1822 1786 | 1684 1822 1787 | 1684 1822 1788 | 1684 1822 1789 | 1684 1822 1790 | 1684 1822 1791 | 1684 1822 1792 | 1684 1822 1793 | 1684 1822 1794 | 1684 1822 1795 | 1684 1822 1796 | 1684 1822 1797 | 1684 1822 1798 | 1684 1822 1799 | 1684 1822 1800 | 1684 1822 1801 | 1684 1822 1802 | 1684 1822 1803 | 1684 1822 1804 | 1684 1822 1805 | 1684 1822 1806 | 1684 1822 1807 | 1684 1822 1808 | 1684 1822 1809 | 1684 1822 1810 | 1684 1822 1811 | 1684 1822 1812 | 1684 1822 1813 | 1684 1822 1814 | 1684 1822 1815 | 1684 1822 1816 | 1684 1822 1817 | 1684 1822 1818 | 1684 1822 1819 | 1684 1822 1820 | 1684 1822 1821 | 1684 1822 1822 | 1684 1822 1823 | 1822 1836 1824 | 1822 1836 1825 | 1822 1836 1826 | 1822 1836 1827 | 1822 1836 1828 | 1822 1836 1829 | 1822 1836 1830 | 1822 1836 1831 | 1822 1836 1832 | 1822 1836 1833 | 1822 1836 1834 | 1822 1836 1835 | 1822 1836 1836 | 1822 1836 1837 | 1836 1841 1838 | 1836 1841 1839 | 1836 1841 1840 | 1836 1841 1841 | 1836 1841 1842 | 1841 1864 1843 | 1841 1864 1844 | 1841 1864 1845 | 1841 1864 1846 | 1841 1864 1847 | 1841 1864 1848 | 1841 1864 1849 | 1841 1864 1850 | 1841 1864 1851 | 1841 1864 1852 | 1841 1864 1853 | 1841 1864 1854 | 1841 1864 1855 | 1841 1864 1856 | 1841 1864 1857 | 1841 1864 1858 | 1841 1864 1859 | 1841 1864 1860 | 1841 1864 1861 | 1841 1864 1862 | 1841 1864 1863 | 1841 1864 1864 | 1841 1864 1865 | 1864 2230 1866 | 1864 2230 1867 | 1864 2230 1868 | 1864 2230 1869 | 1864 2230 1870 | 1864 2230 1871 | 1864 2230 1872 | 1864 2230 1873 | 1864 2230 1874 | 1864 2230 1875 | 1864 2230 1876 | 1864 2230 1877 | 1864 2230 1878 | 1864 2230 1879 | 1864 2230 1880 | 1864 2230 1881 | 1864 2230 1882 | 1864 2230 1883 | 1864 2230 1884 | 1864 2230 1885 | 1864 2230 1886 | 1864 2230 1887 | 1864 2230 1888 | 1864 2230 1889 | 1864 2230 1890 | 1864 2230 1891 | 1864 2230 1892 | 1864 2230 1893 | 1864 2230 1894 | 1864 2230 1895 | 1864 2230 1896 | 1864 2230 1897 | 1864 2230 1898 | 1864 2230 1899 | 1864 2230 1900 | 1864 2230 1901 | 1864 2230 1902 | 1864 2230 1903 | 1864 2230 1904 | 1864 2230 1905 | 1864 2230 1906 | 1864 2230 1907 | 1864 2230 1908 | 1864 2230 1909 | 1864 2230 1910 | 1864 2230 1911 | 1864 2230 1912 | 1864 2230 1913 | 1864 2230 1914 | 1864 2230 1915 | 1864 2230 1916 | 1864 2230 1917 | 1864 2230 1918 | 1864 2230 1919 | 1864 2230 1920 | 1864 2230 1921 | 1864 2230 1922 | 1864 2230 1923 | 1864 2230 1924 | 1864 2230 1925 | 1864 2230 1926 | 1864 2230 1927 | 1864 2230 1928 | 1864 2230 1929 | 1864 2230 1930 | 1864 2230 1931 | 1864 2230 1932 | 1864 2230 1933 | 1864 2230 1934 | 1864 2230 1935 | 1864 2230 1936 | 1864 2230 1937 | 1864 2230 1938 | 1864 2230 1939 | 1864 2230 1940 | 1864 2230 1941 | 1864 2230 1942 | 1864 2230 1943 | 1864 2230 1944 | 1864 2230 1945 | 1864 2230 1946 | 1864 2230 1947 | 1864 2230 1948 | 1864 2230 1949 | 1864 2230 1950 | 1864 2230 1951 | 1864 2230 1952 | 1864 2230 1953 | 1864 2230 1954 | 1864 2230 1955 | 1864 2230 1956 | 1864 2230 1957 | 1864 2230 1958 | 1864 2230 1959 | 1864 2230 1960 | 1864 2230 1961 | 1864 2230 1962 | 1864 2230 1963 | 1864 2230 1964 | 1864 2230 1965 | 1864 2230 1966 | 1864 2230 1967 | 1864 2230 1968 | 1864 2230 1969 | 1864 2230 1970 | 1864 2230 1971 | 1864 2230 1972 | 1864 2230 1973 | 1864 2230 1974 | 1864 2230 1975 | 1864 2230 1976 | 1864 2230 1977 | 1864 2230 1978 | 1864 2230 1979 | 1864 2230 1980 | 1864 2230 1981 | 1864 2230 1982 | 1864 2230 1983 | 1864 2230 1984 | 1864 2230 1985 | 1864 2230 1986 | 1864 2230 1987 | 1864 2230 1988 | 1864 2230 1989 | 1864 2230 1990 | 1864 2230 1991 | 1864 2230 1992 | 1864 2230 1993 | 1864 2230 1994 | 1864 2230 1995 | 1864 2230 1996 | 1864 2230 1997 | 1864 2230 1998 | 1864 2230 1999 | 1864 2230 2000 | 1864 2230 2001 | 1864 2230 2002 | 1864 2230 2003 | 1864 2230 2004 | 1864 2230 2005 | 1864 2230 2006 | 1864 2230 2007 | 1864 2230 2008 | 1864 2230 2009 | 1864 2230 2010 | 1864 2230 2011 | 1864 2230 2012 | 1864 2230 2013 | 1864 2230 2014 | 1864 2230 2015 | 1864 2230 2016 | 1864 2230 2017 | 1864 2230 2018 | 1864 2230 2019 | 1864 2230 2020 | 1864 2230 2021 | 1864 2230 2022 | 1864 2230 2023 | 1864 2230 2024 | 1864 2230 2025 | 1864 2230 2026 | 1864 2230 2027 | 1864 2230 2028 | 1864 2230 2029 | 1864 2230 2030 | 1864 2230 2031 | 1864 2230 2032 | 1864 2230 2033 | 1864 2230 2034 | 1864 2230 2035 | 1864 2230 2036 | 1864 2230 2037 | 1864 2230 2038 | 1864 2230 2039 | 1864 2230 2040 | 1864 2230 2041 | 1864 2230 2042 | 1864 2230 2043 | 1864 2230 2044 | 1864 2230 2045 | 1864 2230 2046 | 1864 2230 2047 | 1864 2230 2048 | 1864 2230 2049 | 1864 2230 2050 | 1864 2230 2051 | 1864 2230 2052 | 1864 2230 2053 | 1864 2230 2054 | 1864 2230 2055 | 1864 2230 2056 | 1864 2230 2057 | 1864 2230 2058 | 1864 2230 2059 | 1864 2230 2060 | 1864 2230 2061 | 1864 2230 2062 | 1864 2230 2063 | 1864 2230 2064 | 1864 2230 2065 | 1864 2230 2066 | 1864 2230 2067 | 1864 2230 2068 | 1864 2230 2069 | 1864 2230 2070 | 1864 2230 2071 | 1864 2230 2072 | 1864 2230 2073 | 1864 2230 2074 | 1864 2230 2075 | 1864 2230 2076 | 1864 2230 2077 | 1864 2230 2078 | 1864 2230 2079 | 1864 2230 2080 | 1864 2230 2081 | 1864 2230 2082 | 1864 2230 2083 | 1864 2230 2084 | 1864 2230 2085 | 1864 2230 2086 | 1864 2230 2087 | 1864 2230 2088 | 1864 2230 2089 | 1864 2230 2090 | 1864 2230 2091 | 1864 2230 2092 | 1864 2230 2093 | 1864 2230 2094 | 1864 2230 2095 | 1864 2230 2096 | 1864 2230 2097 | 1864 2230 2098 | 1864 2230 2099 | 1864 2230 2100 | 1864 2230 2101 | 1864 2230 2102 | 1864 2230 2103 | 1864 2230 2104 | 1864 2230 2105 | 1864 2230 2106 | 1864 2230 2107 | 1864 2230 2108 | 1864 2230 2109 | 1864 2230 2110 | 1864 2230 2111 | 1864 2230 2112 | 1864 2230 2113 | 1864 2230 2114 | 1864 2230 2115 | 1864 2230 2116 | 1864 2230 2117 | 1864 2230 2118 | 1864 2230 2119 | 1864 2230 2120 | 1864 2230 2121 | 1864 2230 2122 | 1864 2230 2123 | 1864 2230 2124 | 1864 2230 2125 | 1864 2230 2126 | 1864 2230 2127 | 1864 2230 2128 | 1864 2230 2129 | 1864 2230 2130 | 1864 2230 2131 | 1864 2230 2132 | 1864 2230 2133 | 1864 2230 2134 | 1864 2230 2135 | 1864 2230 2136 | 1864 2230 2137 | 1864 2230 2138 | 1864 2230 2139 | 1864 2230 2140 | 1864 2230 2141 | 1864 2230 2142 | 1864 2230 2143 | 1864 2230 2144 | 1864 2230 2145 | 1864 2230 2146 | 1864 2230 2147 | 1864 2230 2148 | 1864 2230 2149 | 1864 2230 2150 | 1864 2230 2151 | 1864 2230 2152 | 1864 2230 2153 | 1864 2230 2154 | 1864 2230 2155 | 1864 2230 2156 | 1864 2230 2157 | 1864 2230 2158 | 1864 2230 2159 | 1864 2230 2160 | 1864 2230 2161 | 1864 2230 2162 | 1864 2230 2163 | 1864 2230 2164 | 1864 2230 2165 | 1864 2230 2166 | 1864 2230 2167 | 1864 2230 2168 | 1864 2230 2169 | 1864 2230 2170 | 1864 2230 2171 | 1864 2230 2172 | 1864 2230 2173 | 1864 2230 2174 | 1864 2230 2175 | 1864 2230 2176 | 1864 2230 2177 | 1864 2230 2178 | 1864 2230 2179 | 1864 2230 2180 | 1864 2230 2181 | 1864 2230 2182 | 1864 2230 2183 | 1864 2230 2184 | 1864 2230 2185 | 1864 2230 2186 | 1864 2230 2187 | 1864 2230 2188 | 1864 2230 2189 | 1864 2230 2190 | 1864 2230 2191 | 1864 2230 2192 | 1864 2230 2193 | 1864 2230 2194 | 1864 2230 2195 | 1864 2230 2196 | 1864 2230 2197 | 1864 2230 2198 | 1864 2230 2199 | 1864 2230 2200 | 1864 2230 2201 | 1864 2230 2202 | 1864 2230 2203 | 1864 2230 2204 | 1864 2230 2205 | 1864 2230 2206 | 1864 2230 2207 | 1864 2230 2208 | 1864 2230 2209 | 1864 2230 2210 | 1864 2230 2211 | 1864 2230 2212 | 1864 2230 2213 | 1864 2230 2214 | 1864 2230 2215 | 1864 2230 2216 | 1864 2230 2217 | 1864 2230 2218 | 1864 2230 2219 | 1864 2230 2220 | 1864 2230 2221 | 1864 2230 2222 | 1864 2230 2223 | 1864 2230 2224 | 1864 2230 2225 | 1864 2230 2226 | 1864 2230 2227 | 1864 2230 2228 | 1864 2230 2229 | 1864 2230 2230 | 1864 2230 2231 | 2230 2545 2232 | 2230 2545 2233 | 2230 2545 2234 | 2230 2545 2235 | 2230 2545 2236 | 2230 2545 2237 | 2230 2545 2238 | 2230 2545 2239 | 2230 2545 2240 | 2230 2545 2241 | 2230 2545 2242 | 2230 2545 2243 | 2230 2545 2244 | 2230 2545 2245 | 2230 2545 2246 | 2230 2545 2247 | 2230 2545 2248 | 2230 2545 2249 | 2230 2545 2250 | 2230 2545 2251 | 2230 2545 2252 | 2230 2545 2253 | 2230 2545 2254 | 2230 2545 2255 | 2230 2545 2256 | 2230 2545 2257 | 2230 2545 2258 | 2230 2545 2259 | 2230 2545 2260 | 2230 2545 2261 | 2230 2545 2262 | 2230 2545 2263 | 2230 2545 2264 | 2230 2545 2265 | 2230 2545 2266 | 2230 2545 2267 | 2230 2545 2268 | 2230 2545 2269 | 2230 2545 2270 | 2230 2545 2271 | 2230 2545 2272 | 2230 2545 2273 | 2230 2545 2274 | 2230 2545 2275 | 2230 2545 2276 | 2230 2545 2277 | 2230 2545 2278 | 2230 2545 2279 | 2230 2545 2280 | 2230 2545 2281 | 2230 2545 2282 | 2230 2545 2283 | 2230 2545 2284 | 2230 2545 2285 | 2230 2545 2286 | 2230 2545 2287 | 2230 2545 2288 | 2230 2545 2289 | 2230 2545 2290 | 2230 2545 2291 | 2230 2545 2292 | 2230 2545 2293 | 2230 2545 2294 | 2230 2545 2295 | 2230 2545 2296 | 2230 2545 2297 | 2230 2545 2298 | 2230 2545 2299 | 2230 2545 2300 | 2230 2545 2301 | 2230 2545 2302 | 2230 2545 2303 | 2230 2545 2304 | 2230 2545 2305 | 2230 2545 2306 | 2230 2545 2307 | 2230 2545 2308 | 2230 2545 2309 | 2230 2545 2310 | 2230 2545 2311 | 2230 2545 2312 | 2230 2545 2313 | 2230 2545 2314 | 2230 2545 2315 | 2230 2545 2316 | 2230 2545 2317 | 2230 2545 2318 | 2230 2545 2319 | 2230 2545 2320 | 2230 2545 2321 | 2230 2545 2322 | 2230 2545 2323 | 2230 2545 2324 | 2230 2545 2325 | 2230 2545 2326 | 2230 2545 2327 | 2230 2545 2328 | 2230 2545 2329 | 2230 2545 2330 | 2230 2545 2331 | 2230 2545 2332 | 2230 2545 2333 | 2230 2545 2334 | 2230 2545 2335 | 2230 2545 2336 | 2230 2545 2337 | 2230 2545 2338 | 2230 2545 2339 | 2230 2545 2340 | 2230 2545 2341 | 2230 2545 2342 | 2230 2545 2343 | 2230 2545 2344 | 2230 2545 2345 | 2230 2545 2346 | 2230 2545 2347 | 2230 2545 2348 | 2230 2545 2349 | 2230 2545 2350 | 2230 2545 2351 | 2230 2545 2352 | 2230 2545 2353 | 2230 2545 2354 | 2230 2545 2355 | 2230 2545 2356 | 2230 2545 2357 | 2230 2545 2358 | 2230 2545 2359 | 2230 2545 2360 | 2230 2545 2361 | 2230 2545 2362 | 2230 2545 2363 | 2230 2545 2364 | 2230 2545 2365 | 2230 2545 2366 | 2230 2545 2367 | 2230 2545 2368 | 2230 2545 2369 | 2230 2545 2370 | 2230 2545 2371 | 2230 2545 2372 | 2230 2545 2373 | 2230 2545 2374 | 2230 2545 2375 | 2230 2545 2376 | 2230 2545 2377 | 2230 2545 2378 | 2230 2545 2379 | 2230 2545 2380 | 2230 2545 2381 | 2230 2545 2382 | 2230 2545 2383 | 2230 2545 2384 | 2230 2545 2385 | 2230 2545 2386 | 2230 2545 2387 | 2230 2545 2388 | 2230 2545 2389 | 2230 2545 2390 | 2230 2545 2391 | 2230 2545 2392 | 2230 2545 2393 | 2230 2545 2394 | 2230 2545 2395 | 2230 2545 2396 | 2230 2545 2397 | 2230 2545 2398 | 2230 2545 2399 | 2230 2545 2400 | 2230 2545 2401 | 2230 2545 2402 | 2230 2545 2403 | 2230 2545 2404 | 2230 2545 2405 | 2230 2545 2406 | 2230 2545 2407 | 2230 2545 2408 | 2230 2545 2409 | 2230 2545 2410 | 2230 2545 2411 | 2230 2545 2412 | 2230 2545 2413 | 2230 2545 2414 | 2230 2545 2415 | 2230 2545 2416 | 2230 2545 2417 | 2230 2545 2418 | 2230 2545 2419 | 2230 2545 2420 | 2230 2545 2421 | 2230 2545 2422 | 2230 2545 2423 | 2230 2545 2424 | 2230 2545 2425 | 2230 2545 2426 | 2230 2545 2427 | 2230 2545 2428 | 2230 2545 2429 | 2230 2545 2430 | 2230 2545 2431 | 2230 2545 2432 | 2230 2545 2433 | 2230 2545 2434 | 2230 2545 2435 | 2230 2545 2436 | 2230 2545 2437 | 2230 2545 2438 | 2230 2545 2439 | 2230 2545 2440 | 2230 2545 2441 | 2230 2545 2442 | 2230 2545 2443 | 2230 2545 2444 | 2230 2545 2445 | 2230 2545 2446 | 2230 2545 2447 | 2230 2545 2448 | 2230 2545 2449 | 2230 2545 2450 | 2230 2545 2451 | 2230 2545 2452 | 2230 2545 2453 | 2230 2545 2454 | 2230 2545 2455 | 2230 2545 2456 | 2230 2545 2457 | 2230 2545 2458 | 2230 2545 2459 | 2230 2545 2460 | 2230 2545 2461 | 2230 2545 2462 | 2230 2545 2463 | 2230 2545 2464 | 2230 2545 2465 | 2230 2545 2466 | 2230 2545 2467 | 2230 2545 2468 | 2230 2545 2469 | 2230 2545 2470 | 2230 2545 2471 | 2230 2545 2472 | 2230 2545 2473 | 2230 2545 2474 | 2230 2545 2475 | 2230 2545 2476 | 2230 2545 2477 | 2230 2545 2478 | 2230 2545 2479 | 2230 2545 2480 | 2230 2545 2481 | 2230 2545 2482 | 2230 2545 2483 | 2230 2545 2484 | 2230 2545 2485 | 2230 2545 2486 | 2230 2545 2487 | 2230 2545 2488 | 2230 2545 2489 | 2230 2545 2490 | 2230 2545 2491 | 2230 2545 2492 | 2230 2545 2493 | 2230 2545 2494 | 2230 2545 2495 | 2230 2545 2496 | 2230 2545 2497 | 2230 2545 2498 | 2230 2545 2499 | 2230 2545 2500 | 2230 2545 2501 | 2230 2545 2502 | 2230 2545 2503 | 2230 2545 2504 | 2230 2545 2505 | 2230 2545 2506 | 2230 2545 2507 | 2230 2545 2508 | 2230 2545 2509 | 2230 2545 2510 | 2230 2545 2511 | 2230 2545 2512 | 2230 2545 2513 | 2230 2545 2514 | 2230 2545 2515 | 2230 2545 2516 | 2230 2545 2517 | 2230 2545 2518 | 2230 2545 2519 | 2230 2545 2520 | 2230 2545 2521 | 2230 2545 2522 | 2230 2545 2523 | 2230 2545 2524 | 2230 2545 2525 | 2230 2545 2526 | 2230 2545 2527 | 2230 2545 2528 | 2230 2545 2529 | 2230 2545 2530 | 2230 2545 2531 | 2230 2545 2532 | 2230 2545 2533 | 2230 2545 2534 | 2230 2545 2535 | 2230 2545 2536 | 2230 2545 2537 | 2230 2545 2538 | 2230 2545 2539 | 2230 2545 2540 | 2230 2545 2541 | 2230 2545 2542 | 2230 2545 2543 | 2230 2545 2544 | 2230 2545 2545 | 2230 2545 2546 | 2545 2551 2547 | 2545 2551 2548 | 2545 2551 2549 | 2545 2551 2550 | 2545 2551 2551 | 2545 2551 2552 | 2551 2556 2553 | 2551 2556 2554 | 2551 2556 2555 | 2551 2556 2556 | 2551 2556 2557 | 2556 2753 2558 | 2556 2753 2559 | 2556 2753 2560 | 2556 2753 2561 | 2556 2753 2562 | 2556 2753 2563 | 2556 2753 2564 | 2556 2753 2565 | 2556 2753 2566 | 2556 2753 2567 | 2556 2753 2568 | 2556 2753 2569 | 2556 2753 2570 | 2556 2753 2571 | 2556 2753 2572 | 2556 2753 2573 | 2556 2753 2574 | 2556 2753 2575 | 2556 2753 2576 | 2556 2753 2577 | 2556 2753 2578 | 2556 2753 2579 | 2556 2753 2580 | 2556 2753 2581 | 2556 2753 2582 | 2556 2753 2583 | 2556 2753 2584 | 2556 2753 2585 | 2556 2753 2586 | 2556 2753 2587 | 2556 2753 2588 | 2556 2753 2589 | 2556 2753 2590 | 2556 2753 2591 | 2556 2753 2592 | 2556 2753 2593 | 2556 2753 2594 | 2556 2753 2595 | 2556 2753 2596 | 2556 2753 2597 | 2556 2753 2598 | 2556 2753 2599 | 2556 2753 2600 | 2556 2753 2601 | 2556 2753 2602 | 2556 2753 2603 | 2556 2753 2604 | 2556 2753 2605 | 2556 2753 2606 | 2556 2753 2607 | 2556 2753 2608 | 2556 2753 2609 | 2556 2753 2610 | 2556 2753 2611 | 2556 2753 2612 | 2556 2753 2613 | 2556 2753 2614 | 2556 2753 2615 | 2556 2753 2616 | 2556 2753 2617 | 2556 2753 2618 | 2556 2753 2619 | 2556 2753 2620 | 2556 2753 2621 | 2556 2753 2622 | 2556 2753 2623 | 2556 2753 2624 | 2556 2753 2625 | 2556 2753 2626 | 2556 2753 2627 | 2556 2753 2628 | 2556 2753 2629 | 2556 2753 2630 | 2556 2753 2631 | 2556 2753 2632 | 2556 2753 2633 | 2556 2753 2634 | 2556 2753 2635 | 2556 2753 2636 | 2556 2753 2637 | 2556 2753 2638 | 2556 2753 2639 | 2556 2753 2640 | 2556 2753 2641 | 2556 2753 2642 | 2556 2753 2643 | 2556 2753 2644 | 2556 2753 2645 | 2556 2753 2646 | 2556 2753 2647 | 2556 2753 2648 | 2556 2753 2649 | 2556 2753 2650 | 2556 2753 2651 | 2556 2753 2652 | 2556 2753 2653 | 2556 2753 2654 | 2556 2753 2655 | 2556 2753 2656 | 2556 2753 2657 | 2556 2753 2658 | 2556 2753 2659 | 2556 2753 2660 | 2556 2753 2661 | 2556 2753 2662 | 2556 2753 2663 | 2556 2753 2664 | 2556 2753 2665 | 2556 2753 2666 | 2556 2753 2667 | 2556 2753 2668 | 2556 2753 2669 | 2556 2753 2670 | 2556 2753 2671 | 2556 2753 2672 | 2556 2753 2673 | 2556 2753 2674 | 2556 2753 2675 | 2556 2753 2676 | 2556 2753 2677 | 2556 2753 2678 | 2556 2753 2679 | 2556 2753 2680 | 2556 2753 2681 | 2556 2753 2682 | 2556 2753 2683 | 2556 2753 2684 | 2556 2753 2685 | 2556 2753 2686 | 2556 2753 2687 | 2556 2753 2688 | 2556 2753 2689 | 2556 2753 2690 | 2556 2753 2691 | 2556 2753 2692 | 2556 2753 2693 | 2556 2753 2694 | 2556 2753 2695 | 2556 2753 2696 | 2556 2753 2697 | 2556 2753 2698 | 2556 2753 2699 | 2556 2753 2700 | 2556 2753 2701 | 2556 2753 2702 | 2556 2753 2703 | 2556 2753 2704 | 2556 2753 2705 | 2556 2753 2706 | 2556 2753 2707 | 2556 2753 2708 | 2556 2753 2709 | 2556 2753 2710 | 2556 2753 2711 | 2556 2753 2712 | 2556 2753 2713 | 2556 2753 2714 | 2556 2753 2715 | 2556 2753 2716 | 2556 2753 2717 | 2556 2753 2718 | 2556 2753 2719 | 2556 2753 2720 | 2556 2753 2721 | 2556 2753 2722 | 2556 2753 2723 | 2556 2753 2724 | 2556 2753 2725 | 2556 2753 2726 | 2556 2753 2727 | 2556 2753 2728 | 2556 2753 2729 | 2556 2753 2730 | 2556 2753 2731 | 2556 2753 2732 | 2556 2753 2733 | 2556 2753 2734 | 2556 2753 2735 | 2556 2753 2736 | 2556 2753 2737 | 2556 2753 2738 | 2556 2753 2739 | 2556 2753 2740 | 2556 2753 2741 | 2556 2753 2742 | 2556 2753 2743 | 2556 2753 2744 | 2556 2753 2745 | 2556 2753 2746 | 2556 2753 2747 | 2556 2753 2748 | 2556 2753 2749 | 2556 2753 2750 | 2556 2753 2751 | 2556 2753 2752 | 2556 2753 2753 | 2556 2753 2754 | 2753 2765 2755 | 2753 2765 2756 | 2753 2765 2757 | 2753 2765 2758 | 2753 2765 2759 | 2753 2765 2760 | 2753 2765 2761 | 2753 2765 2762 | 2753 2765 2763 | 2753 2765 2764 | 2753 2765 2765 | 2753 2765 2766 | 2765 2773 2767 | 2765 2773 2768 | 2765 2773 2769 | 2765 2773 2770 | 2765 2773 2771 | 2765 2773 2772 | 2765 2773 2773 | 2765 2773 2774 | 2773 2782 2775 | 2773 2782 2776 | 2773 2782 2777 | 2773 2782 2778 | 2773 2782 2779 | 2773 2782 2780 | 2773 2782 2781 | 2773 2782 2782 | 2773 2782 2783 | 2782 2803 2784 | 2782 2803 2785 | 2782 2803 2786 | 2782 2803 2787 | 2782 2803 2788 | 2782 2803 2789 | 2782 2803 2790 | 2782 2803 2791 | 2782 2803 2792 | 2782 2803 2793 | 2782 2803 2794 | 2782 2803 2795 | 2782 2803 2796 | 2782 2803 2797 | 2782 2803 2798 | 2782 2803 2799 | 2782 2803 2800 | 2782 2803 2801 | 2782 2803 2802 | 2782 2803 2803 | 2782 2803 2804 | 2803 3084 2805 | 2803 3084 2806 | 2803 3084 2807 | 2803 3084 2808 | 2803 3084 2809 | 2803 3084 2810 | 2803 3084 2811 | 2803 3084 2812 | 2803 3084 2813 | 2803 3084 2814 | 2803 3084 2815 | 2803 3084 2816 | 2803 3084 2817 | 2803 3084 2818 | 2803 3084 2819 | 2803 3084 2820 | 2803 3084 2821 | 2803 3084 2822 | 2803 3084 2823 | 2803 3084 2824 | 2803 3084 2825 | 2803 3084 2826 | 2803 3084 2827 | 2803 3084 2828 | 2803 3084 2829 | 2803 3084 2830 | 2803 3084 2831 | 2803 3084 2832 | 2803 3084 2833 | 2803 3084 2834 | 2803 3084 2835 | 2803 3084 2836 | 2803 3084 2837 | 2803 3084 2838 | 2803 3084 2839 | 2803 3084 2840 | 2803 3084 2841 | 2803 3084 2842 | 2803 3084 2843 | 2803 3084 2844 | 2803 3084 2845 | 2803 3084 2846 | 2803 3084 2847 | 2803 3084 2848 | 2803 3084 2849 | 2803 3084 2850 | 2803 3084 2851 | 2803 3084 2852 | 2803 3084 2853 | 2803 3084 2854 | 2803 3084 2855 | 2803 3084 2856 | 2803 3084 2857 | 2803 3084 2858 | 2803 3084 2859 | 2803 3084 2860 | 2803 3084 2861 | 2803 3084 2862 | 2803 3084 2863 | 2803 3084 2864 | 2803 3084 2865 | 2803 3084 2866 | 2803 3084 2867 | 2803 3084 2868 | 2803 3084 2869 | 2803 3084 2870 | 2803 3084 2871 | 2803 3084 2872 | 2803 3084 2873 | 2803 3084 2874 | 2803 3084 2875 | 2803 3084 2876 | 2803 3084 2877 | 2803 3084 2878 | 2803 3084 2879 | 2803 3084 2880 | 2803 3084 2881 | 2803 3084 2882 | 2803 3084 2883 | 2803 3084 2884 | 2803 3084 2885 | 2803 3084 2886 | 2803 3084 2887 | 2803 3084 2888 | 2803 3084 2889 | 2803 3084 2890 | 2803 3084 2891 | 2803 3084 2892 | 2803 3084 2893 | 2803 3084 2894 | 2803 3084 2895 | 2803 3084 2896 | 2803 3084 2897 | 2803 3084 2898 | 2803 3084 2899 | 2803 3084 2900 | 2803 3084 2901 | 2803 3084 2902 | 2803 3084 2903 | 2803 3084 2904 | 2803 3084 2905 | 2803 3084 2906 | 2803 3084 2907 | 2803 3084 2908 | 2803 3084 2909 | 2803 3084 2910 | 2803 3084 2911 | 2803 3084 2912 | 2803 3084 2913 | 2803 3084 2914 | 2803 3084 2915 | 2803 3084 2916 | 2803 3084 2917 | 2803 3084 2918 | 2803 3084 2919 | 2803 3084 2920 | 2803 3084 2921 | 2803 3084 2922 | 2803 3084 2923 | 2803 3084 2924 | 2803 3084 2925 | 2803 3084 2926 | 2803 3084 2927 | 2803 3084 2928 | 2803 3084 2929 | 2803 3084 2930 | 2803 3084 2931 | 2803 3084 2932 | 2803 3084 2933 | 2803 3084 2934 | 2803 3084 2935 | 2803 3084 2936 | 2803 3084 2937 | 2803 3084 2938 | 2803 3084 2939 | 2803 3084 2940 | 2803 3084 2941 | 2803 3084 2942 | 2803 3084 2943 | 2803 3084 2944 | 2803 3084 2945 | 2803 3084 2946 | 2803 3084 2947 | 2803 3084 2948 | 2803 3084 2949 | 2803 3084 2950 | 2803 3084 2951 | 2803 3084 2952 | 2803 3084 2953 | 2803 3084 2954 | 2803 3084 2955 | 2803 3084 2956 | 2803 3084 2957 | 2803 3084 2958 | 2803 3084 2959 | 2803 3084 2960 | 2803 3084 2961 | 2803 3084 2962 | 2803 3084 2963 | 2803 3084 2964 | 2803 3084 2965 | 2803 3084 2966 | 2803 3084 2967 | 2803 3084 2968 | 2803 3084 2969 | 2803 3084 2970 | 2803 3084 2971 | 2803 3084 2972 | 2803 3084 2973 | 2803 3084 2974 | 2803 3084 2975 | 2803 3084 2976 | 2803 3084 2977 | 2803 3084 2978 | 2803 3084 2979 | 2803 3084 2980 | 2803 3084 2981 | 2803 3084 2982 | 2803 3084 2983 | 2803 3084 2984 | 2803 3084 2985 | 2803 3084 2986 | 2803 3084 2987 | 2803 3084 2988 | 2803 3084 2989 | 2803 3084 2990 | 2803 3084 2991 | 2803 3084 2992 | 2803 3084 2993 | 2803 3084 2994 | 2803 3084 2995 | 2803 3084 2996 | 2803 3084 2997 | 2803 3084 2998 | 2803 3084 2999 | 2803 3084 3000 | 2803 3084 3001 | 2803 3084 3002 | 2803 3084 3003 | 2803 3084 3004 | 2803 3084 3005 | 2803 3084 3006 | 2803 3084 3007 | 2803 3084 3008 | 2803 3084 3009 | 2803 3084 3010 | 2803 3084 3011 | 2803 3084 3012 | 2803 3084 3013 | 2803 3084 3014 | 2803 3084 3015 | 2803 3084 3016 | 2803 3084 3017 | 2803 3084 3018 | 2803 3084 3019 | 2803 3084 3020 | 2803 3084 3021 | 2803 3084 3022 | 2803 3084 3023 | 2803 3084 3024 | 2803 3084 3025 | 2803 3084 3026 | 2803 3084 3027 | 2803 3084 3028 | 2803 3084 3029 | 2803 3084 3030 | 2803 3084 3031 | 2803 3084 3032 | 2803 3084 3033 | 2803 3084 3034 | 2803 3084 3035 | 2803 3084 3036 | 2803 3084 3037 | 2803 3084 3038 | 2803 3084 3039 | 2803 3084 3040 | 2803 3084 3041 | 2803 3084 3042 | 2803 3084 3043 | 2803 3084 3044 | 2803 3084 3045 | 2803 3084 3046 | 2803 3084 3047 | 2803 3084 3048 | 2803 3084 3049 | 2803 3084 3050 | 2803 3084 3051 | 2803 3084 3052 | 2803 3084 3053 | 2803 3084 3054 | 2803 3084 3055 | 2803 3084 3056 | 2803 3084 3057 | 2803 3084 3058 | 2803 3084 3059 | 2803 3084 3060 | 2803 3084 3061 | 2803 3084 3062 | 2803 3084 3063 | 2803 3084 3064 | 2803 3084 3065 | 2803 3084 3066 | 2803 3084 3067 | 2803 3084 3068 | 2803 3084 3069 | 2803 3084 3070 | 2803 3084 3071 | 2803 3084 3072 | 2803 3084 3073 | 2803 3084 3074 | 2803 3084 3075 | 2803 3084 3076 | 2803 3084 3077 | 2803 3084 3078 | 2803 3084 3079 | 2803 3084 3080 | 2803 3084 3081 | 2803 3084 3082 | 2803 3084 3083 | 2803 3084 3084 | 2803 3084 3085 | 3084 3142 3086 | 3084 3142 3087 | 3084 3142 3088 | 3084 3142 3089 | 3084 3142 3090 | 3084 3142 3091 | 3084 3142 3092 | 3084 3142 3093 | 3084 3142 3094 | 3084 3142 3095 | 3084 3142 3096 | 3084 3142 3097 | 3084 3142 3098 | 3084 3142 3099 | 3084 3142 3100 | 3084 3142 3101 | 3084 3142 3102 | 3084 3142 3103 | 3084 3142 3104 | 3084 3142 3105 | 3084 3142 3106 | 3084 3142 3107 | 3084 3142 3108 | 3084 3142 3109 | 3084 3142 3110 | 3084 3142 3111 | 3084 3142 3112 | 3084 3142 3113 | 3084 3142 3114 | 3084 3142 3115 | 3084 3142 3116 | 3084 3142 3117 | 3084 3142 3118 | 3084 3142 3119 | 3084 3142 3120 | 3084 3142 3121 | 3084 3142 3122 | 3084 3142 3123 | 3084 3142 3124 | 3084 3142 3125 | 3084 3142 3126 | 3084 3142 3127 | 3084 3142 3128 | 3084 3142 3129 | 3084 3142 3130 | 3084 3142 3131 | 3084 3142 3132 | 3084 3142 3133 | 3084 3142 3134 | 3084 3142 3135 | 3084 3142 3136 | 3084 3142 3137 | 3084 3142 3138 | 3084 3142 3139 | 3084 3142 3140 | 3084 3142 3141 | 3084 3142 3142 | 3084 3142 3143 | 3142 3302 3144 | 3142 3302 3145 | 3142 3302 3146 | 3142 3302 3147 | 3142 3302 3148 | 3142 3302 3149 | 3142 3302 3150 | 3142 3302 3151 | 3142 3302 3152 | 3142 3302 3153 | 3142 3302 3154 | 3142 3302 3155 | 3142 3302 3156 | 3142 3302 3157 | 3142 3302 3158 | 3142 3302 3159 | 3142 3302 3160 | 3142 3302 3161 | 3142 3302 3162 | 3142 3302 3163 | 3142 3302 3164 | 3142 3302 3165 | 3142 3302 3166 | 3142 3302 3167 | 3142 3302 3168 | 3142 3302 3169 | 3142 3302 3170 | 3142 3302 3171 | 3142 3302 3172 | 3142 3302 3173 | 3142 3302 3174 | 3142 3302 3175 | 3142 3302 3176 | 3142 3302 3177 | 3142 3302 3178 | 3142 3302 3179 | 3142 3302 3180 | 3142 3302 3181 | 3142 3302 3182 | 3142 3302 3183 | 3142 3302 3184 | 3142 3302 3185 | 3142 3302 3186 | 3142 3302 3187 | 3142 3302 3188 | 3142 3302 3189 | 3142 3302 3190 | 3142 3302 3191 | 3142 3302 3192 | 3142 3302 3193 | 3142 3302 3194 | 3142 3302 3195 | 3142 3302 3196 | 3142 3302 3197 | 3142 3302 3198 | 3142 3302 3199 | 3142 3302 3200 | 3142 3302 3201 | 3142 3302 3202 | 3142 3302 3203 | 3142 3302 3204 | 3142 3302 3205 | 3142 3302 3206 | 3142 3302 3207 | 3142 3302 3208 | 3142 3302 3209 | 3142 3302 3210 | 3142 3302 3211 | 3142 3302 3212 | 3142 3302 3213 | 3142 3302 3214 | 3142 3302 3215 | 3142 3302 3216 | 3142 3302 3217 | 3142 3302 3218 | 3142 3302 3219 | 3142 3302 3220 | 3142 3302 3221 | 3142 3302 3222 | 3142 3302 3223 | 3142 3302 3224 | 3142 3302 3225 | 3142 3302 3226 | 3142 3302 3227 | 3142 3302 3228 | 3142 3302 3229 | 3142 3302 3230 | 3142 3302 3231 | 3142 3302 3232 | 3142 3302 3233 | 3142 3302 3234 | 3142 3302 3235 | 3142 3302 3236 | 3142 3302 3237 | 3142 3302 3238 | 3142 3302 3239 | 3142 3302 3240 | 3142 3302 3241 | 3142 3302 3242 | 3142 3302 3243 | 3142 3302 3244 | 3142 3302 3245 | 3142 3302 3246 | 3142 3302 3247 | 3142 3302 3248 | 3142 3302 3249 | 3142 3302 3250 | 3142 3302 3251 | 3142 3302 3252 | 3142 3302 3253 | 3142 3302 3254 | 3142 3302 3255 | 3142 3302 3256 | 3142 3302 3257 | 3142 3302 3258 | 3142 3302 3259 | 3142 3302 3260 | 3142 3302 3261 | 3142 3302 3262 | 3142 3302 3263 | 3142 3302 3264 | 3142 3302 3265 | 3142 3302 3266 | 3142 3302 3267 | 3142 3302 3268 | 3142 3302 3269 | 3142 3302 3270 | 3142 3302 3271 | 3142 3302 3272 | 3142 3302 3273 | 3142 3302 3274 | 3142 3302 3275 | 3142 3302 3276 | 3142 3302 3277 | 3142 3302 3278 | 3142 3302 3279 | 3142 3302 3280 | 3142 3302 3281 | 3142 3302 3282 | 3142 3302 3283 | 3142 3302 3284 | 3142 3302 3285 | 3142 3302 3286 | 3142 3302 3287 | 3142 3302 3288 | 3142 3302 3289 | 3142 3302 3290 | 3142 3302 3291 | 3142 3302 3292 | 3142 3302 3293 | 3142 3302 3294 | 3142 3302 3295 | 3142 3302 3296 | 3142 3302 3297 | 3142 3302 3298 | 3142 3302 3299 | 3142 3302 3300 | 3142 3302 3301 | 3142 3302 3302 | 3142 3302 3303 | 3302 3313 3304 | 3302 3313 3305 | 3302 3313 3306 | 3302 3313 3307 | 3302 3313 3308 | 3302 3313 3309 | 3302 3313 3310 | 3302 3313 3311 | 3302 3313 3312 | 3302 3313 3313 | 3302 3313 3314 | 3313 3374 3315 | 3313 3374 3316 | 3313 3374 3317 | 3313 3374 3318 | 3313 3374 3319 | 3313 3374 3320 | 3313 3374 3321 | 3313 3374 3322 | 3313 3374 3323 | 3313 3374 3324 | 3313 3374 3325 | 3313 3374 3326 | 3313 3374 3327 | 3313 3374 3328 | 3313 3374 3329 | 3313 3374 3330 | 3313 3374 3331 | 3313 3374 3332 | 3313 3374 3333 | 3313 3374 3334 | 3313 3374 3335 | 3313 3374 3336 | 3313 3374 3337 | 3313 3374 3338 | 3313 3374 3339 | 3313 3374 3340 | 3313 3374 3341 | 3313 3374 3342 | 3313 3374 3343 | 3313 3374 3344 | 3313 3374 3345 | 3313 3374 3346 | 3313 3374 3347 | 3313 3374 3348 | 3313 3374 3349 | 3313 3374 3350 | 3313 3374 3351 | 3313 3374 3352 | 3313 3374 3353 | 3313 3374 3354 | 3313 3374 3355 | 3313 3374 3356 | 3313 3374 3357 | 3313 3374 3358 | 3313 3374 3359 | 3313 3374 3360 | 3313 3374 3361 | 3313 3374 3362 | 3313 3374 3363 | 3313 3374 3364 | 3313 3374 3365 | 3313 3374 3366 | 3313 3374 3367 | 3313 3374 3368 | 3313 3374 3369 | 3313 3374 3370 | 3313 3374 3371 | 3313 3374 3372 | 3313 3374 3373 | 3313 3374 3374 | 3313 3374 3375 | 3374 3423 3376 | 3374 3423 3377 | 3374 3423 3378 | 3374 3423 3379 | 3374 3423 3380 | 3374 3423 3381 | 3374 3423 3382 | 3374 3423 3383 | 3374 3423 3384 | 3374 3423 3385 | 3374 3423 3386 | 3374 3423 3387 | 3374 3423 3388 | 3374 3423 3389 | 3374 3423 3390 | 3374 3423 3391 | 3374 3423 3392 | 3374 3423 3393 | 3374 3423 3394 | 3374 3423 3395 | 3374 3423 3396 | 3374 3423 3397 | 3374 3423 3398 | 3374 3423 3399 | 3374 3423 3400 | 3374 3423 3401 | 3374 3423 3402 | 3374 3423 3403 | 3374 3423 3404 | 3374 3423 3405 | 3374 3423 3406 | 3374 3423 3407 | 3374 3423 3408 | 3374 3423 3409 | 3374 3423 3410 | 3374 3423 3411 | 3374 3423 3412 | 3374 3423 3413 | 3374 3423 3414 | 3374 3423 3415 | 3374 3423 3416 | 3374 3423 3417 | 3374 3423 3418 | 3374 3423 3419 | 3374 3423 3420 | 3374 3423 3421 | 3374 3423 3422 | 3374 3423 3423 | 3374 3423 3424 | 3423 3434 3425 | 3423 3434 3426 | 3423 3434 3427 | 3423 3434 3428 | 3423 3434 3429 | 3423 3434 3430 | 3423 3434 3431 | 3423 3434 3432 | 3423 3434 3433 | 3423 3434 3434 | 3423 3434 3435 | 3434 3446 3436 | 3434 3446 3437 | 3434 3446 3438 | 3434 3446 3439 | 3434 3446 3440 | 3434 3446 3441 | 3434 3446 3442 | 3434 3446 3443 | 3434 3446 3444 | 3434 3446 3445 | 3434 3446 3446 | 3434 3446 3447 | 3446 3489 3448 | 3446 3489 3449 | 3446 3489 3450 | 3446 3489 3451 | 3446 3489 3452 | 3446 3489 3453 | 3446 3489 3454 | 3446 3489 3455 | 3446 3489 3456 | 3446 3489 3457 | 3446 3489 3458 | 3446 3489 3459 | 3446 3489 3460 | 3446 3489 3461 | 3446 3489 3462 | 3446 3489 3463 | 3446 3489 3464 | 3446 3489 3465 | 3446 3489 3466 | 3446 3489 3467 | 3446 3489 3468 | 3446 3489 3469 | 3446 3489 3470 | 3446 3489 3471 | 3446 3489 3472 | 3446 3489 3473 | 3446 3489 3474 | 3446 3489 3475 | 3446 3489 3476 | 3446 3489 3477 | 3446 3489 3478 | 3446 3489 3479 | 3446 3489 3480 | 3446 3489 3481 | 3446 3489 3482 | 3446 3489 3483 | 3446 3489 3484 | 3446 3489 3485 | 3446 3489 3486 | 3446 3489 3487 | 3446 3489 3488 | 3446 3489 3489 | 3446 3489 3490 | 3489 3635 3491 | 3489 3635 3492 | 3489 3635 3493 | 3489 3635 3494 | 3489 3635 3495 | 3489 3635 3496 | 3489 3635 3497 | 3489 3635 3498 | 3489 3635 3499 | 3489 3635 3500 | 3489 3635 3501 | 3489 3635 3502 | 3489 3635 3503 | 3489 3635 3504 | 3489 3635 3505 | 3489 3635 3506 | 3489 3635 3507 | 3489 3635 3508 | 3489 3635 3509 | 3489 3635 3510 | 3489 3635 3511 | 3489 3635 3512 | 3489 3635 3513 | 3489 3635 3514 | 3489 3635 3515 | 3489 3635 3516 | 3489 3635 3517 | 3489 3635 3518 | 3489 3635 3519 | 3489 3635 3520 | 3489 3635 3521 | 3489 3635 3522 | 3489 3635 3523 | 3489 3635 3524 | 3489 3635 3525 | 3489 3635 3526 | 3489 3635 3527 | 3489 3635 3528 | 3489 3635 3529 | 3489 3635 3530 | 3489 3635 3531 | 3489 3635 3532 | 3489 3635 3533 | 3489 3635 3534 | 3489 3635 3535 | 3489 3635 3536 | 3489 3635 3537 | 3489 3635 3538 | 3489 3635 3539 | 3489 3635 3540 | 3489 3635 3541 | 3489 3635 3542 | 3489 3635 3543 | 3489 3635 3544 | 3489 3635 3545 | 3489 3635 3546 | 3489 3635 3547 | 3489 3635 3548 | 3489 3635 3549 | 3489 3635 3550 | 3489 3635 3551 | 3489 3635 3552 | 3489 3635 3553 | 3489 3635 3554 | 3489 3635 3555 | 3489 3635 3556 | 3489 3635 3557 | 3489 3635 3558 | 3489 3635 3559 | 3489 3635 3560 | 3489 3635 3561 | 3489 3635 3562 | 3489 3635 3563 | 3489 3635 3564 | 3489 3635 3565 | 3489 3635 3566 | 3489 3635 3567 | 3489 3635 3568 | 3489 3635 3569 | 3489 3635 3570 | 3489 3635 3571 | 3489 3635 3572 | 3489 3635 3573 | 3489 3635 3574 | 3489 3635 3575 | 3489 3635 3576 | 3489 3635 3577 | 3489 3635 3578 | 3489 3635 3579 | 3489 3635 3580 | 3489 3635 3581 | 3489 3635 3582 | 3489 3635 3583 | 3489 3635 3584 | 3489 3635 3585 | 3489 3635 3586 | 3489 3635 3587 | 3489 3635 3588 | 3489 3635 3589 | 3489 3635 3590 | 3489 3635 3591 | 3489 3635 3592 | 3489 3635 3593 | 3489 3635 3594 | 3489 3635 3595 | 3489 3635 3596 | 3489 3635 3597 | 3489 3635 3598 | 3489 3635 3599 | 3489 3635 3600 | 3489 3635 3601 | 3489 3635 3602 | 3489 3635 3603 | 3489 3635 3604 | 3489 3635 3605 | 3489 3635 3606 | 3489 3635 3607 | 3489 3635 3608 | 3489 3635 3609 | 3489 3635 3610 | 3489 3635 3611 | 3489 3635 3612 | 3489 3635 3613 | 3489 3635 3614 | 3489 3635 3615 | 3489 3635 3616 | 3489 3635 3617 | 3489 3635 3618 | 3489 3635 3619 | 3489 3635 3620 | 3489 3635 3621 | 3489 3635 3622 | 3489 3635 3623 | 3489 3635 3624 | 3489 3635 3625 | 3489 3635 3626 | 3489 3635 3627 | 3489 3635 3628 | 3489 3635 3629 | 3489 3635 3630 | 3489 3635 3631 | 3489 3635 3632 | 3489 3635 3633 | 3489 3635 3634 | 3489 3635 3635 | 3489 3635 3636 | 3635 3642 3637 | 3635 3642 3638 | 3635 3642 3639 | 3635 3642 3640 | 3635 3642 3641 | 3635 3642 3642 | 3635 3642 3643 | 3642 3665 3644 | 3642 3665 3645 | 3642 3665 3646 | 3642 3665 3647 | 3642 3665 3648 | 3642 3665 3649 | 3642 3665 3650 | 3642 3665 3651 | 3642 3665 3652 | 3642 3665 3653 | 3642 3665 3654 | 3642 3665 3655 | 3642 3665 3656 | 3642 3665 3657 | 3642 3665 3658 | 3642 3665 3659 | 3642 3665 3660 | 3642 3665 3661 | 3642 3665 3662 | 3642 3665 3663 | 3642 3665 3664 | 3642 3665 3665 | 3642 3665 3666 | 3665 3683 3667 | 3665 3683 3668 | 3665 3683 3669 | 3665 3683 3670 | 3665 3683 3671 | 3665 3683 3672 | 3665 3683 3673 | 3665 3683 3674 | 3665 3683 3675 | 3665 3683 3676 | 3665 3683 3677 | 3665 3683 3678 | 3665 3683 3679 | 3665 3683 3680 | 3665 3683 3681 | 3665 3683 3682 | 3665 3683 3683 | 3665 3683 3684 | 3683 3771 3685 | 3683 3771 3686 | 3683 3771 3687 | 3683 3771 3688 | 3683 3771 3689 | 3683 3771 3690 | 3683 3771 3691 | 3683 3771 3692 | 3683 3771 3693 | 3683 3771 3694 | 3683 3771 3695 | 3683 3771 3696 | 3683 3771 3697 | 3683 3771 3698 | 3683 3771 3699 | 3683 3771 3700 | 3683 3771 3701 | 3683 3771 3702 | 3683 3771 3703 | 3683 3771 3704 | 3683 3771 3705 | 3683 3771 3706 | 3683 3771 3707 | 3683 3771 3708 | 3683 3771 3709 | 3683 3771 3710 | 3683 3771 3711 | 3683 3771 3712 | 3683 3771 3713 | 3683 3771 3714 | 3683 3771 3715 | 3683 3771 3716 | 3683 3771 3717 | 3683 3771 3718 | 3683 3771 3719 | 3683 3771 3720 | 3683 3771 3721 | 3683 3771 3722 | 3683 3771 3723 | 3683 3771 3724 | 3683 3771 3725 | 3683 3771 3726 | 3683 3771 3727 | 3683 3771 3728 | 3683 3771 3729 | 3683 3771 3730 | 3683 3771 3731 | 3683 3771 3732 | 3683 3771 3733 | 3683 3771 3734 | 3683 3771 3735 | 3683 3771 3736 | 3683 3771 3737 | 3683 3771 3738 | 3683 3771 3739 | 3683 3771 3740 | 3683 3771 3741 | 3683 3771 3742 | 3683 3771 3743 | 3683 3771 3744 | 3683 3771 3745 | 3683 3771 3746 | 3683 3771 3747 | 3683 3771 3748 | 3683 3771 3749 | 3683 3771 3750 | 3683 3771 3751 | 3683 3771 3752 | 3683 3771 3753 | 3683 3771 3754 | 3683 3771 3755 | 3683 3771 3756 | 3683 3771 3757 | 3683 3771 3758 | 3683 3771 3759 | 3683 3771 3760 | 3683 3771 3761 | 3683 3771 3762 | 3683 3771 3763 | 3683 3771 3764 | 3683 3771 3765 | 3683 3771 3766 | 3683 3771 3767 | 3683 3771 3768 | 3683 3771 3769 | 3683 3771 3770 | 3683 3771 3771 | 3683 3771 3772 | 3771 3796 3773 | 3771 3796 3774 | 3771 3796 3775 | 3771 3796 3776 | 3771 3796 3777 | 3771 3796 3778 | 3771 3796 3779 | 3771 3796 3780 | 3771 3796 3781 | 3771 3796 3782 | 3771 3796 3783 | 3771 3796 3784 | 3771 3796 3785 | 3771 3796 3786 | 3771 3796 3787 | 3771 3796 3788 | 3771 3796 3789 | 3771 3796 3790 | 3771 3796 3791 | 3771 3796 3792 | 3771 3796 3793 | 3771 3796 3794 | 3771 3796 3795 | 3771 3796 3796 | 3771 3796 3797 | 3796 3838 3798 | 3796 3838 3799 | 3796 3838 3800 | 3796 3838 3801 | 3796 3838 3802 | 3796 3838 3803 | 3796 3838 3804 | 3796 3838 3805 | 3796 3838 3806 | 3796 3838 3807 | 3796 3838 3808 | 3796 3838 3809 | 3796 3838 3810 | 3796 3838 3811 | 3796 3838 3812 | 3796 3838 3813 | 3796 3838 3814 | 3796 3838 3815 | 3796 3838 3816 | 3796 3838 3817 | 3796 3838 3818 | 3796 3838 3819 | 3796 3838 3820 | 3796 3838 3821 | 3796 3838 3822 | 3796 3838 3823 | 3796 3838 3824 | 3796 3838 3825 | 3796 3838 3826 | 3796 3838 3827 | 3796 3838 3828 | 3796 3838 3829 | 3796 3838 3830 | 3796 3838 3831 | 3796 3838 3832 | 3796 3838 3833 | 3796 3838 3834 | 3796 3838 3835 | 3796 3838 3836 | 3796 3838 3837 | 3796 3838 3838 | 3796 3838 3839 | 3838 3856 3840 | 3838 3856 3841 | 3838 3856 3842 | 3838 3856 3843 | 3838 3856 3844 | 3838 3856 3845 | 3838 3856 3846 | 3838 3856 3847 | 3838 3856 3848 | 3838 3856 3849 | 3838 3856 3850 | 3838 3856 3851 | 3838 3856 3852 | 3838 3856 3853 | 3838 3856 3854 | 3838 3856 3855 | 3838 3856 3856 | 3838 3856 3857 | 3856 3878 3858 | 3856 3878 3859 | 3856 3878 3860 | 3856 3878 3861 | 3856 3878 3862 | 3856 3878 3863 | 3856 3878 3864 | 3856 3878 3865 | 3856 3878 3866 | 3856 3878 3867 | 3856 3878 3868 | 3856 3878 3869 | 3856 3878 3870 | 3856 3878 3871 | 3856 3878 3872 | 3856 3878 3873 | 3856 3878 3874 | 3856 3878 3875 | 3856 3878 3876 | 3856 3878 3877 | 3856 3878 3878 | 3856 3878 3879 | 3878 3881 3880 | 3878 3881 3881 | 3878 3881 3882 | 3881 3911 3883 | 3881 3911 3884 | 3881 3911 3885 | 3881 3911 3886 | 3881 3911 3887 | 3881 3911 3888 | 3881 3911 3889 | 3881 3911 3890 | 3881 3911 3891 | 3881 3911 3892 | 3881 3911 3893 | 3881 3911 3894 | 3881 3911 3895 | 3881 3911 3896 | 3881 3911 3897 | 3881 3911 3898 | 3881 3911 3899 | 3881 3911 3900 | 3881 3911 3901 | 3881 3911 3902 | 3881 3911 3903 | 3881 3911 3904 | 3881 3911 3905 | 3881 3911 3906 | 3881 3911 3907 | 3881 3911 3908 | 3881 3911 3909 | 3881 3911 3910 | 3881 3911 3911 | 3881 3911 3912 | 3911 3950 3913 | 3911 3950 3914 | 3911 3950 3915 | 3911 3950 3916 | 3911 3950 3917 | 3911 3950 3918 | 3911 3950 3919 | 3911 3950 3920 | 3911 3950 3921 | 3911 3950 3922 | 3911 3950 3923 | 3911 3950 3924 | 3911 3950 3925 | 3911 3950 3926 | 3911 3950 3927 | 3911 3950 3928 | 3911 3950 3929 | 3911 3950 3930 | 3911 3950 3931 | 3911 3950 3932 | 3911 3950 3933 | 3911 3950 3934 | 3911 3950 3935 | 3911 3950 3936 | 3911 3950 3937 | 3911 3950 3938 | 3911 3950 3939 | 3911 3950 3940 | 3911 3950 3941 | 3911 3950 3942 | 3911 3950 3943 | 3911 3950 3944 | 3911 3950 3945 | 3911 3950 3946 | 3911 3950 3947 | 3911 3950 3948 | 3911 3950 3949 | 3911 3950 3950 | 3911 3950 3951 | 3950 3962 3952 | 3950 3962 3953 | 3950 3962 3954 | 3950 3962 3955 | 3950 3962 3956 | 3950 3962 3957 | 3950 3962 3958 | 3950 3962 3959 | 3950 3962 3960 | 3950 3962 3961 | 3950 3962 3962 | 3950 3962 3963 | 3962 4146 3964 | 3962 4146 3965 | 3962 4146 3966 | 3962 4146 3967 | 3962 4146 3968 | 3962 4146 3969 | 3962 4146 3970 | 3962 4146 3971 | 3962 4146 3972 | 3962 4146 3973 | 3962 4146 3974 | 3962 4146 3975 | 3962 4146 3976 | 3962 4146 3977 | 3962 4146 3978 | 3962 4146 3979 | 3962 4146 3980 | 3962 4146 3981 | 3962 4146 3982 | 3962 4146 3983 | 3962 4146 3984 | 3962 4146 3985 | 3962 4146 3986 | 3962 4146 3987 | 3962 4146 3988 | 3962 4146 3989 | 3962 4146 3990 | 3962 4146 3991 | 3962 4146 3992 | 3962 4146 3993 | 3962 4146 3994 | 3962 4146 3995 | 3962 4146 3996 | 3962 4146 3997 | 3962 4146 3998 | 3962 4146 3999 | 3962 4146 4000 | 3962 4146 4001 | 3962 4146 4002 | 3962 4146 4003 | 3962 4146 4004 | 3962 4146 4005 | 3962 4146 4006 | 3962 4146 4007 | 3962 4146 4008 | 3962 4146 4009 | 3962 4146 4010 | 3962 4146 4011 | 3962 4146 4012 | 3962 4146 4013 | 3962 4146 4014 | 3962 4146 4015 | 3962 4146 4016 | 3962 4146 4017 | 3962 4146 4018 | 3962 4146 4019 | 3962 4146 4020 | 3962 4146 4021 | 3962 4146 4022 | 3962 4146 4023 | 3962 4146 4024 | 3962 4146 4025 | 3962 4146 4026 | 3962 4146 4027 | 3962 4146 4028 | 3962 4146 4029 | 3962 4146 4030 | 3962 4146 4031 | 3962 4146 4032 | 3962 4146 4033 | 3962 4146 4034 | 3962 4146 4035 | 3962 4146 4036 | 3962 4146 4037 | 3962 4146 4038 | 3962 4146 4039 | 3962 4146 4040 | 3962 4146 4041 | 3962 4146 4042 | 3962 4146 4043 | 3962 4146 4044 | 3962 4146 4045 | 3962 4146 4046 | 3962 4146 4047 | 3962 4146 4048 | 3962 4146 4049 | 3962 4146 4050 | 3962 4146 4051 | 3962 4146 4052 | 3962 4146 4053 | 3962 4146 4054 | 3962 4146 4055 | 3962 4146 4056 | 3962 4146 4057 | 3962 4146 4058 | 3962 4146 4059 | 3962 4146 4060 | 3962 4146 4061 | 3962 4146 4062 | 3962 4146 4063 | 3962 4146 4064 | 3962 4146 4065 | 3962 4146 4066 | 3962 4146 4067 | 3962 4146 4068 | 3962 4146 4069 | 3962 4146 4070 | 3962 4146 4071 | 3962 4146 4072 | 3962 4146 4073 | 3962 4146 4074 | 3962 4146 4075 | 3962 4146 4076 | 3962 4146 4077 | 3962 4146 4078 | 3962 4146 4079 | 3962 4146 4080 | 3962 4146 4081 | 3962 4146 4082 | 3962 4146 4083 | 3962 4146 4084 | 3962 4146 4085 | 3962 4146 4086 | 3962 4146 4087 | 3962 4146 4088 | 3962 4146 4089 | 3962 4146 4090 | 3962 4146 4091 | 3962 4146 4092 | 3962 4146 4093 | 3962 4146 4094 | 3962 4146 4095 | 3962 4146 4096 | 3962 4146 4097 | 3962 4146 4098 | 3962 4146 4099 | 3962 4146 4100 | 3962 4146 4101 | 3962 4146 4102 | 3962 4146 4103 | 3962 4146 4104 | 3962 4146 4105 | 3962 4146 4106 | 3962 4146 4107 | 3962 4146 4108 | 3962 4146 4109 | 3962 4146 4110 | 3962 4146 4111 | 3962 4146 4112 | 3962 4146 4113 | 3962 4146 4114 | 3962 4146 4115 | 3962 4146 4116 | 3962 4146 4117 | 3962 4146 4118 | 3962 4146 4119 | 3962 4146 4120 | 3962 4146 4121 | 3962 4146 4122 | 3962 4146 4123 | 3962 4146 4124 | 3962 4146 4125 | 3962 4146 4126 | 3962 4146 4127 | 3962 4146 4128 | 3962 4146 4129 | 3962 4146 4130 | 3962 4146 4131 | 3962 4146 4132 | 3962 4146 4133 | 3962 4146 4134 | 3962 4146 4135 | 3962 4146 4136 | 3962 4146 4137 | 3962 4146 4138 | 3962 4146 4139 | 3962 4146 4140 | 3962 4146 4141 | 3962 4146 4142 | 3962 4146 4143 | 3962 4146 4144 | 3962 4146 4145 | 3962 4146 4146 | 3962 4146 4147 | 4146 4149 4148 | 4146 4149 4149 | 4146 4149 4150 | 4149 4155 4151 | 4149 4155 4152 | 4149 4155 4153 | 4149 4155 4154 | 4149 4155 4155 | 4149 4155 4156 | 4155 4178 4157 | 4155 4178 4158 | 4155 4178 4159 | 4155 4178 4160 | 4155 4178 4161 | 4155 4178 4162 | 4155 4178 4163 | 4155 4178 4164 | 4155 4178 4165 | 4155 4178 4166 | 4155 4178 4167 | 4155 4178 4168 | 4155 4178 4169 | 4155 4178 4170 | 4155 4178 4171 | 4155 4178 4172 | 4155 4178 4173 | 4155 4178 4174 | 4155 4178 4175 | 4155 4178 4176 | 4155 4178 4177 | 4155 4178 4178 | 4155 4178 4179 | 4178 4273 4180 | 4178 4273 4181 | 4178 4273 4182 | 4178 4273 4183 | 4178 4273 4184 | 4178 4273 4185 | 4178 4273 4186 | 4178 4273 4187 | 4178 4273 4188 | 4178 4273 4189 | 4178 4273 4190 | 4178 4273 4191 | 4178 4273 4192 | 4178 4273 4193 | 4178 4273 4194 | 4178 4273 4195 | 4178 4273 4196 | 4178 4273 4197 | 4178 4273 4198 | 4178 4273 4199 | 4178 4273 4200 | 4178 4273 4201 | 4178 4273 4202 | 4178 4273 4203 | 4178 4273 4204 | 4178 4273 4205 | 4178 4273 4206 | 4178 4273 4207 | 4178 4273 4208 | 4178 4273 4209 | 4178 4273 4210 | 4178 4273 4211 | 4178 4273 4212 | 4178 4273 4213 | 4178 4273 4214 | 4178 4273 4215 | 4178 4273 4216 | 4178 4273 4217 | 4178 4273 4218 | 4178 4273 4219 | 4178 4273 4220 | 4178 4273 4221 | 4178 4273 4222 | 4178 4273 4223 | 4178 4273 4224 | 4178 4273 4225 | 4178 4273 4226 | 4178 4273 4227 | 4178 4273 4228 | 4178 4273 4229 | 4178 4273 4230 | 4178 4273 4231 | 4178 4273 4232 | 4178 4273 4233 | 4178 4273 4234 | 4178 4273 4235 | 4178 4273 4236 | 4178 4273 4237 | 4178 4273 4238 | 4178 4273 4239 | 4178 4273 4240 | 4178 4273 4241 | 4178 4273 4242 | 4178 4273 4243 | 4178 4273 4244 | 4178 4273 4245 | 4178 4273 4246 | 4178 4273 4247 | 4178 4273 4248 | 4178 4273 4249 | 4178 4273 4250 | 4178 4273 4251 | 4178 4273 4252 | 4178 4273 4253 | 4178 4273 4254 | 4178 4273 4255 | 4178 4273 4256 | 4178 4273 4257 | 4178 4273 4258 | 4178 4273 4259 | 4178 4273 4260 | 4178 4273 4261 | 4178 4273 4262 | 4178 4273 4263 | 4178 4273 4264 | 4178 4273 4265 | 4178 4273 4266 | 4178 4273 4267 | 4178 4273 4268 | 4178 4273 4269 | 4178 4273 4270 | 4178 4273 4271 | 4178 4273 4272 | 4178 4273 4273 | 4178 4273 4274 | 4273 4303 4275 | 4273 4303 4276 | 4273 4303 4277 | 4273 4303 4278 | 4273 4303 4279 | 4273 4303 4280 | 4273 4303 4281 | 4273 4303 4282 | 4273 4303 4283 | 4273 4303 4284 | 4273 4303 4285 | 4273 4303 4286 | 4273 4303 4287 | 4273 4303 4288 | 4273 4303 4289 | 4273 4303 4290 | 4273 4303 4291 | 4273 4303 4292 | 4273 4303 4293 | 4273 4303 4294 | 4273 4303 4295 | 4273 4303 4296 | 4273 4303 4297 | 4273 4303 4298 | 4273 4303 4299 | 4273 4303 4300 | 4273 4303 4301 | 4273 4303 4302 | 4273 4303 4303 | 4273 4303 4304 | 4303 4310 4305 | 4303 4310 4306 | 4303 4310 4307 | 4303 4310 4308 | 4303 4310 4309 | 4303 4310 4310 | 4303 4310 4311 | 4310 4474 4312 | 4310 4474 4313 | 4310 4474 4314 | 4310 4474 4315 | 4310 4474 4316 | 4310 4474 4317 | 4310 4474 4318 | 4310 4474 4319 | 4310 4474 4320 | 4310 4474 4321 | 4310 4474 4322 | 4310 4474 4323 | 4310 4474 4324 | 4310 4474 4325 | 4310 4474 4326 | 4310 4474 4327 | 4310 4474 4328 | 4310 4474 4329 | 4310 4474 4330 | 4310 4474 4331 | 4310 4474 4332 | 4310 4474 4333 | 4310 4474 4334 | 4310 4474 4335 | 4310 4474 4336 | 4310 4474 4337 | 4310 4474 4338 | 4310 4474 4339 | 4310 4474 4340 | 4310 4474 4341 | 4310 4474 4342 | 4310 4474 4343 | 4310 4474 4344 | 4310 4474 4345 | 4310 4474 4346 | 4310 4474 4347 | 4310 4474 4348 | 4310 4474 4349 | 4310 4474 4350 | 4310 4474 4351 | 4310 4474 4352 | 4310 4474 4353 | 4310 4474 4354 | 4310 4474 4355 | 4310 4474 4356 | 4310 4474 4357 | 4310 4474 4358 | 4310 4474 4359 | 4310 4474 4360 | 4310 4474 4361 | 4310 4474 4362 | 4310 4474 4363 | 4310 4474 4364 | 4310 4474 4365 | 4310 4474 4366 | 4310 4474 4367 | 4310 4474 4368 | 4310 4474 4369 | 4310 4474 4370 | 4310 4474 4371 | 4310 4474 4372 | 4310 4474 4373 | 4310 4474 4374 | 4310 4474 4375 | 4310 4474 4376 | 4310 4474 4377 | 4310 4474 4378 | 4310 4474 4379 | 4310 4474 4380 | 4310 4474 4381 | 4310 4474 4382 | 4310 4474 4383 | 4310 4474 4384 | 4310 4474 4385 | 4310 4474 4386 | 4310 4474 4387 | 4310 4474 4388 | 4310 4474 4389 | 4310 4474 4390 | 4310 4474 4391 | 4310 4474 4392 | 4310 4474 4393 | 4310 4474 4394 | 4310 4474 4395 | 4310 4474 4396 | 4310 4474 4397 | 4310 4474 4398 | 4310 4474 4399 | 4310 4474 4400 | 4310 4474 4401 | 4310 4474 4402 | 4310 4474 4403 | 4310 4474 4404 | 4310 4474 4405 | 4310 4474 4406 | 4310 4474 4407 | 4310 4474 4408 | 4310 4474 4409 | 4310 4474 4410 | 4310 4474 4411 | 4310 4474 4412 | 4310 4474 4413 | 4310 4474 4414 | 4310 4474 4415 | 4310 4474 4416 | 4310 4474 4417 | 4310 4474 4418 | 4310 4474 4419 | 4310 4474 4420 | 4310 4474 4421 | 4310 4474 4422 | 4310 4474 4423 | 4310 4474 4424 | 4310 4474 4425 | 4310 4474 4426 | 4310 4474 4427 | 4310 4474 4428 | 4310 4474 4429 | 4310 4474 4430 | 4310 4474 4431 | 4310 4474 4432 | 4310 4474 4433 | 4310 4474 4434 | 4310 4474 4435 | 4310 4474 4436 | 4310 4474 4437 | 4310 4474 4438 | 4310 4474 4439 | 4310 4474 4440 | 4310 4474 4441 | 4310 4474 4442 | 4310 4474 4443 | 4310 4474 4444 | 4310 4474 4445 | 4310 4474 4446 | 4310 4474 4447 | 4310 4474 4448 | 4310 4474 4449 | 4310 4474 4450 | 4310 4474 4451 | 4310 4474 4452 | 4310 4474 4453 | 4310 4474 4454 | 4310 4474 4455 | 4310 4474 4456 | 4310 4474 4457 | 4310 4474 4458 | 4310 4474 4459 | 4310 4474 4460 | 4310 4474 4461 | 4310 4474 4462 | 4310 4474 4463 | 4310 4474 4464 | 4310 4474 4465 | 4310 4474 4466 | 4310 4474 4467 | 4310 4474 4468 | 4310 4474 4469 | 4310 4474 4470 | 4310 4474 4471 | 4310 4474 4472 | 4310 4474 4473 | 4310 4474 4474 | 4310 4474 4475 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import faiss 4 | from tqdm import tqdm 5 | from sklearn.manifold import TSNE 6 | from scipy.spatial.distance import cdist 7 | import numpy as np 8 | import time 9 | 10 | def aggregateMatchScores(dbDesc,qDesc,device,refCandidates=None): 11 | numDb, numQ = dbDesc.shape[0], qDesc.shape[0] 12 | 13 | if refCandidates is None: 14 | shape = [numDb,numQ] 15 | else: 16 | shape = refCandidates.transpose().shape 17 | 18 | dMat_seq = -1*torch.ones(shape,device=device) 19 | 20 | for j in tqdm(range(numQ), total=numQ, leave=True): 21 | t1 = time.time() 22 | if refCandidates is not None: 23 | rCands = refCandidates[j] 24 | else: 25 | rCands = torch.arange(numDb) 26 | for i,r in enumerate(rCands): 27 | dMat = torch.cdist(dbDesc[r].unsqueeze(0),qDesc[j].unsqueeze(0)) 28 | dMat_seq[i,j] = torch.diagonal(dMat,0,1,2).mean(-1) 29 | 30 | return dMat_seq.detach().cpu().numpy() 31 | 32 | def getRecallAtN(n_values, predictions, gt): 33 | correct_at_n = np.zeros(len(n_values)) 34 | numQWithoutGt = 0 35 | #TODO can we do this on the matrix in one go? 36 | for qIx, pred in enumerate(predictions): 37 | if len(gt[qIx]) == 0: 38 | numQWithoutGt += 1 39 | continue 40 | for i,n in enumerate(n_values): 41 | # if in top N then also in top NN, where NN > N 42 | if np.any(np.in1d(pred[:n], gt[qIx])): 43 | correct_at_n[i:] += 1 44 | break 45 | return correct_at_n / (len(gt)-numQWithoutGt) 46 | 47 | def test(opt, model, encoder_dim, device, eval_set, writer, epoch=0, extract_noEval=False): 48 | # TODO what if features dont fit in memory? 49 | test_data_loader = DataLoader(dataset=eval_set, 50 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 51 | pin_memory=not opt.nocuda) 52 | 53 | model.eval() 54 | with torch.no_grad(): 55 | print('====> Extracting Features') 56 | pool_size = encoder_dim 57 | if opt.pooling.lower() == 'seqnet': 58 | pool_size = opt.outDims 59 | if 'seqmatch' in opt.pooling.lower(): 60 | dbFeat = torch.empty((len(eval_set), opt.seqL, pool_size),device=device) 61 | else: 62 | dbFeat = torch.empty((len(eval_set), pool_size),device=device) 63 | 64 | durs_batch = [] 65 | for iteration, (input, indices) in tqdm(enumerate(test_data_loader, 1),total=len(test_data_loader)-1, leave=False): 66 | t1 = time.time() 67 | input = input.float().to(device) 68 | if opt.pooling.lower() == 's1+seqmatch': 69 | shapeOrig = input.shape 70 | input = input.reshape([-1,input.shape[-1]]) 71 | seq_encoding = model.pool(input).reshape(shapeOrig) 72 | else: 73 | seq_encoding = model.pool(input) 74 | if 'seqmatch' in opt.pooling.lower(): 75 | dbFeat[indices,:,:] = seq_encoding 76 | else: 77 | dbFeat[indices, :] = seq_encoding 78 | if iteration % 50 == 0 or len(test_data_loader) <= 10: 79 | print("==> Batch ({}/{})".format(iteration, 80 | len(test_data_loader)), flush=True) 81 | durs_batch.append(time.time()-t1) 82 | del input 83 | del test_data_loader 84 | print("Average batch time:", np.mean(durs_batch), np.std(durs_batch)) 85 | 86 | # extracted for both db and query, now split in own sets 87 | qFeat = dbFeat[eval_set.dbStruct.numDb:] 88 | dbFeat = dbFeat[:eval_set.dbStruct.numDb] 89 | print(dbFeat.shape, qFeat.shape) 90 | 91 | qFeat_np = qFeat.detach().cpu().numpy().astype('float32') 92 | dbFeat_np = dbFeat.detach().cpu().numpy().astype('float32') 93 | 94 | db_emb, q_emb = None, None 95 | if opt.numSamples2Project != -1 and writer is not None: 96 | db_emb = TSNE(n_components=2).fit_transform(dbFeat_np[:opt.numSamples2Project]) 97 | q_emb = TSNE(n_components=2).fit_transform(qFeat_np[:opt.numSamples2Project]) 98 | 99 | if extract_noEval: 100 | return np.vstack([dbFeat_np,qFeat_np]), db_emb, q_emb, None, None 101 | 102 | n_values = [1,5,10,20,100] 103 | 104 | if 'seqmatch' in opt.pooling.lower(): 105 | print('====> Performing sequence score aggregation') 106 | if opt.predictionsFile is not None: 107 | predPrior = np.load(opt.predictionsFile)['preds'] 108 | predPriorTopK = predPrior[:,:20] 109 | else: 110 | predPriorTopK = None 111 | dMatSeq = aggregateMatchScores(dbFeat,qFeat,device,refCandidates=predPriorTopK) 112 | predictions = np.argsort(dMatSeq,axis=0)[:max(n_values),:].transpose() 113 | bestDists = dMatSeq[predictions[:,0],np.arange(dMatSeq.shape[1])] 114 | if opt.predictionsFile is not None: 115 | predictions = np.array([predPriorTopK[qIdx][predictions[qIdx]] for qIdx in range(predictions.shape[0])]) 116 | print("Preds:",predictions.shape) 117 | else: 118 | print('====> Building faiss index') 119 | faiss_index = faiss.IndexFlatL2(pool_size) 120 | faiss_index.add(dbFeat_np) 121 | 122 | distances, predictions = faiss_index.search(qFeat_np, max(n_values)) 123 | bestDists = distances[:,0] 124 | 125 | print('====> Calculating recall @ N') 126 | 127 | # for each query get those within threshold distance 128 | gt,gtDists = eval_set.get_positives(retDists=True) 129 | gtDistsMat = cdist(eval_set.dbStruct.utmDb,eval_set.dbStruct.utmQ) 130 | 131 | # compute recall for different loc radii 132 | rAtL = [] 133 | for locRad in [1,5,10,20,40,100,200]: 134 | gtAtL = gtDistsMat <= locRad 135 | gtAtL = [np.argwhere(gtAtL[:,qIx]).flatten() for qIx in range(gtDistsMat.shape[1])] 136 | rAtL.append(getRecallAtN(n_values, predictions, gtAtL)) 137 | 138 | recall_at_n = getRecallAtN(n_values, predictions, gt) 139 | 140 | recalls = {} #make dict for output 141 | for i,n in enumerate(n_values): 142 | recalls[n] = recall_at_n[i] 143 | print("====> Recall@{}: {:.4f}".format(n, recall_at_n[i])) 144 | if writer is not None: writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i], epoch) 145 | 146 | return recalls, db_emb, q_emb, rAtL, predictions 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataset import Subset 4 | from tqdm import tqdm 5 | import numpy as np 6 | from os.path import join 7 | from os import remove 8 | import h5py 9 | from math import ceil 10 | 11 | def train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer): 12 | epoch_loss = 0 13 | startIter = 1 # keep track of batch iter across subsets for logging 14 | 15 | if opt.cacheRefreshRate > 0: 16 | subsetN = ceil(len(train_set) / opt.cacheRefreshRate) 17 | #TODO randomise the arange before splitting? 18 | subsetIdx = np.array_split(np.arange(len(train_set)), subsetN) 19 | else: 20 | subsetN = 1 21 | subsetIdx = [np.arange(len(train_set))] 22 | 23 | nBatches = (len(train_set) + opt.batchSize - 1) // opt.batchSize 24 | 25 | for subIter in range(subsetN): 26 | print('====> Building Cache') 27 | model.eval() 28 | with h5py.File(train_set.cache, mode='w') as h5: 29 | pool_size = encoder_dim 30 | if opt.pooling.lower() == 'seqnet': 31 | pool_size = opt.outDims 32 | h5feat = h5.create_dataset("features", [len(whole_train_set), pool_size], dtype=np.float32) 33 | with torch.no_grad(): 34 | for iteration, (input, indices) in tqdm(enumerate(whole_training_data_loader, 1),total=len(whole_training_data_loader)-1, leave=False): 35 | image_encoding = (input).float().to(device) 36 | seq_encoding = model.pool(image_encoding) 37 | h5feat[indices.detach().numpy(), :] = seq_encoding.detach().cpu().numpy() 38 | del input, image_encoding, seq_encoding 39 | 40 | sub_train_set = Subset(dataset=train_set, indices=subsetIdx[subIter]) 41 | 42 | training_data_loader = DataLoader(dataset=sub_train_set, num_workers=opt.threads, 43 | batch_size=opt.batchSize, shuffle=True, 44 | collate_fn=dataset.collate_fn, pin_memory=not opt.nocuda) 45 | 46 | print('Allocated:', torch.cuda.memory_allocated()) 47 | print('Cached:', torch.cuda.memory_reserved()) 48 | 49 | model.train() 50 | for iteration, (query, positives, negatives, 51 | negCounts, indices) in tqdm(enumerate(training_data_loader, startIter),total=len(training_data_loader),leave=False): 52 | loss = 0 53 | if query is None: 54 | continue # in case we get an empty batch 55 | 56 | B = query.shape[0] 57 | nNeg = torch.sum(negCounts) 58 | 59 | input = torch.cat([query,positives,negatives]).float() 60 | input = input.to(device) 61 | seq_encoding = model.pool(input) 62 | 63 | seqQ, seqP, seqN = torch.split(seq_encoding, [B, B, nNeg]) 64 | 65 | optimizer.zero_grad() 66 | 67 | # calculate loss for each Query, Positive, Negative triplet 68 | # due to potential difference in number of negatives have to 69 | # do it per query, per negative 70 | for i, negCount in enumerate(negCounts): 71 | for n in range(negCount): 72 | negIx = (torch.sum(negCounts[:i]) + n).item() 73 | loss += criterion(seqQ[i:i+1], seqP[i:i+1], seqN[negIx:negIx+1]) 74 | 75 | loss /= nNeg.float().to(device) # normalise by actual number of negatives 76 | loss.backward() 77 | optimizer.step() 78 | del input, seq_encoding, seqQ, seqP, seqN 79 | del query, positives, negatives 80 | 81 | batch_loss = loss.item() 82 | epoch_loss += batch_loss 83 | 84 | if iteration % 50 == 0 or nBatches <= 10: 85 | print("==> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, 86 | nBatches, batch_loss), flush=True) 87 | writer.add_scalar('Train/Loss', batch_loss, 88 | ((epoch-1) * nBatches) + iteration) 89 | writer.add_scalar('Train/nNeg', nNeg, 90 | ((epoch-1) * nBatches) + iteration) 91 | print('Allocated:', torch.cuda.memory_allocated()) 92 | print('Cached:', torch.cuda.memory_cached()) 93 | 94 | startIter += len(training_data_loader) 95 | del training_data_loader, loss 96 | optimizer.zero_grad() 97 | torch.cuda.empty_cache() 98 | remove(train_set.cache) # delete HDF5 cache 99 | 100 | avg_loss = epoch_loss / nBatches 101 | 102 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, avg_loss), 103 | flush=True) 104 | writer.add_scalar('Train/AvgLoss', avg_loss, epoch) 105 | 106 | --------------------------------------------------------------------------------