├── .gitignore
├── README.md
├── example_data
├── example.png
└── original_data
│ ├── img
│ ├── 03.tif
│ └── 21.tif
│ └── shp
│ ├── 21
│ ├── 21.cpg
│ ├── 21.dbf
│ ├── 21.prj
│ ├── 21.sbn
│ ├── 21.sbx
│ ├── 21.shp
│ └── 21.shx
│ └── 03
│ ├── 03.cpg
│ ├── 03.dbf
│ ├── 03.prj
│ ├── 03.sbn
│ ├── 03.sbx
│ ├── 03.shp
│ └── 03.shx
├── pycococreatortools
└── pycococreatortools.py
├── shape_to_coco.py
├── slice_dataset.py
├── tif_process.py
└── visualize_coco.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # shp2coco
2 | shp2coco is a tool to help create `COCO` datasets from `.shp` file (ArcGIS format).
3 |
4 | It includes:
5 | 1:mask tif with shape file.
6 | 2:crop tif and mask.
7 | 3:slice the dataset into training, eval and test subset.
8 | 4:generate annotations in uncompressed RLE ("crowd") and polygons in the format COCO requires.
9 |
10 | This project is based on [geotool](https://github.com/Kindron/geotool) and [pycococreator](https://github.com/waspinator/pycococreator)
11 |
12 | ## Usage:
13 | If you need to generate annotations in the COCO format, try the following:
14 | `python shape_to_coco.py`
15 | If you need to visualize annotations, try the following:
16 | `python visualize_coco.py`
17 |
18 | ## Example:
19 | 
20 |
21 | ## Thanks to the Third Party Libs
22 | [geotool](https://github.com/Kindron/geotool)
23 | [pycococreator](https://github.com/waspinator/pycococreator)
24 |
--------------------------------------------------------------------------------
/example_data/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/example.png
--------------------------------------------------------------------------------
/example_data/original_data/img/03.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/img/03.tif
--------------------------------------------------------------------------------
/example_data/original_data/img/21.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/img/21.tif
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.cpg:
--------------------------------------------------------------------------------
1 | UTF-8
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.dbf:
--------------------------------------------------------------------------------
1 | w^ A Id N
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.prj:
--------------------------------------------------------------------------------
1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]]
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.sbn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.sbn
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.sbx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.sbx
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.shp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.shp
--------------------------------------------------------------------------------
/example_data/original_data/shp/03/03.shx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/03/03.shx
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.cpg:
--------------------------------------------------------------------------------
1 | UTF-8
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.dbf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.dbf
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.prj:
--------------------------------------------------------------------------------
1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]]
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.sbn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.sbn
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.sbx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.sbx
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.shp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.shp
--------------------------------------------------------------------------------
/example_data/original_data/shp/21/21.shx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dingyuan-Chen/shp2coco/a88fc4b9e649ab11e50dc026bc386fa778b54eef/example_data/original_data/shp/21/21.shx
--------------------------------------------------------------------------------
/pycococreatortools/pycococreatortools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import re
5 | import datetime
6 | import numpy as np
7 | from itertools import groupby
8 | from skimage import measure
9 | from PIL import Image
10 | from pycocotools import mask
11 |
12 | convert = lambda text: int(text) if text.isdigit() else text.lower()
13 | natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
14 |
15 | def resize_binary_mask(array, new_size):
16 | image = Image.fromarray(array.astype(np.uint8)*255)
17 | image = image.resize(new_size)
18 | return np.asarray(image).astype(np.bool_)
19 |
20 | def close_contour(contour):
21 | if not np.array_equal(contour[0], contour[-1]):
22 | contour = np.vstack((contour, contour[0]))
23 | return contour
24 |
25 | def binary_mask_to_rle(binary_mask):
26 | rle = {'counts': [], 'size': list(binary_mask.shape)}
27 | counts = rle.get('counts')
28 | for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order='F'))):
29 | if i == 0 and value == 1:
30 | counts.append(0)
31 | counts.append(len(list(elements)))
32 |
33 | return rle
34 |
35 | def binary_mask_to_polygon(binary_mask, tolerance=0):
36 | """Converts a binary mask to COCO polygon representation
37 |
38 | Args:
39 | binary_mask: a 2D binary numpy array where '1's represent the object
40 | tolerance: Maximum distance from original points of polygon to approximated
41 | polygonal chain. If tolerance is 0, the original coordinate array is returned.
42 |
43 | """
44 | polygons = []
45 | # pad mask to close contours of shapes which start and end at an edge
46 | padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
47 | contours = measure.find_contours(padded_binary_mask, 0.5)
48 | contours = np.subtract(contours, 1)
49 | for contour in contours:
50 | contour = close_contour(contour)
51 | contour = measure.approximate_polygon(contour, tolerance)
52 | if len(contour) < 3:
53 | continue
54 | contour = np.flip(contour, axis=1)
55 | segmentation = contour.ravel().tolist()
56 | # after padding and subtracting 1 we may get -0.5 points in our segmentation
57 | segmentation = [0 if i < 0 else i for i in segmentation]
58 | polygons.append(segmentation)
59 |
60 | return polygons
61 |
62 | def create_image_info(image_id, file_name, image_size,
63 | date_captured=datetime.datetime.utcnow().isoformat(' '),
64 | license_id=1, coco_url="", flickr_url=""):
65 |
66 | image_info = {
67 | "id": image_id,
68 | "file_name": file_name,
69 | "width": image_size[0],
70 | "height": image_size[1],
71 | "date_captured": date_captured,
72 | "license": license_id,
73 | "coco_url": coco_url,
74 | "flickr_url": flickr_url
75 | }
76 |
77 | return image_info
78 |
79 | def create_annotation_info(annotation_id, image_id, category_info, binary_mask,
80 | image_size=None, tolerance=2, bounding_box=None):
81 |
82 | if image_size is not None:
83 | binary_mask = resize_binary_mask(binary_mask, image_size)
84 |
85 | binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8)))
86 |
87 | area = mask.area(binary_mask_encoded)
88 | if area < 1:
89 | return None
90 |
91 | if bounding_box is None:
92 | bounding_box = mask.toBbox(binary_mask_encoded)
93 |
94 | if category_info["is_crowd"]:
95 | is_crowd = 1
96 | segmentation = binary_mask_to_rle(binary_mask)
97 | else :
98 | is_crowd = 0
99 | segmentation = binary_mask_to_polygon(binary_mask, tolerance)
100 | if not segmentation:
101 | return None
102 |
103 | annotation_info = {
104 | "id": annotation_id,
105 | "image_id": image_id,
106 | "category_id": category_info["id"],
107 | "iscrowd": is_crowd,
108 | "area": area.tolist(),
109 | "bbox": bounding_box.tolist(),
110 | "segmentation": segmentation,
111 | "width": binary_mask.shape[1],
112 | "height": binary_mask.shape[0],
113 | }
114 |
115 | return annotation_info
116 |
--------------------------------------------------------------------------------
/shape_to_coco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import datetime
4 | import json
5 | import os
6 | import re
7 | import fnmatch
8 | from PIL import Image
9 | import numpy as np
10 | from pycococreatortools import pycococreatortools
11 | from tif_process import *
12 | from slice_dataset import slice
13 |
14 | # root path for saving the tif and shp file.
15 | ROOT = r'./example_data/original_data'
16 | img_path = 'img'
17 | shp_path = 'shp'
18 | # root path for saving the mask.
19 | ROOT_DIR = ROOT + '/dataset'
20 | IMAGE_DIR = os.path.join(ROOT_DIR, "greenhouse_2019")
21 | ANNOTATION_DIR = os.path.join(ROOT_DIR, "annotations")
22 |
23 | clip_size = 512
24 |
25 | INFO = {
26 | "description": "Greenhouse Dataset",
27 | "url": "",
28 | "version": "0.1.0",
29 | "year": 2019,
30 | "contributor": "DuncanChen",
31 | "date_created": datetime.datetime.utcnow().isoformat(' ')
32 | }
33 |
34 | LICENSES = [
35 | {
36 | "id": 1,
37 | "name": "",
38 | "url": ""
39 | }
40 | ]
41 |
42 | CATEGORIES = [
43 | {
44 | 'id': 1,
45 | 'name': 'greenhouse',
46 | 'supercategory': 'building',
47 | },
48 | ]
49 |
50 | def filter_for_jpeg(root, files):
51 | # file_types = ['*.jpeg', '*.jpg']
52 | file_types = ['*.tiff', '*.tif']
53 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types])
54 | files = [os.path.join(root, f) for f in files]
55 | files = [f for f in files if re.match(file_types, f)]
56 |
57 | return files
58 |
59 | def filter_for_annotations(root, files, image_filename):
60 | # file_types = ['*.png']
61 | file_types = ['*.tif']
62 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types])
63 | basename_no_extension = os.path.splitext(os.path.basename(image_filename))[0]
64 | # file_name_prefix = basename_no_extension + '.*'
65 | files = [os.path.join(root, f) for f in files]
66 | files = [f for f in files if re.match(file_types, f)]
67 | # files = [f for f in files if re.match(file_name_prefix, os.path.splitext(os.path.basename(f))[0])]
68 | files = [f for f in files if basename_no_extension == os.path.splitext(os.path.basename(f))[0].split('_', 1)[0]]
69 |
70 | return files
71 |
72 | def from_mask_to_coco(root, MARK, IMAGE, ANNOTATION):
73 | ROOT_DIR = root + '/' + MARK
74 | IMAGE_DIR = ROOT_DIR + '/' + IMAGE
75 | ANNOTATION_DIR = ROOT_DIR + '/' + ANNOTATION
76 | if os.path.exists(ROOT_DIR):
77 | coco_output = {
78 | "info": INFO,
79 | "licenses": LICENSES,
80 | "categories": CATEGORIES,
81 | "images": [],
82 | "annotations": []
83 | }
84 |
85 | image_id = 1
86 | segmentation_id = 1
87 |
88 | # filter for jpeg images
89 | for root, _, files in os.walk(IMAGE_DIR):
90 | image_files = filter_for_jpeg(root, files)
91 |
92 | # go through each image
93 | for image_filename in image_files:
94 | image = Image.open(image_filename)
95 | image_info = pycococreatortools.create_image_info(
96 | image_id, os.path.basename(image_filename), image.size)
97 | coco_output["images"].append(image_info)
98 |
99 | # filter for associated png annotations
100 | for root, _, files in os.walk(ANNOTATION_DIR):
101 | annotation_files = filter_for_annotations(root, files, image_filename)
102 |
103 | # go through each associated annotation
104 | for annotation_filename in annotation_files:
105 |
106 | print(annotation_filename)
107 | class_id = [x['id'] for x in CATEGORIES if x['name'] in annotation_filename][0]
108 |
109 | category_info = {'id': class_id, 'is_crowd': 'crowd' in image_filename}
110 | binary_mask = np.asarray(Image.open(annotation_filename)
111 | .convert('1')).astype(np.uint8)
112 |
113 | annotation_info = pycococreatortools.create_annotation_info(
114 | segmentation_id, image_id, category_info, binary_mask,
115 | image.size, tolerance=2)
116 |
117 | if annotation_info is not None:
118 | coco_output["annotations"].append(annotation_info)
119 |
120 | segmentation_id = segmentation_id + 1
121 |
122 | image_id = image_id + 1
123 |
124 | with open('{}/instances_greenhouse_{}2019.json'.format(ROOT_DIR, MARK), 'w') as output_json_file:
125 | json.dump(coco_output, output_json_file)
126 | else:
127 | print(ROOT_DIR + ' does not exit!')
128 |
129 | def main():
130 | clip_from_file(clip_size, ROOT, img_path, shp_path)
131 | slice(ROOT_DIR, train=0.6, eval=0.2, test=0.2)
132 | from_mask_to_coco(ROOT_DIR, 'train', "greenhouse_2019", "annotations")
133 | from_mask_to_coco(ROOT_DIR, 'eval', "greenhouse_2019", "annotations")
134 | from_mask_to_coco(ROOT_DIR, 'test', "greenhouse_2019", "annotations")
135 |
136 | if __name__ == "__main__":
137 | main()
138 |
--------------------------------------------------------------------------------
/slice_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import shutil
4 | import re
5 | import fnmatch
6 |
7 | ann_path = 'annotations'
8 | img_path = 'greenhouse_2019'
9 |
10 | def filter_for_annotations(root, files, image_filename):
11 | # file_types = ['*.png']
12 | file_types = ['*.tif']
13 | file_types = r'|'.join([fnmatch.translate(x) for x in file_types])
14 | basename_no_extension = os.path.splitext(os.path.basename(image_filename))[0]
15 | # file_name_prefix = basename_no_extension + '.*'
16 | files = [os.path.join(root, f) for f in files]
17 | files = [f for f in files if re.match(file_types, f)]
18 | # files = [f for f in files if re.match(file_name_prefix, os.path.splitext(os.path.basename(f))[0])]
19 | files = [f for f in files if basename_no_extension == os.path.splitext(os.path.basename(f))[0].split('_', 1)[0]]
20 |
21 | return files
22 |
23 | def copy_data(input_path, id, num, mark = 'train'):
24 | if num != 0:
25 | list = os.listdir(input_path + '/' + img_path)
26 | ann_list = os.listdir(input_path + '/' + ann_path)
27 | if not os.path.isdir(input_path + '/' + mark + '/' + img_path):
28 | os.makedirs(input_path + '/' + mark + '/' + img_path)
29 | if not os.path.isdir(input_path + '/' + mark + '/' + ann_path):
30 | os.makedirs(input_path + '/' + mark + '/' + ann_path)
31 |
32 | for i in range(num):
33 | shutil.copy(input_path + '/' + img_path + '/' + list[id[i]], input_path + '/' + mark + '/' + img_path
34 | + '/' + list[id[i]])
35 | print('From src: ' + img_path + '/' + list[id[i]] + ' =>dst:' + '/' + mark + '/' + img_path
36 | + '/' + list[id[i]])
37 | annotation_files = filter_for_annotations(input_path, ann_list, list[id[i]])
38 | for j in range(len(annotation_files)):
39 | shutil.copy(input_path + '/' + ann_path + '/' + os.path.basename(annotation_files[j]),
40 | input_path + '/' + mark + '/' + ann_path + '/' + os.path.basename(annotation_files[j]))
41 |
42 | f = open(input_path + '/' + mark + '/' + mark + '.txt', 'w')
43 | f.write(str(id))
44 | f.close()
45 |
46 | def slice(input_path, train=0.8, eval=0.2, test=0.0):
47 | """
48 | slice the dataset into training, eval and test sub_dataset.
49 | :param input_path: path to the original dataset.
50 | :param train: the ratio of the training subset.
51 | :param eval: the ratio of the eval subset.
52 | :param test: the ratio of the test subset.
53 | """
54 | list = os.listdir(input_path + '/' + img_path)
55 | ann_list = os.listdir(input_path + '/' + ann_path)
56 | num_list = len(list)
57 | n_train = int(num_list * train)
58 | if test == 0:
59 | n_eval = num_list - n_train
60 | n_test = 0
61 | else:
62 | n_eval = int(num_list * eval)
63 | n_test = num_list - n_train - n_eval
64 |
65 | img_id = np.arange(num_list)
66 | np.random.shuffle(img_id)
67 | train_id, eval_id, test_id = img_id[:n_train], img_id[n_train: n_train+n_eval], img_id[n_train+n_eval:]
68 | copy_data(input_path, train_id, n_train, 'train')
69 | copy_data(input_path, eval_id, n_eval, 'eval')
70 | copy_data(input_path, test_id, n_test, 'test')
71 |
72 | if __name__ == '__main__':
73 | input_path = r'./example_data/original_data/dataset'
74 | # slice(input_path, train=0.6, eval=0.2, test=0.2)
75 | slice(input_path)
76 |
--------------------------------------------------------------------------------
/tif_process.py:
--------------------------------------------------------------------------------
1 | # Date:2019.04.10
2 | # Author: DuncanChen
3 | # A tool implementation on gdal and geotool API
4 | # functions:
5 | # 1. get mask raster with shapefile
6 | # 2. clip raster and shapefile with grid
7 |
8 | from PIL import Image, ImageDraw
9 | import os
10 | from osgeo import gdal, gdalnumeric
11 | import numpy as np
12 | import ogr
13 | import glob
14 | gdal.UseExceptions()
15 |
16 |
17 | class GeoTiff(object):
18 | def __init__(self, tif_path):
19 | """
20 | A tool for Remote Sensing Image
21 | Args:
22 | tif_path: tif path
23 | Examples::
24 | >>> tif = GeoTif('xx.tif')
25 | # if you want to clip tif with grid reserved geo reference
26 | >>> tif.clip_tif_with_grid(512, 'out_dir')
27 | # if you want to clip tif with shape file
28 | >>> tif.clip_tif_with_shapefile('shapefile.shp', 'save_path.tif')
29 | # if you want to mask tif with shape file
30 | >>> tif.mask_tif_with_shapefile('shapefile.shp', 'save_path.tif')
31 | """
32 | self.dataset = gdal.Open(tif_path)
33 | self.bands_count = self.dataset.RasterCount
34 | # get each band
35 | self.bands = [self.dataset.GetRasterBand(i + 1) for i in range(self.bands_count)]
36 | self.col = self.dataset.RasterXSize
37 | self.row = self.dataset.RasterYSize
38 | self.geotransform = self.dataset.GetGeoTransform()
39 | self.src_path = tif_path
40 | self.mask = None
41 | self.mark = None
42 |
43 | def get_left_top(self):
44 | return self.geotransform[3], self.geotransform[0]
45 |
46 | def get_pixel_height_width(self):
47 | return abs(self.geotransform[5]), abs(self.geotransform[1])
48 |
49 | def __getitem__(self, *args):
50 | """
51 |
52 | Args:
53 | *args: range, an instance of tuple, ((start, stop, step), (start, stop, step))
54 |
55 | Returns:
56 | res: image block , array ,[bands......, height, weight]
57 |
58 | """
59 | if isinstance(args[0], tuple) and len(args[0]) == 2:
60 | # get params
61 | start_row, end_row = args[0][0].start, args[0][0].stop
62 | start_col, end_col = args[0][1].start, args[0][1].stop
63 | start_row = 0 if start_row is None else start_row
64 | start_col = 0 if start_col is None else start_col
65 | num_row = self.row if end_row is None else (end_row - start_row)
66 | num_col = self.col if end_col is None else (end_col - start_col)
67 | # dataset read image array
68 | res = self.dataset.ReadAsArray(start_col, start_row, num_col, num_row)
69 | return res
70 | else:
71 | raise NotImplementedError('the param should be [a: b, c: d] !')
72 |
73 | def clip_tif_with_grid(self, clip_size, begin_id, out_dir):
74 | """
75 | clip image with grid
76 | Args:
77 | clip_size: int
78 | out_dir: str
79 |
80 | Returns:
81 |
82 | """
83 | if not os.path.exists(out_dir):
84 | # check the dir
85 | os.makedirs(out_dir)
86 | print('create dir', out_dir)
87 |
88 | row_num = int(self.row / clip_size)
89 | col_num = int(self.col / clip_size)
90 |
91 | gtiffDriver = gdal.GetDriverByName('GTiff')
92 | if gtiffDriver is None:
93 | raise ValueError("Can't find GeoTiff Driver")
94 |
95 | count = 1
96 | for i in range(row_num):
97 | for j in range(col_num):
98 | # if begin_id+i*col_num+j in self.mark:
99 | # continue
100 | clipped_image = np.array(self[i * clip_size: (i + 1) * clip_size, j * clip_size: (j + 1) * clip_size])
101 | clipped_image = clipped_image.astype(np.int8)
102 |
103 | try:
104 | save_path = os.path.join(out_dir, '%d.tif' % (begin_id+i*col_num+j))
105 | save_image_with_georef(clipped_image, gtiffDriver,
106 | self.dataset, j*clip_size, i*clip_size, save_path)
107 | print('clip successfully!(%d/%d)' % (count, row_num * col_num))
108 | count += 1
109 | except Exception:
110 | raise IOError('clip failed!%d' % count)
111 |
112 | return row_num * col_num
113 |
114 | def clip_mask_with_grid(self, clip_size, begin_id, out_dir):
115 | """
116 | clip mask with grid
117 | Args:
118 | clip_size: int
119 | out_dir: str
120 |
121 | Returns:
122 |
123 | """
124 | if not os.path.exists(out_dir):
125 | # check the dir
126 | os.makedirs(out_dir)
127 | print('create dir', out_dir)
128 |
129 | row_num = int(self.row / clip_size)
130 | col_num = int(self.col / clip_size)
131 |
132 | gtiffDriver = gdal.GetDriverByName('GTiff')
133 | if gtiffDriver is None:
134 | raise ValueError("Can't find GeoTiff Driver")
135 |
136 | # self.mark = []
137 |
138 | count = 1
139 | for i in range(row_num):
140 | for j in range(col_num):
141 | clipped_image = np.array(self.mask[0, i * clip_size: (i + 1) * clip_size, j * clip_size: (j + 1) * clip_size])
142 | ins_list = np.unique(clipped_image)
143 | # if len(ins_list) <= 1:
144 | # self.mark.append(begin_id+i*col_num+j)
145 | # continue
146 | ins_list = ins_list[1:]
147 | for id in range(len(ins_list)):
148 | bg_img = np.zeros((clipped_image.shape)).astype(np.int8)
149 | if ins_list[id] > 0:
150 | bg_img[np.where(clipped_image == ins_list[id])] = 255
151 | try:
152 | save_path = os.path.join(out_dir, '%d_%s_%d.tif' % (begin_id+i*col_num+j, 'greenhouse', id))
153 | save_image_with_georef(bg_img, gtiffDriver,
154 | self.dataset, j*clip_size, i*clip_size, save_path)
155 | print('clip mask successfully!(%d/%d)' % (count, row_num * col_num))
156 | count += 1
157 | except Exception:
158 | raise IOError('clip failed!%d' % count)
159 |
160 | def world2Pixel(self, x, y):
161 | """
162 | Uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
163 | the pixel location of a geospatial coordinate
164 | """
165 | ulY, ulX = self.get_left_top()
166 | distY, distX = self.get_pixel_height_width()
167 |
168 | pixel_x = abs(int((x - ulX) / distX))
169 | pixel_y = abs(int((ulY - y) / distY))
170 | pixel_y = self.row if pixel_y > self.row else pixel_y
171 | pixel_x = self.col if pixel_x > self.col else pixel_x
172 | return pixel_x, pixel_y
173 |
174 | def mask_tif_with_shapefile(self, shapefile_path, label=255):
175 | """
176 | mask tif with shape file, supported point, line, polygon and multi polygons
177 | Args:
178 | shapefile_path:
179 | save_path:
180 | label:
181 |
182 | Returns:
183 |
184 | """
185 | driver = ogr.GetDriverByName('ESRI Shapefile')
186 | dataSource = driver.Open(shapefile_path, 0)
187 | if dataSource is None:
188 | raise IOError('could not open!')
189 | gtiffDriver = gdal.GetDriverByName('GTiff')
190 | if gtiffDriver is None:
191 | raise ValueError("Can't find GeoTiff Driver")
192 |
193 | layer = dataSource.GetLayer(0)
194 | # # Convert the layer extent to image pixel coordinates
195 | minX, maxX, minY, maxY = layer.GetExtent()
196 | ulX, ulY = self.world2Pixel(minX, maxY)
197 |
198 | # initialize mask drawing
199 | rasterPoly = Image.new("I", (self.col, self.row), 0)
200 | rasterize = ImageDraw.Draw(rasterPoly)
201 |
202 | feature_num = layer.GetFeatureCount() # get poly count
203 | for i in range(feature_num):
204 | points = [] # store points
205 | pixels = [] # store pixels
206 | feature = layer.GetFeature(i)
207 | geom = feature.GetGeometryRef()
208 | feature_type = geom.GetGeometryName()
209 |
210 | if feature_type == 'POLYGON' or 'MULTIPOLYGON':
211 | # multi polygon operation
212 | # 1. use label to mask the max polygon
213 | # 2. use -label to mask the other polygon
214 | for j in range(geom.GetGeometryCount()):
215 | sub_polygon = geom.GetGeometryRef(j)
216 | if feature_type == 'MULTIPOLYGON':
217 | sub_polygon = sub_polygon.GetGeometryRef(0)
218 | for p_i in range(sub_polygon.GetPointCount()):
219 | px = sub_polygon.GetX(p_i)
220 | py = sub_polygon.GetY(p_i)
221 | points.append((px, py))
222 |
223 | for p in points:
224 | origin_pixel_x, origin_pixel_y = self.world2Pixel(p[0], p[1])
225 | # the pixel in new image
226 | new_pixel_x, new_pixel_y = origin_pixel_x, origin_pixel_y
227 | pixels.append((new_pixel_x, new_pixel_y))
228 |
229 | rasterize.polygon(pixels, i+1)
230 | pixels = []
231 | points = []
232 | if feature_type != 'MULTIPOLYGON':
233 | label = -abs(label)
234 |
235 | # restore the label value
236 | label = abs(label)
237 | else:
238 | for j in range(geom.GetPointCount()):
239 | px = geom.GetX(j)
240 | py = geom.GetY(j)
241 | points.append((px, py))
242 |
243 | for p in points:
244 | origin_pixel_x, origin_pixel_y = self.world2Pixel(p[0], p[1])
245 | # the pixel in new image
246 | new_pixel_x, new_pixel_y = origin_pixel_x, origin_pixel_y
247 | pixels.append((new_pixel_x, new_pixel_y))
248 |
249 | feature.Destroy() # delete feature
250 |
251 | if feature_type == 'LINESTRING':
252 | rasterize.line(pixels, i+1)
253 | if feature_type == 'POINT':
254 | # pixel x, y
255 | rasterize.point(pixels, i+1)
256 |
257 | mask = np.array(rasterPoly)
258 | self.mask = mask[np.newaxis, :] # extend an axis to three
259 |
260 | def clip_tif_and_shapefile(self, clip_size, begin_id, shapefile_path, out_dir):
261 | self.mask_tif_with_shapefile(shapefile_path)
262 | self.clip_mask_with_grid(clip_size=clip_size, begin_id=begin_id, out_dir=out_dir + '/annotations')
263 | pic_id = self.clip_tif_with_grid(clip_size=clip_size, begin_id=begin_id, out_dir=out_dir + '/greenhouse_2019')
264 | return pic_id
265 |
266 | def channel_first_to_last(image):
267 | """
268 |
269 | Args:
270 | image: 3-D numpy array of shape [channel, width, height]
271 |
272 | Returns:
273 | new_image: 3-D numpy array of shape [height, width, channel]
274 | """
275 | new_image = np.transpose(image, axes=[1, 2, 0])
276 | return new_image
277 |
278 | def channel_last_to_first(image):
279 | """
280 |
281 | Args:
282 | image: 3-D numpy array of shape [channel, width, height]
283 |
284 | Returns:
285 | new_image: 3-D numpy array of shape [height, width, channel]
286 | """
287 | new_image = np.transpose(image, axes=[2, 0, 1])
288 | return new_image
289 |
290 | def save_image_with_georef(image, driver, original_ds, offset_x=0, offset_y=0, save_path=None):
291 | """
292 |
293 | Args:
294 | save_path: str, image save path
295 | driver: gdal IO driver
296 | image: an instance of ndarray
297 | original_ds: a instance of data set
298 | offset_x: x location in data set
299 | offset_y: y location in data set
300 |
301 | Returns:
302 |
303 | """
304 | # get Geo Reference
305 | ds = gdalnumeric.OpenArray(image)
306 | gdalnumeric.CopyDatasetInfo(original_ds, ds, xoff=offset_x, yoff=offset_y)
307 | driver.CreateCopy(save_path, ds)
308 | # write by band
309 | clip = image.astype(np.int8)
310 | # write the dataset
311 | if len(image.shape)==3:
312 | for i in range(image.shape[0]):
313 | ds.GetRasterBand(i + 1).WriteArray(clip[i])
314 | else:
315 | ds.GetRasterBand(1).WriteArray(clip)
316 | del ds
317 |
318 | def define_ref_predict(tif_dir, mask_dir, save_dir):
319 | """
320 | define reference for raster referred to a geometric raster.
321 | Args:
322 | tif_dir: the dir to save referenced raster
323 | mask_dir:
324 | save_dir:
325 |
326 | Returns:
327 |
328 | """
329 | tif_list = glob.glob(os.path.join(tif_dir, '*.tif'))
330 |
331 | mask_list = glob.glob(os.path.join(mask_dir, '*.png'))
332 | mask_list += (glob.glob(os.path.join(mask_dir, '*.jpg')))
333 | mask_list += (glob.glob(os.path.join(mask_dir, '*.tif')))
334 |
335 | tif_list.sort()
336 | mask_list.sort()
337 |
338 | os.makedirs(save_dir, exist_ok=True)
339 | gtiffDriver = gdal.GetDriverByName('GTiff')
340 | if gtiffDriver is None:
341 | raise ValueError("Can't find GeoTiff Driver")
342 | for i in range(len(tif_list)):
343 | save_name = tif_list[i].split('\\')[-1]
344 | save_path = os.path.join(save_dir, save_name)
345 | tif = GeoTiff(tif_list[i])
346 | mask = np.array(Image.open(mask_list[i]))
347 | mask = channel_last_to_first(mask)
348 | save_image_with_georef(mask, gtiffDriver, tif.dataset, save_path=save_path)
349 |
350 | class GeoShaplefile(object):
351 | def __init__(self, file_path=""):
352 | self.file_path = file_path
353 | self.layer = ""
354 | self.minX, self.maxX, self.minY, self.maxY = (0, 0, 0, 0)
355 | self.feature_type = ""
356 | self.feature_num = 0
357 | self.open_shapefile()
358 | def open_shapefile(self):
359 | driver = ogr.GetDriverByName('ESRI Shapefile')
360 | dataSource = driver.Open(self.file_path, 0)
361 | if dataSource is None:
362 | raise IOError('could not open!')
363 | gtiffDriver = gdal.GetDriverByName('GTiff')
364 | if gtiffDriver is None:
365 | raise ValueError("Can't find GeoTiff Driver")
366 |
367 | self.layer = dataSource.GetLayer(0)
368 | self.minX, self.maxX, self.minY, self.maxY = self.layer.GetExtent()
369 | self.feature_num = self.layer.GetFeatureCount() # get poly count
370 | if self.feature_num > 0:
371 | polygon = self.layer.GetFeature(0)
372 | geom = polygon.GetGeometryRef()
373 | # feature type
374 | self.feature_type = geom.GetGeometryName()
375 |
376 | def clip_from_file(clip_size, root, img_path, shp_path):
377 | img_list = os.listdir(root + '/' + img_path)
378 | n_img = len(img_list)
379 | pic_id = 0
380 | for i in range(n_img):
381 | tif = GeoTiff(root + '/' + img_path + '/' + img_list[i])
382 | img_id = img_list[i].split('.', 1)[0]
383 | pic_num = tif.clip_tif_and_shapefile(clip_size, pic_id, root + '/' + shp_path + '/' + img_id + '/' + img_id + '.shp', root + '/dataset')
384 | pic_id += pic_num
385 |
386 | if __name__ == '__main__':
387 | root = r'./example_data/original_data'
388 | img_path = 'img'
389 | shp_path = 'shp'
390 | clip_from_file(512, root, img_path, shp_path)
391 |
--------------------------------------------------------------------------------
/visualize_coco.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | import numpy as np
3 | import skimage.io as io
4 | import matplotlib.pyplot as plt
5 | import pylab
6 | import os
7 |
8 | ROOT_DIR = r'./example_data/original_data/dataset/eval'
9 | image_directory = os.path.join(ROOT_DIR, "greenhouse_2019")
10 | annotation_file = os.path.join(ROOT_DIR, "instances_greenhouse_eval2019.json")
11 |
12 | example_coco = COCO(annotation_file)
13 |
14 | category_ids = example_coco.getCatIds(catNms=['square'])
15 | image_ids = example_coco.getImgIds(catIds=category_ids)
16 | image_data = example_coco.loadImgs(image_ids[0])[0]
17 |
18 | image = io.imread(image_directory + '/' + image_data['file_name'])
19 | plt.imshow(image); plt.axis('off')
20 | pylab.rcParams['figure.figsize'] = (8.0, 10.0)
21 | annotation_ids = example_coco.getAnnIds(imgIds=image_data['id'], catIds=category_ids, iscrowd=None)
22 | annotations = example_coco.loadAnns(annotation_ids)
23 | example_coco.showAnns(annotations)
24 | plt.show()
--------------------------------------------------------------------------------