├── .gitignore ├── README.md ├── example_data ├── example.png └── original_data │ ├── img │ ├── 03.tif │ └── 21.tif │ └── shp │ ├── 21 │ ├── 21.cpg │ ├── 21.dbf │ ├── 21.prj │ ├── 21.sbn │ ├── 21.sbx │ ├── 21.shp │ └── 21.shx │ └── 03 │ ├── 03.cpg │ ├── 03.dbf │ ├── 03.prj │ ├── 03.sbn │ ├── 03.sbx │ ├── 03.shp │ └── 03.shx ├── pycococreatortools └── pycococreatortools.py ├── shape_to_coco.py ├── slice_dataset.py ├── tif_process.py └── visualize_coco.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # shp2coco 2 | shp2coco is a tool to help create `COCO` datasets from `.shp` file (ArcGIS format).
3 | 4 | It includes:
5 | 1:mask tif with shape file.
6 | 2:crop tif and mask.
7 | 3:slice the dataset into training, eval and test subset.
8 | 4:generate annotations in uncompressed RLE ("crowd") and polygons in the format COCO requires.
9 | 10 | This project is based on [geotool](https://github.com/Kindron/geotool) and [pycococreator](https://github.com/waspinator/pycococreator) 11 | 12 | ## Usage: 13 | If you need to generate annotations in the COCO format, try the following:
14 | `python shape_to_coco.py`
15 | If you need to visualize annotations, try the following:
16 | `python visualize_coco.py`
17 | 18 | ## Example: 19 | ![example](https://github.com/DuncanChen2018/shp2coco/blob/master/example_data/example.png) 20 | 21 | ## Thanks to the Third Party Libs 22 | [geotool](https://github.com/Kindron/geotool)
23 | [pycococreator](https://github.com/waspinator/pycococreator)
24 | -------------------------------------------------------------------------------- /example_data/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/example.png -------------------------------------------------------------------------------- /example_data/original_data/img/03.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/img/03.tif -------------------------------------------------------------------------------- /example_data/original_data/img/21.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/img/21.tif -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.dbf: -------------------------------------------------------------------------------- 1 | w^AIdN 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.sbn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.sbn -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.sbx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.sbx -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.shp -------------------------------------------------------------------------------- /example_data/original_data/shp/03/03.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.shx -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.dbf -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.sbn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.sbn -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.sbx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.sbx -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.shp -------------------------------------------------------------------------------- /example_data/original_data/shp/21/21.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.shx -------------------------------------------------------------------------------- /pycococreatortools/pycococreatortools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import re 5 | import datetime 6 | import numpy as np 7 | from itertools import groupby 8 | from skimage import measure 9 | from PIL import Image 10 | from pycocotools import mask 11 | 12 | convert = lambda text: int(text) if text.isdigit() else text.lower() 13 | natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 14 | 15 | def resize_binary_mask(array, new_size): 16 | image = Image.fromarray(array.astype(np.uint8)*255) 17 | image = image.resize(new_size) 18 | return np.asarray(image).astype(np.bool_) 19 | 20 | def close_contour(contour): 21 | if not np.array_equal(contour[0], contour[-1]): 22 | contour = np.vstack((contour, contour[0])) 23 | return contour 24 | 25 | def binary_mask_to_rle(binary_mask): 26 | rle = {'counts': [], 'size': list(binary_mask.shape)} 27 | counts = rle.get('counts') 28 | for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order='F'))): 29 | if i == 0 and value == 1: 30 | counts.append(0) 31 | counts.append(len(list(elements))) 32 | 33 | return rle 34 | 35 | def binary_mask_to_polygon(binary_mask, tolerance=0): 36 | """Converts a binary mask to COCO polygon representation 37 | 38 | Args: 39 | binary_mask: a 2D binary numpy array where '1's represent the object 40 | tolerance: Maximum distance from original points of polygon to approximated 41 | polygonal chain. If tolerance is 0, the original coordinate array is returned. 42 | 43 | """ 44 | polygons = [] 45 | # pad mask to close contours of shapes which start and end at an edge 46 | padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0) 47 | contours = measure.find_contours(padded_binary_mask, 0.5) 48 | contours = np.subtract(contours, 1) 49 | for contour in contours: 50 | contour = close_contour(contour) 51 | contour = measure.approximate_polygon(contour, tolerance) 52 | if len(contour) < 3: 53 | continue 54 | contour = np.flip(contour, axis=1) 55 | segmentation = contour.ravel().tolist() 56 | # after padding and subtracting 1 we may get -0.5 points in our segmentation 57 | segmentation = [0 if i < 0 else i for i in segmentation] 58 | polygons.append(segmentation) 59 | 60 | return polygons 61 | 62 | def create_image_info(image_id, file_name, image_size, 63 | date_captured=datetime.datetime.utcnow().isoformat(' '), 64 | license_id=1, coco_url="", flickr_url=""): 65 | 66 | image_info = { 67 | "id": image_id, 68 | "file_name": file_name, 69 | "width": image_size[0], 70 | "height": image_size[1], 71 | "date_captured": date_captured, 72 | "license": license_id, 73 | "coco_url": coco_url, 74 | "flickr_url": flickr_url 75 | } 76 | 77 | return image_info 78 | 79 | def create_annotation_info(annotation_id, image_id, category_info, binary_mask, 80 | image_size=None, tolerance=2, bounding_box=None): 81 | 82 | if image_size is not None: 83 | binary_mask = resize_binary_mask(binary_mask, image_size) 84 | 85 | binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 86 | 87 | area = mask.area(binary_mask_encoded) 88 | if area < 1: 89 | return None 90 | 91 | if bounding_box is None: 92 | bounding_box = mask.toBbox(binary_mask_encoded) 93 | 94 | if category_info["is_crowd"]: 95 | is_crowd = 1 96 | segmentation = binary_mask_to_rle(binary_mask) 97 | else : 98 | is_crowd = 0 99 | segmentation = binary_mask_to_polygon(binary_mask, tolerance) 100 | if not segmentation: 101 | return None 102 | 103 | annotation_info = { 104 | "id": annotation_id, 105 | "image_id": image_id, 106 | "category_id": category_info["id"], 107 | "iscrowd": is_crowd, 108 | "area": area.tolist(), 109 | "bbox": bounding_box.tolist(), 110 | "segmentation": segmentation, 111 | "width": binary_mask.shape[1], 112 | "height": binary_mask.shape[0], 113 | } 114 | 115 | return annotation_info 116 | -------------------------------------------------------------------------------- /shape_to_coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import datetime 4 | import json 5 | import os 6 | import re 7 | import fnmatch 8 | from PIL import Image 9 | import numpy as np 10 | from pycococreatortools import pycococreatortools 11 | from tif_process import * 12 | from slice_dataset import slice 13 | 14 | # root path for saving the tif and shp file. 15 | ROOT = r'./example_data/original_data' 16 | img_path = 'img' 17 | shp_path = 'shp' 18 | # root path for saving the mask. 19 | ROOT_DIR = ROOT + '/dataset' 20 | IMAGE_DIR = os.path.join(ROOT_DIR, "greenhouse_2019") 21 | ANNOTATION_DIR = os.path.join(ROOT_DIR, "annotations") 22 | 23 | clip_size = 512 24 | 25 | INFO = { 26 | "description": "Greenhouse Dataset", 27 | "url": "", 28 | "version": "0.1.0", 29 | "year": 2019, 30 | "contributor": "DuncanChen", 31 | "date_created": datetime.datetime.utcnow().isoformat(' ') 32 | } 33 | 34 | LICENSES = [ 35 | { 36 | "id": 1, 37 | "name": "", 38 | "url": "" 39 | } 40 | ] 41 | 42 | CATEGORIES = [ 43 | { 44 | 'id': 1, 45 | 'name': 'greenhouse', 46 | 'supercategory': 'building', 47 | }, 48 | ] 49 | 50 | def filter_for_jpeg(root, files): 51 | # file_types = ['*.jpeg', '*.jpg'] 52 | file_types = ['*.tiff', '*.tif'] 53 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types]) 54 | files = [os.path.join(root, f) for f in files] 55 | files = [f for f in files if re.match(file_types, f)] 56 | 57 | return files 58 | 59 | def filter_for_annotations(root, files, image_filename): 60 | # file_types = ['*.png'] 61 | file_types = ['*.tif'] 62 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types]) 63 | basename_no_extension = os.path.splitext(os.path.basename(image_filename))[0] 64 | # file_name_prefix = basename_no_extension + '.*' 65 | files = [os.path.join(root, f) for f in files] 66 | files = [f for f in files if re.match(file_types, f)] 67 | # files = [f for f in files if re.match(file_name_prefix, os.path.splitext(os.path.basename(f))[0])] 68 | files = [f for f in files if basename_no_extension == os.path.splitext(os.path.basename(f))[0].split('_', 1)[0]] 69 | 70 | return files 71 | 72 | def from_mask_to_coco(root, MARK, IMAGE, ANNOTATION): 73 | ROOT_DIR = root + '/' + MARK 74 | IMAGE_DIR = ROOT_DIR + '/' + IMAGE 75 | ANNOTATION_DIR = ROOT_DIR + '/' + ANNOTATION 76 | if os.path.exists(ROOT_DIR): 77 | coco_output = { 78 | "info": INFO, 79 | "licenses": LICENSES, 80 | "categories": CATEGORIES, 81 | "images": [], 82 | "annotations": [] 83 | } 84 | 85 | image_id = 1 86 | segmentation_id = 1 87 | 88 | # filter for jpeg images 89 | for root, _, files in os.walk(IMAGE_DIR): 90 | image_files = filter_for_jpeg(root, files) 91 | 92 | # go through each image 93 | for image_filename in image_files: 94 | image = Image.open(image_filename) 95 | image_info = pycococreatortools.create_image_info( 96 | image_id, os.path.basename(image_filename), image.size) 97 | coco_output["images"].append(image_info) 98 | 99 | # filter for associated png annotations 100 | for root, _, files in os.walk(ANNOTATION_DIR): 101 | annotation_files = filter_for_annotations(root, files, image_filename) 102 | 103 | # go through each associated annotation 104 | for annotation_filename in annotation_files: 105 | 106 | print(annotation_filename) 107 | class_id = [x['id'] for x in CATEGORIES if x['name'] in annotation_filename][0] 108 | 109 | category_info = {'id': class_id, 'is_crowd': 'crowd' in image_filename} 110 | binary_mask = np.asarray(Image.open(annotation_filename) 111 | .convert('1')).astype(np.uint8) 112 | 113 | annotation_info = pycococreatortools.create_annotation_info( 114 | segmentation_id, image_id, category_info, binary_mask, 115 | image.size, tolerance=2) 116 | 117 | if annotation_info is not None: 118 | coco_output["annotations"].append(annotation_info) 119 | 120 | segmentation_id = segmentation_id + 1 121 | 122 | image_id = image_id + 1 123 | 124 | with open('{}/instances_greenhouse_{}2019.json'.format(ROOT_DIR, MARK), 'w') as output_json_file: 125 | json.dump(coco_output, output_json_file) 126 | else: 127 | print(ROOT_DIR + ' does not exit!') 128 | 129 | def main(): 130 | clip_from_file(clip_size, ROOT, img_path, shp_path) 131 | slice(ROOT_DIR, train=0.6, eval=0.2, test=0.2) 132 | from_mask_to_coco(ROOT_DIR, 'train', "greenhouse_2019", "annotations") 133 | from_mask_to_coco(ROOT_DIR, 'eval', "greenhouse_2019", "annotations") 134 | from_mask_to_coco(ROOT_DIR, 'test', "greenhouse_2019", "annotations") 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /slice_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import shutil 4 | import re 5 | import fnmatch 6 | 7 | ann_path = 'annotations' 8 | img_path = 'greenhouse_2019' 9 | 10 | def filter_for_annotations(root, files, image_filename): 11 | # file_types = ['*.png'] 12 | file_types = ['*.tif'] 13 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types]) 14 | basename_no_extension = os.path.splitext(os.path.basename(image_filename))[0] 15 | # file_name_prefix = basename_no_extension + '.*' 16 | files = [os.path.join(root, f) for f in files] 17 | files = [f for f in files if re.match(file_types, f)] 18 | # files = [f for f in files if re.match(file_name_prefix, os.path.splitext(os.path.basename(f))[0])] 19 | files = [f for f in files if basename_no_extension == os.path.splitext(os.path.basename(f))[0].split('_', 1)[0]] 20 | 21 | return files 22 | 23 | def copy_data(input_path, id, num, mark = 'train'): 24 | if num != 0: 25 | list = os.listdir(input_path + '/' + img_path) 26 | ann_list = os.listdir(input_path + '/' + ann_path) 27 | if not os.path.isdir(input_path + '/' + mark + '/' + img_path): 28 | os.makedirs(input_path + '/' + mark + '/' + img_path) 29 | if not os.path.isdir(input_path + '/' + mark + '/' + ann_path): 30 | os.makedirs(input_path + '/' + mark + '/' + ann_path) 31 | 32 | for i in range(num): 33 | shutil.copy(input_path + '/' + img_path + '/' + list[id[i]], input_path + '/' + mark + '/' + img_path 34 | + '/' + list[id[i]]) 35 | print('From src: ' + img_path + '/' + list[id[i]] + ' =>dst:' + '/' + mark + '/' + img_path 36 | + '/' + list[id[i]]) 37 | annotation_files = filter_for_annotations(input_path, ann_list, list[id[i]]) 38 | for j in range(len(annotation_files)): 39 | shutil.copy(input_path + '/' + ann_path + '/' + os.path.basename(annotation_files[j]), 40 | input_path + '/' + mark + '/' + ann_path + '/' + os.path.basename(annotation_files[j])) 41 | 42 | f = open(input_path + '/' + mark + '/' + mark + '.txt', 'w') 43 | f.write(str(id)) 44 | f.close() 45 | 46 | def slice(input_path, train=0.8, eval=0.2, test=0.0): 47 | """ 48 | slice the dataset into training, eval and test sub_dataset. 49 | :param input_path: path to the original dataset. 50 | :param train: the ratio of the training subset. 51 | :param eval: the ratio of the eval subset. 52 | :param test: the ratio of the test subset. 53 | """ 54 | list = os.listdir(input_path + '/' + img_path) 55 | ann_list = os.listdir(input_path + '/' + ann_path) 56 | num_list = len(list) 57 | n_train = int(num_list * train) 58 | if test == 0: 59 | n_eval = num_list - n_train 60 | n_test = 0 61 | else: 62 | n_eval = int(num_list * eval) 63 | n_test = num_list - n_train - n_eval 64 | 65 | img_id = np.arange(num_list) 66 | np.random.shuffle(img_id) 67 | train_id, eval_id, test_id = img_id[:n_train], img_id[n_train: n_train+n_eval], img_id[n_train+n_eval:] 68 | copy_data(input_path, train_id, n_train, 'train') 69 | copy_data(input_path, eval_id, n_eval, 'eval') 70 | copy_data(input_path, test_id, n_test, 'test') 71 | 72 | if __name__ == '__main__': 73 | input_path = r'./example_data/original_data/dataset' 74 | # slice(input_path, train=0.6, eval=0.2, test=0.2) 75 | slice(input_path) 76 | -------------------------------------------------------------------------------- /tif_process.py: -------------------------------------------------------------------------------- 1 | # Date:2019.04.10 2 | # Author: DuncanChen 3 | # A tool implementation on gdal and geotool API 4 | # functions: 5 | # 1. get mask raster with shapefile 6 | # 2. clip raster and shapefile with grid 7 | 8 | from PIL import Image, ImageDraw 9 | import os 10 | from osgeo import gdal, gdalnumeric 11 | import numpy as np 12 | import ogr 13 | import glob 14 | gdal.UseExceptions() 15 | 16 | 17 | class GeoTiff(object): 18 | def __init__(self, tif_path): 19 | """ 20 | A tool for Remote Sensing Image 21 | Args: 22 | tif_path: tif path 23 | Examples:: 24 | >>> tif = GeoTif('xx.tif') 25 | # if you want to clip tif with grid reserved geo reference 26 | >>> tif.clip_tif_with_grid(512, 'out_dir') 27 | # if you want to clip tif with shape file 28 | >>> tif.clip_tif_with_shapefile('shapefile.shp', 'save_path.tif') 29 | # if you want to mask tif with shape file 30 | >>> tif.mask_tif_with_shapefile('shapefile.shp', 'save_path.tif') 31 | """ 32 | self.dataset = gdal.Open(tif_path) 33 | self.bands_count = self.dataset.RasterCount 34 | # get each band 35 | self.bands = [self.dataset.GetRasterBand(i + 1) for i in range(self.bands_count)] 36 | self.col = self.dataset.RasterXSize 37 | self.row = self.dataset.RasterYSize 38 | self.geotransform = self.dataset.GetGeoTransform() 39 | self.src_path = tif_path 40 | self.mask = None 41 | self.mark = None 42 | 43 | def get_left_top(self): 44 | return self.geotransform[3], self.geotransform[0] 45 | 46 | def get_pixel_height_width(self): 47 | return abs(self.geotransform[5]), abs(self.geotransform[1]) 48 | 49 | def __getitem__(self, *args): 50 | """ 51 | 52 | Args: 53 | *args: range, an instance of tuple, ((start, stop, step), (start, stop, step)) 54 | 55 | Returns: 56 | res: image block , array ,[bands......, height, weight] 57 | 58 | """ 59 | if isinstance(args[0], tuple) and len(args[0]) == 2: 60 | # get params 61 | start_row, end_row = args[0][0].start, args[0][0].stop 62 | start_col, end_col = args[0][1].start, args[0][1].stop 63 | start_row = 0 if start_row is None else start_row 64 | start_col = 0 if start_col is None else start_col 65 | num_row = self.row if end_row is None else (end_row - start_row) 66 | num_col = self.col if end_col is None else (end_col - start_col) 67 | # dataset read image array 68 | res = self.dataset.ReadAsArray(start_col, start_row, num_col, num_row) 69 | return res 70 | else: 71 | raise NotImplementedError('the param should be [a: b, c: d] !') 72 | 73 | def clip_tif_with_grid(self, clip_size, begin_id, out_dir): 74 | """ 75 | clip image with grid 76 | Args: 77 | clip_size: int 78 | out_dir: str 79 | 80 | Returns: 81 | 82 | """ 83 | if not os.path.exists(out_dir): 84 | # check the dir 85 | os.makedirs(out_dir) 86 | print('create dir', out_dir) 87 | 88 | row_num = int(self.row / clip_size) 89 | col_num = int(self.col / clip_size) 90 | 91 | gtiffDriver = gdal.GetDriverByName('GTiff') 92 | if gtiffDriver is None: 93 | raise ValueError("Can't find GeoTiff Driver") 94 | 95 | count = 1 96 | for i in range(row_num): 97 | for j in range(col_num): 98 | # if begin_id+i*col_num+j in self.mark: 99 | # continue 100 | clipped_image = np.array(self[i * clip_size: (i + 1) * clip_size, j * clip_size: (j + 1) * clip_size]) 101 | clipped_image = clipped_image.astype(np.int8) 102 | 103 | try: 104 | save_path = os.path.join(out_dir, '%d.tif' % (begin_id+i*col_num+j)) 105 | save_image_with_georef(clipped_image, gtiffDriver, 106 | self.dataset, j*clip_size, i*clip_size, save_path) 107 | print('clip successfully!(%d/%d)' % (count, row_num * col_num)) 108 | count += 1 109 | except Exception: 110 | raise IOError('clip failed!%d' % count) 111 | 112 | return row_num * col_num 113 | 114 | def clip_mask_with_grid(self, clip_size, begin_id, out_dir): 115 | """ 116 | clip mask with grid 117 | Args: 118 | clip_size: int 119 | out_dir: str 120 | 121 | Returns: 122 | 123 | """ 124 | if not os.path.exists(out_dir): 125 | # check the dir 126 | os.makedirs(out_dir) 127 | print('create dir', out_dir) 128 | 129 | row_num = int(self.row / clip_size) 130 | col_num = int(self.col / clip_size) 131 | 132 | gtiffDriver = gdal.GetDriverByName('GTiff') 133 | if gtiffDriver is None: 134 | raise ValueError("Can't find GeoTiff Driver") 135 | 136 | # self.mark = [] 137 | 138 | count = 1 139 | for i in range(row_num): 140 | for j in range(col_num): 141 | clipped_image = np.array(self.mask[0, i * clip_size: (i + 1) * clip_size, j * clip_size: (j + 1) * clip_size]) 142 | ins_list = np.unique(clipped_image) 143 | # if len(ins_list) <= 1: 144 | # self.mark.append(begin_id+i*col_num+j) 145 | # continue 146 | ins_list = ins_list[1:] 147 | for id in range(len(ins_list)): 148 | bg_img = np.zeros((clipped_image.shape)).astype(np.int8) 149 | if ins_list[id] > 0: 150 | bg_img[np.where(clipped_image == ins_list[id])] = 255 151 | try: 152 | save_path = os.path.join(out_dir, '%d_%s_%d.tif' % (begin_id+i*col_num+j, 'greenhouse', id)) 153 | save_image_with_georef(bg_img, gtiffDriver, 154 | self.dataset, j*clip_size, i*clip_size, save_path) 155 | print('clip mask successfully!(%d/%d)' % (count, row_num * col_num)) 156 | count += 1 157 | except Exception: 158 | raise IOError('clip failed!%d' % count) 159 | 160 | def world2Pixel(self, x, y): 161 | """ 162 | Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate 163 | the pixel location of a geospatial coordinate 164 | """ 165 | ulY, ulX = self.get_left_top() 166 | distY, distX = self.get_pixel_height_width() 167 | 168 | pixel_x = abs(int((x - ulX) / distX)) 169 | pixel_y = abs(int((ulY - y) / distY)) 170 | pixel_y = self.row if pixel_y > self.row else pixel_y 171 | pixel_x = self.col if pixel_x > self.col else pixel_x 172 | return pixel_x, pixel_y 173 | 174 | def mask_tif_with_shapefile(self, shapefile_path, label=255): 175 | """ 176 | mask tif with shape file, supported point, line, polygon and multi polygons 177 | Args: 178 | shapefile_path: 179 | save_path: 180 | label: 181 | 182 | Returns: 183 | 184 | """ 185 | driver = ogr.GetDriverByName('ESRI Shapefile') 186 | dataSource = driver.Open(shapefile_path, 0) 187 | if dataSource is None: 188 | raise IOError('could not open!') 189 | gtiffDriver = gdal.GetDriverByName('GTiff') 190 | if gtiffDriver is None: 191 | raise ValueError("Can't find GeoTiff Driver") 192 | 193 | layer = dataSource.GetLayer(0) 194 | # # Convert the layer extent to image pixel coordinates 195 | minX, maxX, minY, maxY = layer.GetExtent() 196 | ulX, ulY = self.world2Pixel(minX, maxY) 197 | 198 | # initialize mask drawing 199 | rasterPoly = Image.new("I", (self.col, self.row), 0) 200 | rasterize = ImageDraw.Draw(rasterPoly) 201 | 202 | feature_num = layer.GetFeatureCount() # get poly count 203 | for i in range(feature_num): 204 | points = [] # store points 205 | pixels = [] # store pixels 206 | feature = layer.GetFeature(i) 207 | geom = feature.GetGeometryRef() 208 | feature_type = geom.GetGeometryName() 209 | 210 | if feature_type == 'POLYGON' or 'MULTIPOLYGON': 211 | # multi polygon operation 212 | # 1. use label to mask the max polygon 213 | # 2. use -label to mask the other polygon 214 | for j in range(geom.GetGeometryCount()): 215 | sub_polygon = geom.GetGeometryRef(j) 216 | if feature_type == 'MULTIPOLYGON': 217 | sub_polygon = sub_polygon.GetGeometryRef(0) 218 | for p_i in range(sub_polygon.GetPointCount()): 219 | px = sub_polygon.GetX(p_i) 220 | py = sub_polygon.GetY(p_i) 221 | points.append((px, py)) 222 | 223 | for p in points: 224 | origin_pixel_x, origin_pixel_y = self.world2Pixel(p[0], p[1]) 225 | # the pixel in new image 226 | new_pixel_x, new_pixel_y = origin_pixel_x, origin_pixel_y 227 | pixels.append((new_pixel_x, new_pixel_y)) 228 | 229 | rasterize.polygon(pixels, i+1) 230 | pixels = [] 231 | points = [] 232 | if feature_type != 'MULTIPOLYGON': 233 | label = -abs(label) 234 | 235 | # restore the label value 236 | label = abs(label) 237 | else: 238 | for j in range(geom.GetPointCount()): 239 | px = geom.GetX(j) 240 | py = geom.GetY(j) 241 | points.append((px, py)) 242 | 243 | for p in points: 244 | origin_pixel_x, origin_pixel_y = self.world2Pixel(p[0], p[1]) 245 | # the pixel in new image 246 | new_pixel_x, new_pixel_y = origin_pixel_x, origin_pixel_y 247 | pixels.append((new_pixel_x, new_pixel_y)) 248 | 249 | feature.Destroy() # delete feature 250 | 251 | if feature_type == 'LINESTRING': 252 | rasterize.line(pixels, i+1) 253 | if feature_type == 'POINT': 254 | # pixel x, y 255 | rasterize.point(pixels, i+1) 256 | 257 | mask = np.array(rasterPoly) 258 | self.mask = mask[np.newaxis, :] # extend an axis to three 259 | 260 | def clip_tif_and_shapefile(self, clip_size, begin_id, shapefile_path, out_dir): 261 | self.mask_tif_with_shapefile(shapefile_path) 262 | self.clip_mask_with_grid(clip_size=clip_size, begin_id=begin_id, out_dir=out_dir + '/annotations') 263 | pic_id = self.clip_tif_with_grid(clip_size=clip_size, begin_id=begin_id, out_dir=out_dir + '/greenhouse_2019') 264 | return pic_id 265 | 266 | def channel_first_to_last(image): 267 | """ 268 | 269 | Args: 270 | image: 3-D numpy array of shape [channel, width, height] 271 | 272 | Returns: 273 | new_image: 3-D numpy array of shape [height, width, channel] 274 | """ 275 | new_image = np.transpose(image, axes=[1, 2, 0]) 276 | return new_image 277 | 278 | def channel_last_to_first(image): 279 | """ 280 | 281 | Args: 282 | image: 3-D numpy array of shape [channel, width, height] 283 | 284 | Returns: 285 | new_image: 3-D numpy array of shape [height, width, channel] 286 | """ 287 | new_image = np.transpose(image, axes=[2, 0, 1]) 288 | return new_image 289 | 290 | def save_image_with_georef(image, driver, original_ds, offset_x=0, offset_y=0, save_path=None): 291 | """ 292 | 293 | Args: 294 | save_path: str, image save path 295 | driver: gdal IO driver 296 | image: an instance of ndarray 297 | original_ds: a instance of data set 298 | offset_x: x location in data set 299 | offset_y: y location in data set 300 | 301 | Returns: 302 | 303 | """ 304 | # get Geo Reference 305 | ds = gdalnumeric.OpenArray(image) 306 | gdalnumeric.CopyDatasetInfo(original_ds, ds, xoff=offset_x, yoff=offset_y) 307 | driver.CreateCopy(save_path, ds) 308 | # write by band 309 | clip = image.astype(np.int8) 310 | # write the dataset 311 | if len(image.shape)==3: 312 | for i in range(image.shape[0]): 313 | ds.GetRasterBand(i + 1).WriteArray(clip[i]) 314 | else: 315 | ds.GetRasterBand(1).WriteArray(clip) 316 | del ds 317 | 318 | def define_ref_predict(tif_dir, mask_dir, save_dir): 319 | """ 320 | define reference for raster referred to a geometric raster. 321 | Args: 322 | tif_dir: the dir to save referenced raster 323 | mask_dir: 324 | save_dir: 325 | 326 | Returns: 327 | 328 | """ 329 | tif_list = glob.glob(os.path.join(tif_dir, '*.tif')) 330 | 331 | mask_list = glob.glob(os.path.join(mask_dir, '*.png')) 332 | mask_list += (glob.glob(os.path.join(mask_dir, '*.jpg'))) 333 | mask_list += (glob.glob(os.path.join(mask_dir, '*.tif'))) 334 | 335 | tif_list.sort() 336 | mask_list.sort() 337 | 338 | os.makedirs(save_dir, exist_ok=True) 339 | gtiffDriver = gdal.GetDriverByName('GTiff') 340 | if gtiffDriver is None: 341 | raise ValueError("Can't find GeoTiff Driver") 342 | for i in range(len(tif_list)): 343 | save_name = tif_list[i].split('\\')[-1] 344 | save_path = os.path.join(save_dir, save_name) 345 | tif = GeoTiff(tif_list[i]) 346 | mask = np.array(Image.open(mask_list[i])) 347 | mask = channel_last_to_first(mask) 348 | save_image_with_georef(mask, gtiffDriver, tif.dataset, save_path=save_path) 349 | 350 | class GeoShaplefile(object): 351 | def __init__(self, file_path=""): 352 | self.file_path = file_path 353 | self.layer = "" 354 | self.minX, self.maxX, self.minY, self.maxY = (0, 0, 0, 0) 355 | self.feature_type = "" 356 | self.feature_num = 0 357 | self.open_shapefile() 358 | def open_shapefile(self): 359 | driver = ogr.GetDriverByName('ESRI Shapefile') 360 | dataSource = driver.Open(self.file_path, 0) 361 | if dataSource is None: 362 | raise IOError('could not open!') 363 | gtiffDriver = gdal.GetDriverByName('GTiff') 364 | if gtiffDriver is None: 365 | raise ValueError("Can't find GeoTiff Driver") 366 | 367 | self.layer = dataSource.GetLayer(0) 368 | self.minX, self.maxX, self.minY, self.maxY = self.layer.GetExtent() 369 | self.feature_num = self.layer.GetFeatureCount() # get poly count 370 | if self.feature_num > 0: 371 | polygon = self.layer.GetFeature(0) 372 | geom = polygon.GetGeometryRef() 373 | # feature type 374 | self.feature_type = geom.GetGeometryName() 375 | 376 | def clip_from_file(clip_size, root, img_path, shp_path): 377 | img_list = os.listdir(root + '/' + img_path) 378 | n_img = len(img_list) 379 | pic_id = 0 380 | for i in range(n_img): 381 | tif = GeoTiff(root + '/' + img_path + '/' + img_list[i]) 382 | img_id = img_list[i].split('.', 1)[0] 383 | pic_num = tif.clip_tif_and_shapefile(clip_size, pic_id, root + '/' + shp_path + '/' + img_id + '/' + img_id + '.shp', root + '/dataset') 384 | pic_id += pic_num 385 | 386 | if __name__ == '__main__': 387 | root = r'./example_data/original_data' 388 | img_path = 'img' 389 | shp_path = 'shp' 390 | clip_from_file(512, root, img_path, shp_path) 391 | -------------------------------------------------------------------------------- /visualize_coco.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import numpy as np 3 | import skimage.io as io 4 | import matplotlib.pyplot as plt 5 | import pylab 6 | import os 7 | 8 | ROOT_DIR = r'./example_data/original_data/dataset/eval' 9 | image_directory = os.path.join(ROOT_DIR, "greenhouse_2019") 10 | annotation_file = os.path.join(ROOT_DIR, "instances_greenhouse_eval2019.json") 11 | 12 | example_coco = COCO(annotation_file) 13 | 14 | category_ids = example_coco.getCatIds(catNms=['square']) 15 | image_ids = example_coco.getImgIds(catIds=category_ids) 16 | image_data = example_coco.loadImgs(image_ids[0])[0] 17 | 18 | image = io.imread(image_directory + '/' + image_data['file_name']) 19 | plt.imshow(image); plt.axis('off') 20 | pylab.rcParams['figure.figsize'] = (8.0, 10.0) 21 | annotation_ids = example_coco.getAnnIds(imgIds=image_data['id'], catIds=category_ids, iscrowd=None) 22 | annotations = example_coco.loadAnns(annotation_ids) 23 | example_coco.showAnns(annotations) 24 | plt.show() --------------------------------------------------------------------------------