├── .gitignore
├── datasets
├── __init__.py
├── background_datasets.py
├── base_datasets.py
├── data_manager.py
├── panoptic_coco_categories.json
├── coco_categories.py
└── object_datasets.py
├── segment_layout
├── __init__.py
├── existing_layout_generator.py
├── random_bbox_layout_generator.py
└── fine_grained_bbox_layout_generator.py
├── segmentation_generator
├── utils.py
├── __init__.py
├── generate_image_and_mask.py
├── panoptic_coco_categories.json
└── segmentataion_synthesizer.py
├── requirements.txt
├── requirements_referring_expression_generation.txt
├── .gitattributes
├── assets
├── ovd.png
├── ref.png
├── fcgc.png
├── pipeline.png
├── results.png
├── sfcsgc.png
├── teaser.png
├── comparison.png
└── intra-class.png
├── requirements_relight_and_blending.txt
├── scripts
└── generate_with_batch.py
├── README.md
├── referring_expression_generation
└── inference.py
└── relighting_and_blending
└── inference.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/segment_layout/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/segmentation_generator/utils.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/segmentation_generator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | open_clip_torch
2 | pandas
3 | pycocotools
4 | imageio[pyav]
--------------------------------------------------------------------------------
/requirements_referring_expression_generation.txt:
--------------------------------------------------------------------------------
1 | pytorch
2 | vllm
3 | google-cloud-storage
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.pdf filter=lfs diff=lfs merge=lfs -text
2 | *.png filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/assets/ovd.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:900c4744d0d78ea674ec2e1ba0e77d34046d5bf3206697e82ddc313b2e69b8a2
3 | size 474129
4 |
--------------------------------------------------------------------------------
/assets/ref.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5c57b5a85ae1778a44b7af9c0a49a5506e3b306151923f1dc1f7b6276d3026ab
3 | size 504939
4 |
--------------------------------------------------------------------------------
/assets/fcgc.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a5638531abd756b072c6135233e750be5ba4f8ea269fc5eb283f3813d5e206ce
3 | size 1039387
4 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:76725df1ae1ddc3906464b997cce6a8b90548fa65a1cb159e0f0435133efa14a
3 | size 872622
4 |
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:ca283f0a17b8958b1857fe24fbd1edc6b25b9c261b49112c2732c1e0e08cde95
3 | size 387757
4 |
--------------------------------------------------------------------------------
/assets/sfcsgc.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d862973fcb1de835f907486a06aa1c1db87b0b946a53b0c1db3ba99a5a95dd63
3 | size 708186
4 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7d6bc7e1b3a3d6258b180c5396cde1db6da76e0b1a118868bf963cf10274ba24
3 | size 2273149
4 |
--------------------------------------------------------------------------------
/assets/comparison.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7ce424160c8d85ea1baf949734a2f7891d5d9bd848cd702fa3fbb04d4001e39f
3 | size 1307979
4 |
--------------------------------------------------------------------------------
/assets/intra-class.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:2c615f4a058fa08f4d1bf98f468164f752d452b7f00c411d1c066c596bc3c90e
3 | size 715132
4 |
--------------------------------------------------------------------------------
/requirements_relight_and_blending.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.32.1
2 | transformers==4.48.0
3 | opencv-python
4 | safetensors
5 | pillow==10.2.0
6 | einops
7 | torch
8 | peft
9 | gradio==3.41.2
10 | protobuf==3.20
11 | google-cloud-storage
12 | scikit-image
--------------------------------------------------------------------------------
/datasets/background_datasets.py:
--------------------------------------------------------------------------------
1 | from .base_datasets import BaseBackgroundDataset
2 | import os
3 | import glob
4 | import pandas as pd
5 |
6 |
7 | class BG20KDataset(BaseBackgroundDataset):
8 | def __init__(self, dataset_path):
9 | dataset_name = "BG20k"
10 | data_type = "background"
11 | super().__init__(dataset_name, data_type, dataset_path)
12 | # get metadata_table
13 | rows = []
14 | self.metadata_table = pd.DataFrame(columns=['idx_for_curr_dataset', 'category', 'image_path', 'dataset_name', 'data_type'])
15 | index_counter = 0
16 | for image_path in glob.glob(self.dataset_path + "/train/*.jpg"):
17 | rows.append({
18 | 'idx_for_curr_dataset': index_counter,
19 | 'category': "no category",
20 | 'image_path': image_path,
21 | 'dataset_name': dataset_name,
22 | 'data_type': data_type
23 | })
24 | index_counter += 1
25 | self.metadata_table = pd.DataFrame(rows)
26 |
27 | class NoBGDataset(BaseBackgroundDataset):
28 | pass
--------------------------------------------------------------------------------
/segment_layout/existing_layout_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy as np
4 | import pandas as pd
5 | import random
6 | import ast
7 |
8 | def calculate_image_areas(image_paths):
9 | components = []
10 | for path in image_paths:
11 | with Image.open(path) as img:
12 | width, height = img.size
13 | components.append((width, height))
14 | return components
15 |
16 | def compute_similarity(component, box):
17 | """
18 | Compute similarity score based on area and aspect ratio.
19 | """
20 | # Unpack dimensions
21 | comp_h, comp_w = component
22 | box_h = box[2] - box[0]
23 | box_w = box[3] - box[1]
24 |
25 | # Calculate areas and aspect ratios
26 | comp_area = comp_h * comp_w
27 | box_area = box_h * box_w
28 |
29 | comp_aspect_ratio = comp_h / comp_w
30 | box_aspect_ratio = box_h / box_w
31 |
32 | # Area difference (normalized)
33 | area_diff = abs(comp_area - box_area) / max(comp_area, box_area)
34 |
35 | # Aspect ratio difference (normalized)
36 | aspect_ratio_diff = abs(comp_aspect_ratio - box_aspect_ratio)
37 |
38 | # Weighted score (lower is better)
39 | score = 0.8 * area_diff + 0.2 * aspect_ratio_diff
40 | return score
41 |
42 |
43 | def postprocess_bboxes(bbox_predictions, width, height):
44 | # Scale the normalized bbox predictions to image dimensions
45 | processed_bboxes = []
46 | for bbox in bbox_predictions:
47 | xmin = bbox[0] * width
48 | ymin = bbox[1] * height
49 | xmax = bbox[2] * width
50 | ymax = bbox[3] * height
51 | processed_bboxes.append([xmin, ymin, xmax, ymax])
52 |
53 | return processed_bboxes
54 |
55 | class ExistingLayoutGenerator():
56 | def __init__(self, table_path, device=None):
57 | # Determine device
58 | if device is None:
59 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60 | else:
61 | self.device = device
62 |
63 | # Load model weights
64 | self.layout_table = pd.read_csv(table_path)
65 |
66 | def predict_wi_area(self, num_bboxes, segment_list, width, height):
67 | # implement future
68 | pass
69 |
70 | def predict_wo_area(self, num_bboxes, width, height):
71 | # should we pass the result here
72 | layout_table_filtered = self.layout_table[self.layout_table['num_bboxes'] >= num_bboxes]
73 | # sample a layout from the filtered table
74 | bboxes_result = ast.literal_eval(layout_table_filtered.sample(n=1)['bboxes'].item())
75 | correspond_bboxes_result = bboxes_result[0:num_bboxes]
76 | # sort the box
77 | correspond_bboxes_result = sorted(
78 | bboxes_result, key=lambda bbox: (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), reverse=True
79 | )
80 | # correspond_bboxes_result = random.sample(bboxes_result, num_bboxes)
81 |
82 | outputs = postprocess_bboxes(correspond_bboxes_result, width, height)
83 | return outputs
84 |
--------------------------------------------------------------------------------
/datasets/base_datasets.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | def convert_rgba_to_rgb_and_mask(rgba_image: Image):
7 | # Open the RGBA image
8 | if rgba_image.mode != "RGBA":
9 | raise ValueError("Input image must be in RGBA mode")
10 | # Convert to RGB
11 | rgb_image = rgba_image.convert("RGB")
12 |
13 | # Create a mask based on the alpha channel
14 | alpha_channel = np.array(rgba_image.getchannel('A'))
15 | mask = Image.fromarray(alpha_channel) # This will be a grayscale image
16 | return rgb_image, mask
17 |
18 |
19 | def get_bounding_box(mask):
20 | pass
21 |
22 | def get_image_and_mask(image_path, mask_path):
23 | if image_path == mask_path:
24 | image_with_mask = Image.open(image_path)
25 | if image_with_mask.mode == "RGBA":
26 | return convert_rgba_to_rgb_and_mask(image_with_mask)
27 | else:
28 | raise ValueError("Image and mask are the same, but image is not in RGBA mode")
29 | else:
30 | return Image.open(image_path), Image.open(mask_path)
31 |
32 |
33 | class BaseObjectDataset(Dataset):
34 | def __init__(self, dataset_name, data_type, dataset_path, filtering_annotations_path=None):
35 | self.dataset_name = dataset_name
36 | self.dataset_path = dataset_path
37 | self.data_type = data_type
38 | self.metadata_table = None
39 | self.filtering_annotations_path = filtering_annotations_path
40 |
41 |
42 | def __len__(self):
43 | return len(self.metadata_table)
44 |
45 | def __getitem__(self, index):
46 | data_package = {}
47 | metadata = self.metadata_table[self.metadata_table['idx_for_curr_dataset'] == index].iloc[0]
48 | data_package['metadata'] = metadata
49 | data_package['image'], data_package['mask'] = get_image_and_mask(metadata['image_path'], metadata['mask_path'])
50 | return data_package
51 |
52 | def return_metadata_table(self):
53 | # metadata contains the information of each mask, like its cateogory, path.
54 | # each mask store as a row in the metadata table
55 | return self.metadata_table
56 |
57 |
58 | class BaseBackgroundDataset(Dataset):
59 | def __init__(self, dataset_name, data_type, dataset_path):
60 | self.dataset_name = dataset_name
61 | self.dataset_path = dataset_path
62 | self.data_type = data_type
63 | self.metadata_table = None
64 |
65 | def __len__(self):
66 | return len(self.metadata_table)
67 |
68 | def __getitem__(self, index):
69 | data_package = {}
70 | metadata = self.metadata_table[self.metadata_table['idx_for_curr_dataset'] == index].iloc[0]
71 | data_package['metadata'] = metadata
72 | data_package['image'] = Image.open(metadata['image_path']).convert("RGB")
73 | return data_package
74 |
75 | def return_metadata_table(self):
76 | # metadata contains the information of each mask, like its cateogory, path.
77 | # each mask store as a row in the metadata table
78 | return self.metadata_table
--------------------------------------------------------------------------------
/segment_layout/random_bbox_layout_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy as np
4 | import pandas as pd
5 | import random
6 | import ast
7 |
8 | def postprocess_bboxes(bbox_predictions, width, height):
9 | # Scale the normalized bbox predictions to image dimensions.
10 | processed_bboxes = []
11 | for bbox in bbox_predictions:
12 | xmin = bbox[0] * width
13 | ymin = bbox[1] * height
14 | xmax = bbox[2] * width
15 | ymax = bbox[3] * height
16 | processed_bboxes.append([xmin, ymin, xmax, ymax])
17 | return processed_bboxes
18 |
19 | def compute_iou(bbox1, bbox2):
20 | """
21 | Compute Intersection over Union (IoU) for two boxes in normalized coordinates.
22 | Each bbox is in the format [x_min, y_min, x_max, y_max].
23 | """
24 | x_left = max(bbox1[0], bbox2[0])
25 | y_top = max(bbox1[1], bbox2[1])
26 | x_right = min(bbox1[2], bbox2[2])
27 | y_bottom = min(bbox1[3], bbox2[3])
28 |
29 | if x_right < x_left or y_bottom < y_top:
30 | return 0.0
31 |
32 | intersection_area = (x_right - x_left) * (y_bottom - y_top)
33 | area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
34 | area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
35 | union_area = area1 + area2 - intersection_area
36 | if union_area == 0:
37 | return 0.0
38 | return intersection_area / union_area
39 |
40 | class RandomBoundingBoxLayoutGenerator():
41 | def __init__(self, device=None):
42 | # Determine device.
43 | if device is None:
44 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 | else:
46 | self.device = device
47 |
48 | def generate(self, num_bboxes, width, height, min_box_scale=0.05, max_box_scale=0.6, max_overlap=0.5):
49 | """
50 | Generate num_bboxes random bounding boxes in a controlled way with limited overlap.
51 | The boxes are defined in normalized coordinates and then scaled to the given width and height.
52 | min_box_scale and max_box_scale define the fraction of the image that a box can take in width and height.
53 | max_overlap defines the maximum allowed Intersection over Union (IoU) between any two boxes.
54 | """
55 | bboxes = []
56 | max_attempts_total = num_bboxes * 50 # Maximum number of candidate generations.
57 | attempts = 0
58 |
59 | while len(bboxes) < num_bboxes and attempts < max_attempts_total:
60 | # Randomly select box width and height as a fraction of the full image.
61 | box_w = random.uniform(min_box_scale, max_box_scale)
62 | box_h = random.uniform(min_box_scale, max_box_scale)
63 | # Choose top-left coordinates so that the box fits within [0,1].
64 | x_min = random.uniform(0, 1 - box_w)
65 | y_min = random.uniform(0, 1 - box_h)
66 | x_max = x_min + box_w
67 | y_max = y_min + box_h + 0.3
68 | candidate_box = [x_min, y_min, x_max, y_max]
69 |
70 | # Check overlap with all previously accepted boxes.
71 | valid = True
72 | for existing_box in bboxes:
73 | if compute_iou(candidate_box, existing_box) > max_overlap:
74 | valid = False
75 | break
76 |
77 | if valid:
78 | bboxes.append(candidate_box)
79 | attempts += 1
80 |
81 | if len(bboxes) < num_bboxes:
82 | print(f"Warning: Only generated {len(bboxes)} boxes with the given overlap constraints.")
83 |
84 | processed_bboxes = postprocess_bboxes(bboxes, width, height)
85 | return processed_bboxes
--------------------------------------------------------------------------------
/datasets/data_manager.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | class DataManager():
4 | def __init__(self, available_object_datasets, available_background_datasets, filtering_setting):
5 | # Process object datasets.
6 | self.available_object_datasets = available_object_datasets
7 | self.filtered_object_data = {} # To store filtered data for each object dataset.
8 | self.accumulated_object_metadata_table = pd.DataFrame()
9 | unified_object_categories = set()
10 |
11 | for dataset_name, curr_dataset in self.available_object_datasets.items():
12 | # Retrieve the original metadata table.
13 | metadata_table = curr_dataset.return_metadata_table()
14 | count_before = len(metadata_table)
15 |
16 | # Apply filtering based on the provided settings.
17 | if metadata_table['filtering_annotation'].isna().any():
18 | filtered_table = metadata_table
19 | count_after = count_before
20 | else:
21 | filter_mask = pd.Series(True, index=metadata_table.index)
22 | for metric, condition in filtering_setting.items():
23 | if condition == "filter":
24 | filter_mask &= metadata_table['filtering_annotation'].apply(
25 | lambda x: x.get(metric, False) if pd.notna(x) else False
26 | )
27 | filtered_table = metadata_table[filter_mask]
28 | count_after = len(filtered_table)
29 |
30 | # Print counts for this dataset.
31 | print(f"Object dataset '{dataset_name}':")
32 | print(f" Count before filtering: {count_before}")
33 | print(f" Count after filtering: {count_after}")
34 |
35 | # Store the filtered table for later access and accumulate in one table.
36 | self.filtered_object_data[dataset_name] = filtered_table
37 | self.accumulated_object_metadata_table = pd.concat(
38 | [self.accumulated_object_metadata_table, filtered_table]
39 | )
40 |
41 | # Collect unified object categories.
42 | unified_object_categories.update(curr_dataset.categories)
43 |
44 | # Prepare a sorted list of unified object categories and a mapping to indices.
45 | self.unified_object_categories = sorted(list(unified_object_categories))
46 | self.unified_object_categories_to_idx = {category: idx for idx, category in enumerate(self.unified_object_categories)}
47 |
48 | # Process background datasets.
49 | self.available_background_datasets = available_background_datasets
50 | self.accumulated_background_metadata_table = pd.DataFrame()
51 | self.background_data = {} # Store each background dataset's metadata table.
52 | for dataset_name, curr_dataset in self.available_background_datasets.items():
53 | metadata_table = curr_dataset.return_metadata_table()
54 | self.background_data[dataset_name] = metadata_table
55 | self.accumulated_background_metadata_table = pd.concat(
56 | [self.accumulated_background_metadata_table, metadata_table]
57 | )
58 | print(f"Background dataset '{dataset_name}' has {len(metadata_table)} records.")
59 |
60 | print("Finished processing object and background datasets.")
61 |
62 | # object
63 | def query_object_metadata(self, query):
64 | return self.accumulated_object_metadata_table.query(query)
65 |
66 | def get_object_by_metadata(self, metadata):
67 | dataset_name = metadata["dataset_name"]
68 | idx_for_curr_dataset = metadata["idx_for_curr_dataset"]
69 | return self.available_object_datasets[dataset_name][idx_for_curr_dataset]
70 |
71 | def get_random_object_metadata(self, rng):
72 | # First uniformly select an object dataset.
73 | dataset_name = rng.choice(list(self.filtered_object_data.keys()))
74 | dataset_table = self.filtered_object_data[dataset_name]
75 | # Then uniformly sample from that dataset's metadata table.
76 | idx = rng.choice(len(dataset_table))
77 | selected_metadata = dataset_table.iloc[idx].copy()
78 | # Add dataset information to the metadata.
79 | selected_metadata["dataset_name"] = dataset_name
80 | return selected_metadata
81 |
82 | # background
83 | def query_background_metadata(self, query):
84 | return self.accumulated_background_metadata_table.query(query)
85 |
86 | def get_background_by_metadata(self, metadata):
87 | dataset_name = metadata["dataset_name"]
88 | idx_for_curr_dataset = metadata["idx_for_curr_dataset"]
89 | return self.available_background_datasets[dataset_name][idx_for_curr_dataset]
90 |
91 | def get_random_background_metadata(self, rng):
92 | # First uniformly select a background dataset.
93 | dataset_name = rng.choice(list(self.background_data.keys()))
94 | metadata_table = self.background_data[dataset_name]
95 | # Then uniformly sample from that dataset's metadata table.
96 | idx = rng.choice(len(metadata_table))
97 | selected_metadata = metadata_table.iloc[idx].copy()
98 | # Add dataset information to the metadata.
99 | selected_metadata["dataset_name"] = dataset_name
100 | return selected_metadata
101 |
102 | def get_random_object_metadata_by_category(self, rng, category):
103 | df = self.accumulated_object_metadata_table
104 | # vectorized filter
105 | cat_df = df[df['category'] == category]
106 | if cat_df.empty:
107 | raise ValueError(f"No objects found for category '{category}'")
108 | # sample a single row by integer location
109 | i = rng.integers(0, len(cat_df))
110 | return cat_df.iloc[i].copy()
--------------------------------------------------------------------------------
/segment_layout/fine_grained_bbox_layout_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import matplotlib.pyplot as plt
4 | import matplotlib.patches as patches
5 | from shapely.geometry import box as shapely_box
6 | from shapely.ops import unary_union
7 | import math
8 |
9 | # Function to scale normalized bounding boxes to pixel coordinates
10 | def postprocess_bboxes(bboxes, width, height):
11 | processed = []
12 | for bbox in bboxes:
13 | x_min, y_min, x_max, y_max = bbox
14 | processed.append([int(x_min * width),
15 | int(y_min * height),
16 | int(min(x_max * width, width)),
17 | int(min(y_max * height, height))])
18 | return processed
19 |
20 | def compute_visible_ratios(bboxes):
21 | """
22 | For a given list of boxes (in generation order), compute the visible area ratio
23 | for each box (visible_area / total_area), assuming later boxes occlude earlier ones.
24 | """
25 | ratios = []
26 | polys = [shapely_box(*bbox) for bbox in bboxes]
27 | n = len(polys)
28 | for i in range(n):
29 | poly = polys[i]
30 | if i == n - 1: # last box is fully visible
31 | ratios.append(1.0)
32 | else:
33 | occluders = polys[i+1:]
34 | union_occlusion = unary_union(occluders)
35 | intersection = poly.intersection(union_occlusion)
36 | occluded_area = intersection.area
37 | visible_area = poly.area - occluded_area
38 | ratio = visible_area / poly.area if poly.area > 0 else 0
39 | ratios.append(ratio)
40 | return ratios
41 |
42 | def sample_box(min_area, max_area, aspect_range=(0.5, 2.0)):
43 | """
44 | Sample a box (in normalized coordinates) with an area in [min_area, max_area] and
45 | an aspect ratio in aspect_range.
46 | Returns [x_min, y_min, x_max, y_max] (all between 0 and 1), or None if failed.
47 | """
48 | # Sample a desired area uniformly
49 | area = random.uniform(min_area, max_area)
50 | # Sample an aspect ratio (width/height)
51 | aspect = random.uniform(aspect_range[0], aspect_range[1])
52 | # Compute width and height from area and aspect ratio:
53 | # area = width * height, and width = aspect * height
54 | # => height = sqrt(area / aspect), width = sqrt(area * aspect)
55 | h = math.sqrt(area / aspect)
56 | w = math.sqrt(area * aspect)
57 | # Ensure the box fits inside [0,1] by sampling a top-left coordinate.
58 | if w > 1 or h > 1:
59 | return None
60 | x_min = random.uniform(0, 1 - w)
61 | y_min = random.uniform(0, 1 - h)
62 | return [x_min, y_min, x_min + w, y_min + h]
63 |
64 | class FineGrainedBoundingBoxLayoutGenerator():
65 | def __init__(self, device=None):
66 | if device is None:
67 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68 | else:
69 | self.device = device
70 |
71 | def generate(self, num_large, num_mid, num_small, width, height,
72 | # COCO area thresholds (normalized)
73 | # For an image of area 1, small: <1024/1e6, medium: [1024/1e6, 9216/1e6), large: >=9216/1e6
74 | # For a 1024x1024 image, total area=1.
75 | small_area_range=((32 * 16)/(640*480), (32 * 32)/(640*480)), # ~[0, 0.0009766)
76 | mid_area_range=((32*32)/(640*480), (96 * 96)/(640*480)), # ~[0.0009766, 0.0087891)
77 | large_area_range=((96 * 96)/(640*480), 0.5), # large boxes: area >=0.0087891 and up to 0.5
78 | aspect_range=(0.5, 2.0),
79 | min_avg_display_area=0.8,
80 | min_single_display_area=0.5
81 | ):
82 | """
83 | Generate a layout of boxes in three groups:
84 | - Large boxes: number=num_large, area in large_area_range
85 | - Mid boxes: number=num_mid, area in mid_area_range
86 | - Small boxes: number=num_small, area in small_area_range
87 |
88 | Boxes are generated sequentially (large first, then mid, then small).
89 | Overlap among boxes (regardless of group) is taken into account via visible area constraints.
90 | """
91 | total_boxes = num_large + num_mid + num_small
92 | max_layout_attempts = total_boxes * 100 # maximum attempts to generate the full layout
93 | layout_attempt = 0
94 |
95 | candidate_bboxes = []
96 | groups = [
97 | (num_large, large_area_range),
98 | (num_mid, mid_area_range),
99 | (num_small, small_area_range)
100 | ]
101 |
102 | while layout_attempt < max_layout_attempts:
103 | candidate_bboxes = []
104 | valid = True
105 | # For each group in order:
106 | for (num, area_range) in groups:
107 | for _ in range(num):
108 | max_box_attempts = 100
109 | box_found = False
110 | for _ in range(max_box_attempts):
111 | candidate_box = sample_box(area_range[0], area_range[1], aspect_range)
112 | if candidate_box is None:
113 | continue
114 | temp_bboxes = candidate_bboxes + [candidate_box]
115 | ratios = compute_visible_ratios(temp_bboxes)
116 | avg_ratio = sum(ratios) / len(ratios)
117 | if avg_ratio >= min_avg_display_area and all(r >= min_single_display_area for r in ratios):
118 | candidate_bboxes.append(candidate_box)
119 | box_found = True
120 | break
121 | if not box_found:
122 | valid = False
123 | break # break out if a box in this group cannot be generated
124 | if not valid:
125 | break
126 | if valid and len(candidate_bboxes) == total_boxes:
127 | return postprocess_bboxes(candidate_bboxes, width, height)
128 | layout_attempt += 1
129 |
130 | raise ValueError(f"After {max_layout_attempts} attempts, failed to generate a layout satisfying the constraints.")
--------------------------------------------------------------------------------
/segmentation_generator/generate_image_and_mask.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | def generate_image_and_mask(data_manager, image_metadata, resize_mode, containSmallObjectMask=False, small_object_ratio=(36 * 36)/(640 * 480)):
5 | width = image_metadata["width"]
6 | height = image_metadata["height"]
7 | # Create a blank canvas and ground truth mask
8 | canvas = Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))
9 | ground_truth_mask = Image.new("L", (width, height), 0)
10 |
11 | # This will collect segment IDs for small objects based on the bounding box area.
12 | small_object_ids = []
13 |
14 | # Add background if available
15 | if "background" in image_metadata:
16 | background_metadata = image_metadata["background"]["background_metadata"]
17 | background_image = data_manager.get_background_by_metadata(background_metadata)['image']
18 | canvas = paste_background(background_image, canvas)
19 |
20 | # Process objects
21 | if "objects" in image_metadata:
22 | for object_info in image_metadata["objects"]:
23 | object_metadata = object_info["object_metadata"]
24 | object_position = object_info["object_position"] # either a bounding box or center point
25 | segment_id = object_info["segment_id"]
26 |
27 | # Use the bounding box area if available (i.e., when object_position has more than 2 elements).
28 | if len(object_position) != 2:
29 | bbox_area = (object_position[2] - object_position[0]) * (object_position[3] - object_position[1])
30 | if bbox_area / (width * height) <= small_object_ratio:
31 | small_object_ids.append(segment_id)
32 | # Optionally, you might decide what to do if only a center point is provided.
33 | # For example, you could skip the small object check or assign a default small area.
34 |
35 | # Get the object image and its mask from the data manager.
36 | object_data_package = data_manager.get_object_by_metadata(object_metadata)
37 | image_obj, mask_obj = object_data_package['image'], object_data_package['mask']
38 |
39 | # If augmentation information is provided, apply the augmentation
40 | aug_params = object_info.get("augmentation", None)
41 | if aug_params is not None:
42 | image_obj, mask_obj = apply_augmentation(image_obj, mask_obj, aug_params)
43 |
44 | # Paste using the appropriate method based on how the object is defined.
45 | if len(object_position) == 2:
46 | paste_segment_wi_center_point(image_obj, mask_obj, canvas, ground_truth_mask, object_position, segment_id)
47 | else:
48 | paste_segment_wi_bbox(image_obj, mask_obj, canvas, ground_truth_mask, object_position, segment_id, resize_mode)
49 |
50 | # If a small object mask is desired, create it by keeping only the visible pixels for small objects.
51 | if containSmallObjectMask:
52 | ground_truth_mask_np = np.array(ground_truth_mask)
53 | # Retain only the pixels whose segment ID is in the small_object_ids list.
54 | small_object_mask_np = np.where(np.isin(ground_truth_mask_np, small_object_ids), ground_truth_mask_np, 0).astype(np.uint8)
55 | return canvas, ground_truth_mask, small_object_mask_np
56 | else:
57 | return canvas, ground_truth_mask
58 |
59 | def paste_background(background, canvas):
60 | # Resize background if necessary
61 | if canvas.size != background.size:
62 | background = background.resize(canvas.size)
63 | return background
64 |
65 | def normalize_mask(mask: Image) -> Image:
66 | """
67 | Convert a mask to mode 'L' and ensure its pixel values are in [0, 255].
68 | """
69 | if mask.mode != 'L':
70 | mask = mask.convert('L')
71 | mask_np = np.array(mask)
72 | if mask_np.max() <= 1:
73 | mask_np = (mask_np * 255).astype(np.uint8)
74 | return Image.fromarray(mask_np)
75 |
76 | def paste_segment_wi_center_point(image: Image, mask: Image, canvas: Image,
77 | ground_truth_mask: Image, position: tuple, segment_id):
78 | # Ensure correct image modes
79 | if image.mode != 'RGB':
80 | image = image.convert('RGB')
81 | mask = normalize_mask(mask)
82 |
83 | # Get the bounding box of the nonzero (segmented) area
84 | bbox = mask.getbbox()
85 | if bbox is None:
86 | raise ValueError("No segmented area found in the mask.")
87 |
88 | # Crop both the image and mask to the bounding box
89 | cropped_image = image.crop(bbox)
90 | cropped_mask = mask.crop(bbox)
91 |
92 | # Compute offset so that the segment's center aligns with the given position
93 | segment_center = ((bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2)
94 | offset = (int(position[0] - segment_center[0]), int(position[1] - segment_center[1]))
95 |
96 | # Build a segment ID mask using vectorized operations (avoid slow Python loops)
97 | cropped_mask_np = np.array(cropped_mask)
98 | segment_mask_np = np.where(cropped_mask_np > 0, segment_id, 0).astype(np.uint8)
99 | segment_id_mask = Image.fromarray(segment_mask_np)
100 |
101 | # Paste the cropped image and the segment ID mask onto their respective canvases
102 | canvas.paste(cropped_image, offset, cropped_mask)
103 | ground_truth_mask.paste(segment_id_mask, offset, cropped_mask)
104 |
105 | def paste_segment_wi_bbox(image, mask, canvas,
106 | ground_truth_mask, bbox, segment_id,
107 | resize_mode):
108 | # Ensure the image is in RGB mode
109 | if image.mode != 'RGB':
110 | image = image.convert('RGB')
111 | mask = normalize_mask(mask)
112 |
113 | # Get the bounding box of the segmented area to crop out unnecessary parts
114 | original_object_bbox = mask.getbbox()
115 | if original_object_bbox is None:
116 | raise ValueError("No segmented area found in the mask.")
117 |
118 | cropped_image = image.crop(original_object_bbox)
119 | cropped_mask = mask.crop(original_object_bbox)
120 |
121 | # Calculate the target region size based on the provided bounding box
122 | x_min, y_min, x_max, y_max = bbox
123 | bbox_width = int(round(x_max - x_min))
124 | bbox_height = int(round(y_max - y_min))
125 |
126 | if resize_mode == "full":
127 | # Stretch the image to fully fit the bounding box
128 | new_width, new_height = bbox_width, bbox_height
129 | offset = (int(round(x_min)), int(round(y_min)))
130 | elif resize_mode == "fit":
131 | # Maintain the aspect ratio of the cropped image
132 | orig_width, orig_height = cropped_image.size
133 | ratio = orig_width / orig_height
134 | # Compute the maximum size that fits within the bounding box while preserving the ratio
135 | if bbox_width / bbox_height > ratio:
136 | new_height = bbox_height
137 | new_width = int(round(bbox_height * ratio))
138 | else:
139 | new_width = bbox_width
140 | new_height = int(round(bbox_width / ratio))
141 | # Center the image within the bounding box
142 | offset = (int(round(x_min + (bbox_width - new_width) / 2)),
143 | int(round(y_min + (bbox_height - new_height) / 2)))
144 | else:
145 | raise ValueError("resize_mode must be 'full' or 'fit'.")
146 |
147 | if new_width <= 0 or new_height <= 0:
148 | new_width = max(1, new_width)
149 | new_height = max(1, new_height)
150 | # fixing now
151 | # Resize the cropped image and mask to the calculated dimensions
152 | resized_image = cropped_image.resize((new_width, new_height), resample=Image.BICUBIC)
153 | resized_mask = cropped_mask.resize((new_width, new_height), resample=Image.NEAREST)
154 |
155 | # Create a segment ID mask using vectorized operations
156 | resized_mask_np = np.array(resized_mask)
157 | segment_id_mask_np = np.where(resized_mask_np > 0, segment_id, 0).astype(np.uint8)
158 | segment_id_mask = Image.fromarray(segment_id_mask_np)
159 |
160 | # Paste the resized image and segment ID mask onto the canvas and ground truth mask
161 | canvas.paste(resized_image, offset, resized_mask)
162 | ground_truth_mask.paste(segment_id_mask, offset, resized_mask)
163 |
164 | def apply_augmentation(image: Image, mask: Image, aug_params: dict):
165 | """
166 | Apply scaling, horizontal and vertical flips, and rotation to both the image and its mask.
167 | Scaling uses bicubic interpolation for the image and nearest for the mask;
168 | rotation is performed with expansion to preserve the full transformed segment.
169 | """
170 | scale = aug_params.get("scale", 1.0)
171 | flip_horizontal = aug_params.get("flip_horizontal", False)
172 | flip_vertical = aug_params.get("flip_vertical", False)
173 | rotate_angle = aug_params.get("rotate", 0)
174 | # Scaling: resize the image and mask if the scale factor is different from 1.0
175 | if scale != 1.0:
176 | new_size = (int(round(image.width * scale)), int(round(image.height * scale)))
177 | image = image.resize(new_size, resample=Image.BICUBIC)
178 | mask = mask.resize(new_size, resample=Image.NEAREST)
179 |
180 | # Horizontal flip: if the flag is True, perform a horizontal flip
181 | if flip_horizontal:
182 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
183 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
184 |
185 | # Vertical flip: if the flag is True, perform a vertical flip
186 | if flip_vertical:
187 | image = image.transpose(Image.FLIP_TOP_BOTTOM)
188 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
189 |
190 | # Rotation: if the rotation angle is non-zero, rotate the image and mask.
191 | if rotate_angle != 0:
192 | image = image.rotate(rotate_angle, expand=True, resample=Image.BICUBIC)
193 | mask = mask.rotate(rotate_angle, expand=True, resample=Image.NEAREST)
194 |
195 | return image, mask
196 |
197 |
--------------------------------------------------------------------------------
/datasets/panoptic_coco_categories.json:
--------------------------------------------------------------------------------
1 | [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}]
2 |
--------------------------------------------------------------------------------
/segmentation_generator/panoptic_coco_categories.json:
--------------------------------------------------------------------------------
1 | [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}]
2 |
--------------------------------------------------------------------------------
/datasets/coco_categories.py:
--------------------------------------------------------------------------------
1 | COCO_CATEGORIES = [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}]
2 | THING_CATEGORIES = [cat for cat in COCO_CATEGORIES if cat["isthing"] == 1]
--------------------------------------------------------------------------------
/scripts/generate_with_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Switch execution environment to parent directory
5 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
6 | os.chdir(project_root)
7 | sys.path.insert(0, project_root)
8 |
9 | from datasets.object_datasets import *
10 | from datasets.background_datasets import *
11 | from datasets.data_manager import DataManager
12 | from segmentation_generator.segmentataion_synthesizer import RandomCenterPointSegmentationSynthesizer, FineGrainedBoundingBoxSegmentationSynthesizer
13 | import numpy as np
14 | import json
15 | from PIL import Image
16 | import os
17 | import random
18 | import argparse
19 | from tqdm import tqdm
20 | import concurrent.futures
21 |
22 |
23 | # Filter conditions
24 | filtering_setting_0 = {
25 | "IsInstance": "non-filter",
26 | "Integrity": "non-filter",
27 | "QualityAndRegularity": "non-filter"
28 | }
29 | filtering_setting_1 = {
30 | "IsInstance": "filter",
31 | "Integrity": "non-filter",
32 | "QualityAndRegularity": "non-filter"
33 | }
34 | filtering_setting_2 = {
35 | "IsInstance": "non-filter",
36 | "Integrity": "filter",
37 | "QualityAndRegularity": "non-filter"
38 | }
39 | filtering_setting_3 = {
40 | "IsInstance": "non-filter",
41 | "Integrity": "non-filter",
42 | "QualityAndRegularity": "filter"
43 | }
44 | filtering_setting_4 = {
45 | "IsInstance": "filter",
46 | "Integrity": "filter",
47 | "QualityAndRegularity": "filter"
48 | }
49 |
50 |
51 | def reset_folders(paths):
52 | """Deletes and recreates specified directories to ensure a clean slate."""
53 | for path in paths:
54 | if os.path.exists(path):
55 | # Uncomment below to delete existing folders if needed:
56 | # shutil.rmtree(path)
57 | pass
58 | if not os.path.exists(path):
59 | print(f"Creating: {path}")
60 | os.makedirs(path, exist_ok=True)
61 |
62 |
63 | def prepare_folders(image_save_path, mask_save_path, annotation_path):
64 | reset_folders([image_save_path, mask_save_path, annotation_path])
65 |
66 |
67 | class NumpyEncoder(json.JSONEncoder):
68 | """Custom JSON encoder to handle NumPy data types."""
69 | def default(self, obj):
70 | if isinstance(obj, np.integer):
71 | return int(obj)
72 | return super(NumpyEncoder, self).default(obj)
73 |
74 | # without bg, for relightening and blending
75 | def initialize_global_categories(annotation_path, filtering_setting, global_data_manager):
76 | """
77 | Initializes the global categories from the DataManager.
78 | This function runs in the main process so that the global categories
79 | are available when merging annotation JSONs.
80 | """
81 | categories = global_data_manager.unified_object_categories
82 | categories_to_idx = {category: idx for idx, category in enumerate(categories)}
83 | idx_to_categories = {idx: category for idx, category in enumerate(categories)}
84 | global_categories = [{"id": idx, "name": category} for idx, category in enumerate(categories)]
85 |
86 | # Save category files for reference.
87 | json.dump(categories, open(os.path.join(annotation_path, "categories.json"), "w"), indent=4)
88 | json.dump(categories_to_idx, open(os.path.join(annotation_path, "categories_to_idx.json"), "w"), indent=4)
89 | json.dump(idx_to_categories, open(os.path.join(annotation_path, "idx_to_categories.json"), "w"), indent=4)
90 |
91 | return global_categories
92 |
93 |
94 | def process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager):
95 | """
96 | Worker function that processes a batch of images with a unique seed.
97 | For each image, the corresponding annotation is saved as a separate JSON file.
98 | """
99 | random.seed(worker_seed)
100 | np.random.seed(worker_seed)
101 |
102 | rss = FineGrainedBoundingBoxSegmentationSynthesizer(data_manager, "./", random_seed=worker_seed)
103 |
104 | # Create a subfolder for separate annotations
105 | separate_annotation_path = os.path.join(annotation_path, "separate_annotations")
106 | if not (os.path.exists(separate_annotation_path) and os.path.isdir(separate_annotation_path)):
107 | os.makedirs(separate_annotation_path, exist_ok=True)
108 |
109 | for i in range(start_idx, end_idx):
110 | image_name = f"{i}.png"
111 | mask_name = f"{i}.png"
112 | image_path_full = os.path.join(image_save_path, image_name)
113 | mask_path_full = os.path.join(mask_save_path, mask_name)
114 | separate_annot_file = os.path.join(separate_annotation_path, f"{i}.json")
115 |
116 | # Skip processing if image, mask, and annotation all exist.
117 | if os.path.exists(image_path_full) and os.path.exists(mask_path_full) and os.path.exists(separate_annot_file):
118 | queue.put(1)
119 | continue
120 |
121 | obj_nums = np.random.randint(5, 20)
122 | data = rss.sampling_metadata(1024, 1024, obj_nums, hasBackground=False, dataAugmentation=False)
123 | res = rss.generate_with_unified_format(data, containRGBA=True, containCategory=True, containSmallObjectMask=False, resize_mode="fit")
124 |
125 | # Save image and mask.
126 | Image.fromarray(res['image_rgba']).save(image_path_full)
127 | Image.fromarray(res['coco_mask']).save(mask_path_full)
128 |
129 | # Build annotation for this image.
130 | annotation = {
131 | "segments_info": res["segments_info"],
132 | "file_name": mask_name,
133 | "image_id": i,
134 | }
135 | # Save the annotation as a separate JSON file.
136 | with open(separate_annot_file, "w") as f:
137 | json.dump(annotation, f, cls=NumpyEncoder, indent=4)
138 |
139 | queue.put(1)
140 |
141 | # with bg
142 | # def process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager):
143 | # """
144 | # Worker function that processes a batch of images with a unique seed.
145 | # For each image, the corresponding annotation is saved as a separate JSON file.
146 | # """
147 | # random.seed(worker_seed)
148 | # np.random.seed(worker_seed)
149 |
150 | # rss = FineGrainedBoundingBoxSegmentationSynthesizer(data_manager, "./", random_seed=worker_seed)
151 |
152 | # # Create a subfolder for separate annotations
153 | # separate_annotation_path = os.path.join(annotation_path, "separate_annotations")
154 | # if not (os.path.exists(separate_annotation_path) and os.path.isdir(separate_annotation_path)):
155 | # os.makedirs(separate_annotation_path, exist_ok=True)
156 |
157 | # for i in range(start_idx, end_idx):
158 | # image_name = f"{i}.jpg"
159 | # mask_name = f"{i}.png"
160 | # image_path_full = os.path.join(image_save_path, image_name)
161 | # mask_path_full = os.path.join(mask_save_path, mask_name)
162 | # separate_annot_file = os.path.join(separate_annotation_path, f"{i}.json")
163 |
164 | # # Skip processing if image, mask, and annotation all exist.
165 | # if os.path.exists(image_path_full) and os.path.exists(mask_path_full) and os.path.exists(separate_annot_file):
166 | # queue.put(1)
167 | # continue
168 |
169 | # obj_nums = np.random.randint(5, 20)
170 | # data = rss.sampling_metadata(1024, 1024, obj_nums, hasBackground=True, dataAugmentation=False)
171 | # res = rss.generate_with_unified_format(data, containRGBA=True, containCategory=True, containSmallObjectMask=False, resize_mode="fit")
172 |
173 | # # Save image and mask.
174 | # res['image'].save(image_path_full)
175 | # Image.fromarray(res['coco_mask']).save(mask_path_full)
176 |
177 | # # Build annotation for this image.
178 | # annotation = {
179 | # "segments_info": res["segments_info"],
180 | # "file_name": mask_name,
181 | # "image_id": i,
182 | # }
183 | # # Save the annotation as a separate JSON file.
184 | # with open(separate_annot_file, "w") as f:
185 | # json.dump(annotation, f, cls=NumpyEncoder, indent=4)
186 |
187 | # queue.put(1)
188 |
189 |
190 | def listener(queue, total):
191 | """Updates the tqdm progress bar as images are generated."""
192 | pbar = tqdm(total=total)
193 | while True:
194 | msg = queue.get()
195 | if msg == 'kill':
196 | break
197 | pbar.update(msg)
198 | pbar.close()
199 |
200 |
201 |
202 |
203 | class NumpyEncoder(json.JSONEncoder):
204 | """Custom JSON encoder to handle NumPy data types."""
205 | def default(self, obj):
206 | try:
207 | import numpy as np
208 | if isinstance(obj, np.integer):
209 | return int(obj)
210 | except ImportError:
211 | pass
212 | return super(NumpyEncoder, self).default(obj)
213 |
214 | # Define process_file at the module level so it is pickleable
215 | def process_file(annot_file):
216 | try:
217 | with open(annot_file, "r") as f:
218 | annot = json.load(f)
219 | # Build image info assuming a fixed size (1024x1024)
220 | image_info = {
221 | "file_name": f"{annot['image_id']}.jpg",
222 | "height": 1024,
223 | "width": 1024,
224 | "id": annot["image_id"]
225 | }
226 | return annot, image_info
227 | except Exception:
228 | return None
229 |
230 | def merge_annotation_jsons(annotation_path, json_save_path, categories):
231 | """
232 | Merges separate annotation JSON files (stored in the "separate_annotations" subfolder)
233 | into a final COCO-format JSON. Also constructs the images list based on the annotation files.
234 | This version uses multiprocessing via a ProcessPoolExecutor with os.scandir for efficient directory listing.
235 | """
236 | separate_annotation_path = os.path.join(annotation_path, "separate_annotations")
237 | # Use os.scandir for efficient directory traversal
238 | json_files = [entry.path for entry in os.scandir(separate_annotation_path)
239 | if entry.is_file() and entry.name.endswith(".json")]
240 |
241 | annotations = []
242 | images = []
243 |
244 | # Adjust max_workers to a number more appropriate for your system (e.g., 4 to 8)
245 | max_workers = 100
246 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
247 | results = list(tqdm(executor.map(process_file, json_files),
248 | total=len(json_files),
249 | desc="Merging annotations"))
250 | for res in results:
251 | if res is not None:
252 | annot, image_info = res
253 | annotations.append(annot)
254 | images.append(image_info)
255 |
256 | # Sort the lists based on their IDs
257 | images = sorted(images, key=lambda x: x["id"])
258 | annotations = sorted(annotations, key=lambda x: x["image_id"])
259 |
260 | coco_json = {
261 | "images": images,
262 | "annotations": annotations,
263 | "categories": categories
264 | }
265 |
266 | with open(json_save_path, "w") as f:
267 | json.dump(coco_json, f, cls=NumpyEncoder, indent=4)
268 |
269 |
270 | def run_generation(args):
271 | """
272 | Runs the synthetic image generation process.
273 | Each worker saves its annotation into a separate JSON file.
274 | In the end, all these JSON files are merged into one final JSON.
275 | """
276 | image_save_path = args.image_save_path
277 | mask_save_path = args.mask_save_path
278 | annotation_path = args.annotation_path
279 | json_save_path = args.json_save_path
280 |
281 | prepare_folders(image_save_path, mask_save_path, annotation_path)
282 |
283 | # Set the filtering setting based on argument
284 | fs = str(args.filtering_setting)
285 | if fs.isdigit():
286 | fs = "filter_" + fs
287 | if fs == 'filter_0':
288 | filtering_setting = filtering_setting_0
289 | elif fs == 'filter_1':
290 | filtering_setting = filtering_setting_1
291 | elif fs == 'filter_2':
292 | filtering_setting = filtering_setting_2
293 | elif fs == 'filter_3':
294 | filtering_setting = filtering_setting_3
295 | else:
296 | filtering_setting = filtering_setting_4
297 |
298 | # Initialize global DataManager instance in the main process.
299 | global_data_manager = DataManager(available_object_datasets, available_background_datasets, filtering_setting)
300 |
301 | # Initialize global categories in the main process by passing global_data_manager.
302 | global_categories = initialize_global_categories(annotation_path, filtering_setting, global_data_manager)
303 |
304 | num_images = args.total_images
305 | num_workers = args.num_processes
306 |
307 | chunk_size = num_images // num_workers
308 | remainder = num_images % num_workers
309 |
310 | manager = multiprocessing.Manager()
311 | queue = manager.Queue()
312 |
313 | listener_process = multiprocessing.Process(target=listener, args=(queue, num_images))
314 | listener_process.start()
315 |
316 | processes = []
317 | start_index = 0
318 |
319 | for i in range(num_workers):
320 | end_index = start_index + chunk_size
321 | if i < remainder:
322 | end_index += 1
323 |
324 | worker_seed = random.randint(0, int(1e6))
325 |
326 | def worker_task(start_idx=start_index, end_idx=end_index, worker_seed=worker_seed):
327 | process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager=global_data_manager)
328 |
329 | p = multiprocessing.Process(target=worker_task)
330 | processes.append(p)
331 | p.start()
332 | start_index = end_index
333 |
334 | for p in processes:
335 | p.join()
336 |
337 | queue.put('kill')
338 | listener_process.join()
339 |
340 | # Merge individual annotation JSON files into the final COCO JSON.
341 | merge_annotation_jsons(annotation_path, json_save_path, global_categories)
342 |
343 |
344 | def main():
345 | parser = argparse.ArgumentParser(
346 | description="Synthetic Image Generation with Segmentation Annotations"
347 | )
348 | parser.add_argument('--num_processes', type=int, default=100,
349 | help='Number of processes to use')
350 | parser.add_argument('--total_images', type=int, default=1000,
351 | help='Total number of images to generate')
352 | parser.add_argument('--filtering_setting', type=str, default='filter_4',
353 | help='Filtering setting to apply. Either "filter_1", "filter_2", "filter_3", "filter_4", or a digit 1-4.')
354 | parser.add_argument('--image_save_path', type=str,
355 | default="/output/train",
356 | help="Path to save images")
357 | parser.add_argument('--mask_save_path', type=str,
358 | default="/output/panoptic_train",
359 | help="Path to save masks")
360 | parser.add_argument('--annotation_path', type=str,
361 | default="/output/annotations",
362 | help="Path to save separate annotation JSON files")
363 | parser.add_argument('--json_save_path', type=str,
364 | default="/output/annotations/panoptic_train.json",
365 | help="Path to save the merged COCO panoptic JSON")
366 | args = parser.parse_args()
367 |
368 | run_generation(args)
369 |
370 |
371 | if __name__ == "__main__":
372 |
373 | # using FC split data
374 | available_object_datasets = {
375 | "Synthetic": SyntheticDataset(
376 | dataset_path="/fc_10m",
377 | synthetic_annotation_path="/fc_10m/gc_object_segments_metadata.json",
378 | dataset_name="Synthetic",
379 | cache_path="./metadata_ovd_cache"
380 | ),
381 | }
382 |
383 | # using GC split data
384 | # available_object_datasets = {
385 | # "Synthetic": SyntheticDataset(
386 | # dataset_path="/gc_10m",
387 | # synthetic_annotation_path="/gc_10m/gc_object_segments_metadata.json",
388 | # dataset_name="Synthetic",
389 | # cache_path="./metadata_ovd_cache"
390 | # ),
391 | # }
392 |
393 | # mixing data from different source
394 | # available_object_datasets = {
395 | # "Synthetic_fc": SyntheticDataset(
396 | # dataset_path="/fc_10m",
397 | # synthetic_annotation_path="/fc_10m/gc_object_segments_metadata.json",
398 | # dataset_name="Synthetic_fc",
399 | # cache_path="./metadata_fc_cache"
400 | # ),
401 | # "Synthetic_gc": SyntheticDataset(
402 | # dataset_path="/gc_10m",
403 | # synthetic_annotation_path="/gc_10m/gc_object_segments_metadata.json",
404 | # dataset_name="Synthetic_gc",
405 | # cache_path="./metadata_gc_cache"
406 | # ),
407 | # }
408 |
409 | # By default, this loads from a placeholder dataset containing only a single image using `background`.
410 | # To use `background` in BG20k, please set up the dataset path.
411 | available_background_datasets = {
412 | "BG20k": BG20KDataset("datasets/one_image_bg")
413 | # "BG20k": BG20KDataset("/your/path/to/bg20k_dataset")
414 |
415 | }
416 |
417 | main()
418 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Synthetic Object Compositions for Scalable and Accurate Learning in Detection, Segmentation, and Grounding
3 |
4 |
5 |
6 | 2 Million Diverse, Accurate Synthetic Dense-Annotated Images (FC-1M + GC-1M) + 20M Synthetic Object Segments to Supercharge Grounding-DINO, Mask2Former, and Any Detectors / Segmentors / Grounding-VLMs
7 |
8 |
9 |
10 |
15 |
16 |
17 | Weikai Huang1, Jieyu Zhang1,
18 | Taoyang Jia1, Chenhao Zheng1, Ziqi Gao1,
19 | Jae Sung Park1, Ranjay Krishna1,2
20 | 1  University of Washington  
21 | 2Allen Institute for AI
22 |
23 |
24 | ---
25 |
26 |
27 |
28 |
29 | A scalable pipeline for composing high-quality synthetic object segments into richly annotated images for object detection, instance segmentation, and visual grounding.
30 |
31 | ## 🌟 Highlights
32 |
33 | **Why SOC?** A small amount of high-quality synthetic data can outperform orders of magnitude more real data:
34 |
35 | - 🚀 **Efficient & Scalable**: Just **50K** SOC images match the gains from **20M** model-generated (GRIT) or **200K** human-annotated (V3Det) images on LVIS detection. We compose 2 million diverse images (FC-1M + GC-1M) with annotations, along with 20 million synthetic object segments across 47,000+ categories from Flux.
36 | - 🎯 **Accurate Annotations**: Object-centric composition provides pixel-perfect masks, boxes, and referring expressions—no noisy pseudo-labels
37 | - 🎨 **Controllable Generation**: Synthesize targeted data for specific scenarios (e.g., intra-class referring, rare categories, domain-specific applications)
38 | - 🔄 **Complementary to Real Data**: Adding SOC to existing datasets (COCO, LVIS, V3Det, GRIT) yields consistent additive gains across all benchmarks
39 | - 💰 **Cost-Effective**: Generate unlimited training data from 20M object segments without expensive human annotation
40 | - 📈 **100K SOC surpasses larger real-data baselines**: +10.9 LVIS AP (OVD) and +8.4 gRefCOCO NAcc (VG); and remains complementary when combined with GRIT/V3Det
41 |
42 |
43 | ---
44 |
45 | # 📊 Released Datasets
46 |
47 | We release the following datasets for research use:
48 |
49 | | Dataset Name | # Images | # Categories | Description | Download |
50 | |-------------|----------|--------------|-------------|----------|
51 | | **FC-1M** | 1,000,000 | 1,600 | Frequent Categories | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-FC-1M) |
52 | | **GC-1M** | 1,000,000 | 47,000+ | General Categories | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-GC-1M) |
53 | | **SFC-200K** | 200,000 | 1,600 | Single-category Frequent Category — same category objects with varied attributes | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-SFC-200K) |
54 | | **SGC-200K** | 200,000 | 47,000+ | Single-category General Category — same category objects with varied attributes | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-SGC-200K) |
55 |
56 |
57 | Examples of dataset types:
58 |
59 |
60 |
61 | 
62 | FC / GC
63 | |
64 |
65 | 
66 | SFC / SGC
67 | |
68 |
69 |
70 |
71 | **All datasets include:**
72 | - ✅ High-resolution images with photorealistic relighting and blending
73 | - ✅ Pixel-perfect segmentation masks
74 | - ✅ Tight bounding boxes
75 | - ✅ Category labels
76 | - ✅ Diverse referring expressions (attribute-based, spatial-based, and mixed)
77 |
78 | **Note:** Other dataset variants (e.g., SOC-LVIS, MixCOCO) contain segments from existing datasets and cannot be released. Please use the code in this repository to compose your own datasets from the released object segments.
79 |
80 | ## Object Segments
81 |
82 | We also release **20M synthetic object segments** used to compose the above datasets:
83 |
84 | | Segment Set | # Segments | # Categories | Prompts/Category | Segments/Prompt | Download |
85 | |-------------|------------|--------------|------------------|-----------------|----------|
86 | | **FC Object Segments** | 10,000,000 | 1,600 | 200 | 3 | [🤗 SOC-FC-Object-Segments-10M](https://huggingface.co/datasets/weikaih/SOC-FC-Object-Segments-10M) |
87 | | **GC Object Segments** | 10,000,000 | 47,000+ | 10 | 3 | [🤗 SOC-GC-Object-Segments-10M](https://huggingface.co/datasets/weikaih/SOC-GC-Object-Segments-10M) |
88 |
89 | Browse all sets via the collection: [🤗 HuggingFace Collection](https://huggingface.co/collections/weikaih/SOC-synthetic-object-segments-improves-detection-segmentat-682679751d20faa20800033c)
90 |
91 | ---
92 |
93 | # 📦 Installation
94 |
95 | *Notice*: We provide only minimal guidance for the core parts of the codebase for: image composing, relighting and blending, and referring expression generation. The full documentation (with an accompanying arXiv paper) covering additional tasks and case studies will be released soon.
96 |
97 | ## Environment Setup
98 | Follow the steps below to set up the environment and use the repository:
99 | ```bash
100 | # Clone the repository
101 | git clone https://github.com/weikaih04/SOC
102 | cd ./SOC
103 |
104 | # Create and activate a Python virtual environment:
105 | conda create -n SOC python==3.10
106 | conda activate SOC
107 |
108 | # Install the required dependencies for composing images with synthetic object segments:
109 | pip install -r requirements.txt
110 |
111 | # If you want to perform relighting and blending:
112 | conda create -n SOC-relight python==3.10
113 | conda activate SOC-relight
114 | pip install -r requirements_relight_and_blending.txt
115 |
116 | # If you want to generating referring expression:
117 | conda create -n SOC-ref python==3.10
118 | conda activate SOC-ref
119 | pip install -r requirements_referring_expression_generation.txt
120 | ```
121 |
122 |
123 | ## Background Dataset (Optional)
124 | If you want to relight images and didn't directly paste object segments into the background, just use a random image as the background and set the `hasBackground` to false in `scripts/generate_with_batch.py`
125 |
126 | You can download the BG-20K from this repo: https://github.com/JizhiziLi/GFM.git
127 |
128 | # 🚀 Usage
129 |
130 | ## Composing Synthetic Images
131 | We provide scripts to compose images with synthetic segments:
132 |
133 | If you want to generate images for relighting and blending that only contain foreground object segments:
134 | ```bash
135 | python scripts/generate_with_batch.py \
136 | --num_processes 100 \ # depends on your CPUs
137 | --total_images 100000 \
138 | --filtering_setting filter_0 \
139 | --image_save_path "/output/dataset_name/train" \
140 | --mask_save_path "/output/dataset_name/panoptic_train" \
141 | --annotation_path "/output/dataset_name/annotations" \
142 | --json_save_path "/output/dataset_name/annotations/panoptic_train.json"
143 | ```
144 |
145 |
146 | ### Key parameters
147 | - --num_processes: Number of parallel workers to generate images; set based on CPU cores.
148 | - --total_images: Total images to generate.
149 | - --filtering_setting: One of filter_0..filter_4 (filter_4 = strictest). Controls segment quality filters.
150 | - --image_save_path: Output path for rendered RGBA images (PNG).
151 | - --mask_save_path: Output path for color panoptic masks (PNG).
152 | - --annotation_path: Output folder for per-image JSONs and category maps.
153 | - --json_save_path: Final merged COCO-style panoptic JSON path.
154 |
155 | Important: At the end of scripts/generate_with_batch.py, available_object_datasets must point to your local copies of released FC/GC object segments and their metadata JSON. For example, if you downloaded SOC-FC-Object-Segments-10M to /data/fc_10m with metadata fc_object_segments_metadata.json, set:
156 | - dataset_path="/data/fc_10m"
157 | - synthetic_annotation_path="/data/fc_10m/fc_object_segments_metadata.json"
158 | Similarly for GC: gc_object_segments_metadata.json
159 |
160 |
161 |
162 |
163 | Notes
164 | - We expect dataset_path to contain category/subcategory/ID.png structure as provided in our released object-segment datasets.
165 | - The script writes per-image JSONs under annotation_path/separate_annotations and merges them into the final COCO-style panoptic JSON at json_save_path.
166 |
167 | Minimal example
168 | ```bash
169 | # Symlink your datasets to the default paths expected by the script (optional)
170 | ln -s /data/fc_10m /fc_10m
171 | ln -s /data/gc_10m /gc_10m
172 |
173 | # Generate a tiny sample dataset locally
174 | python scripts/generate_with_batch.py \
175 | --num_processes 4 \
176 | --total_images 20 \
177 | --filtering_setting filter_0 \
178 | --image_save_path "./out/train" \
179 | --mask_save_path "./out/panoptic_train" \
180 | --annotation_path "./out/annotations" \
181 | --json_save_path "./out/annotations/panoptic_train.json"
182 | ```
183 |
184 | If you want to generate images that directly paste objects onto backgrounds, uncomment the `with bg process_image_worker` function in `scripts/generate_with_batch.py`.
185 |
186 | ## Relighting and Blending
187 | Relight and blend images using IC-Light with mask-area-weighted blending to enhance photorealism while preserving object details and colors:
188 |
189 | ```bash
190 | python relighting_and_blending/inference.py \
191 | --dataset_path "$DATASET_PATH" \
192 | --output_data_path "$OUTPUT_DATA_PATH" \
193 | --num_splits "$NUM_SPLITS" \
194 | --split "$SPLIT" \
195 | --index_json_path "" \
196 | --illuminate_prompts_path "$ILLUMINATE_PROMPTS_PATH" \
197 | --record_path "$RECORD_PATH"
198 | ```
199 |
200 | Currently supports Google Cloud Storage access and local file system.
201 |
202 | Notes
203 | - Requires a CUDA GPU. Models load in half precision; 12GB+ VRAM recommended.
204 | - Weights auto-download on first run:
205 | - Stable Diffusion components from stablediffusionapi/realistic-vision-v51
206 | - Background remover briaai/RMBG-1.4
207 | - IC-Light offset iclight_sd15_fc.safetensors (downloaded to ./models if missing)
208 | - Input expectations:
209 | - dataset_path should point to the folder with RGBA foreground PNGs (e.g., ./out/train) named 0.png, 1.png, ...
210 | - A matching color panoptic mask must exist at the same id under dataset_path with "train" replaced by "panoptic_train" (e.g., ./out/panoptic_train/0.png)
211 | - illuminate_prompts_path must be a JSON file containing an array of prompt strings for relighting.
212 |
213 | Minimal example
214 | ````bash
215 | # Create a tiny illumination prompt list
216 | cat > ./illumination_prompt.json << 'JSON'
217 | [
218 | "golden hour lighting, soft shadows",
219 | "overcast daylight, diffuse light",
220 | "studio softbox lighting"
221 | ]
222 | JSON
223 |
224 | # Relight a small sample from the composed outputs
225 | python relighting_and_blending/inference.py \
226 | --dataset_path ./out/train \
227 | --output_data_path ./out/relit \
228 | --num_splits 1 \
229 | --split 0 \
230 | --illuminate_prompts_path ./illumination_prompt.json
231 | ````
232 |
233 |
234 | ## Referring Expression Generation
235 | We use an OpenAI-compatible endpoint (vLLM) and query a local model.
236 |
237 | Step 1) Start an OpenAI-compatible server (port 8080)
238 | ````bash
239 | # Example: start vLLM OpenAI server with the model used in our script
240 | python -m vllm.entrypoints.openai.api_server \
241 | --model Qwen/QwQ-32B-AWQ \
242 | --host 0.0.0.0 \
243 | --port 8080
244 | ````
245 | Notes
246 | - Our script currently assumes base_url=http://localhost:8080/v1.
247 | - Ensure your GPU/driver supports the chosen model; adjust model name if needed.
248 |
249 | Step 2) Run the generator
250 | ````bash
251 | # INPUT_FILE is the merged COCO-style JSON from the composing stage
252 | # OUTPUT_DIR will contain jsonl shards (one per job): job_0.jsonl, ...
253 | export OPENAI_API_KEY=dummy_key # any non-empty string is accepted
254 |
255 | python referring_expression_generation/inference.py \
256 | 1 \
257 | 0 \
258 | ./out/annotations/panoptic_train.json \
259 | ./out/refexp \
260 | --api_key "$OPENAI_API_KEY" \
261 | --num_workers 8
262 | ````
263 | Outputs
264 | - At least 9 expressions per image (balanced across attribute/spatial/reasoning, single/multi).
265 | - Writes per-job jsonl files under OUTPUT_DIR.
266 | - Supports local paths and GCS (gs://) for both inputs and outputs.
267 |
268 | # 📈 Results
269 |
270 | ## Task 1: Open-Vocabulary Object Detection
271 |
272 | **Model**: MM-Grounding-DINO | **Benchmarks**: LVIS v1.0 Full Val, OdinW-35
273 |
274 |
275 |
276 |
277 |
278 | ### Key Findings
279 |
280 | #### 🎯 Small Amount of SOC Efficiently Brings Strong Gains
281 | With only **50K** synthetic images, SOC delivers gains comparable to orders of magnitude more real data:
282 |
283 | | Training Data | LVIS AP | APrare | Gain vs Baseline |
284 | |--------------|---------|-------------------|------------------|
285 | | Object365+GoldG (Baseline) | 20.1 | 10.1 | - |
286 | | + GRIT (20M images) | 27.1 | 17.1 | +7.0 AP |
287 | | + V3Det (200K images) | 30.6 | 24.6 | +10.5 AP |
288 | | **+ SOC-50K** | **29.8** | **23.5** | **+9.7 AP** |
289 |
290 | **SOC-50K matches V3Det's gains with 400× fewer images!**
291 |
292 | #### 📊 Scaling Up SOC Data Leads to Better Performance
293 | Continuous improvements as we scale from 50K → 100K → 400K:
294 |
295 | | SOC Scale | LVIS AP | APrare | OdinW-35 mAP |
296 | |-----------|---------|-------------------|--------------|
297 | | 50K | 29.8 | 23.5 | 21.0 |
298 | | 100K | 31.0 (+1.2) | 26.3 (+2.8) | 21.0 |
299 | | 400K | **31.4 (+1.6)** | **27.9 (+1.6)** | **22.8 (+1.8)** |
300 |
301 | #### 🔄 SOC is Complementary to Real Datasets
302 | Adding SOC on top of large-scale real datasets yields additive gains:
303 |
304 | | Training Data | LVIS AP | APrare | OdinW-35 mAP |
305 | |--------------|---------|-------------------|--------------|
306 | | Object365+GoldG+V3Det+GRIT | 31.9 | 23.6 | - |
307 | | **+ SOC-100K** | **33.2 (+1.3)** | **29.8 (+6.2)** | **+2.8** |
308 |
309 | SOC introduces novel vocabulary and contextual variations not captured by existing real datasets.
310 |
311 | ## Task 2: Visual Grounding
312 |
313 | **Model**: MM-Grounding-DINO | **Benchmarks**: RefCOCO/+/g, gRefCOCO, DoD
314 |
315 |
316 |
317 |
318 |
319 | ### Key Findings
320 |
321 | #### ⚠️ Existing Large Detection and Grounding Datasets Yield Only Marginal Improvements
322 | Large-scale real datasets provide limited gains for referring expression tasks:
323 |
324 | | Training Data | gRefCOCO P@1 | gRefCOCO NAcc | DoD FULL mAP |
325 | |--------------|--------------|--------------------------|--------------|
326 | | Object365+GoldG | - | 89.3 | - |
327 | | + V3Det (200K) | +0.5 | +0.0 | - |
328 | | + GRIT (20M) | - | - | +1.4 |
329 |
330 | **Why?** V3Det lacks sentence-level supervision; GRIT uses noisy model-generated caption-box pairs.
331 |
332 | #### ✨ SOC Provides Diverse, High-Quality Referring Expressions
333 | SOC generates precise referring pairs from ground truth annotations without human labels:
334 |
335 | | Training Data | gRefCOCO NAcc | DoD FULL mAP | Gain |
336 | |--------------|--------------------------|--------------|------|
337 | | Object365+GoldG | 89.3 | - | Baseline |
338 | | **+ SOC-50K** | **93.9 (+4.6)** | **+1.0** | 50K images |
339 | | **+ SOC-100K** | **97.7 (+8.4)** | **+3.8** | 100K images |
340 |
341 | **Expression Types** (3-6 per type, balanced coverage):
342 | - **Attribute-based**: "the red apple", "charcoal-grey cat"
343 | - **Spatial-based**: "dog to the right of the bike"
344 | - **Mixed-type**: "red object to the right of the child"
345 |
346 | SOC's gains per example far outperform GRIT (20M) and V3Det (200K)!
347 |
348 | ## Task 3: Instance Segmentation
349 |
350 | **Model**: APE (LVIS pre-trained) | **Benchmark**: LVIS v1.0 Val
351 |
352 |
353 |
354 |
355 |
356 | ### Key Findings
357 |
358 | #### 🎯 SOC Continuously Improves LVIS Segmentation
359 | Two-stage fine-tuning: (1) Train on 50K SOC-LVIS → (2) Continue on LVIS train split
360 |
361 | | Training Protocol | AP | APrare | APcommon | APfrequent |
362 | |------------------|-------|-------------------|---------------------|----------------------|
363 | | LVIS only | 46.96 | 40.87 | - | - |
364 | | **SOC-50K → LVIS** | **48.48 (+1.52)** | **44.70 (+3.83)** | - | **(+0.31)** |
365 |
366 | **Why the large rare-class gain?** Synthetic data can be generated to cover underrepresented classes, mitigating LVIS's long-tail imbalance. Frequent classes already have ample real examples and benefit less.
367 |
368 | ---
369 |
370 | ## Task 4: Small-Vocabulary, Limited-Data Regimes
371 |
372 | **Model**: Mask2Former-ResNet-50 | **Benchmark**: COCO Instance Segmentation
373 |
374 | ### Key Findings
375 |
376 | #### 💰 SOC Excels in Low-Data Regimes
377 | Mixing real COCO segments with SOC synthetic segments (80 COCO categories):
378 |
379 | | COCO Data Scale | COCO Only | COCO + SOC | Gain |
380 | |----------------|-----------|------------|------|
381 | | 1% (~1K images) | - | - | **+6.59 AP** |
382 | | 10% (~10K images) | - | - | **~+3 AP** |
383 | | 50% (~50K images) | - | - | **~+3 AP** |
384 | | 100% (Full) | - | - | **~+3 AP** |
385 |
386 | **Key Insight**: The boost is particularly dramatic at 1% COCO (+6.59 AP), and grows by roughly 3% at each subsequent data scale. SOC is most effective when real data is scarce!
387 |
388 | ---
389 |
390 | ## Task 5: Intra-Class Referring Expression
391 |
392 | **Model**: MM-Grounding-DINO | **Benchmark**: Custom intra-class benchmark (COCO + OpenImages V7)
393 |
394 |
395 |
396 |
397 |
398 | ### What is Intra-Class Referring?
399 | A challenging visual grounding task requiring fine-grained attribute discrimination among same-category instances.
400 |
401 | **Example**: In an image with multiple cars of different colors and makes, locate "the charcoal-grey sedan" (not just "car").
402 |
403 | **Why it's hard**: Models often shortcut by ignoring attributes and relying solely on category nouns.
404 |
405 | ### Evaluation Metrics
406 | - **Average Gap**: Average confidence margin between ground-truth box and highest-scoring same-category distractor
407 | - **Positive Gap Ratio**: Percentage of images where ground-truth box receives highest confidence among same-category candidates
408 |
409 | ### Key Findings
410 |
411 | #### 🎯 Targeted SOC Data Fixes Intra-Class Shortcuts
412 |
413 | | Training Data | Average Gap | Positive Gap Ratio |
414 | |--------------|-------------|-------------------|
415 | | Object365+GoldG | 37.5 | ~80% |
416 | | + GRIT (20M) | 34.6 (-2.9) | ~82% |
417 | | + V3Det (200K) | 36.7 (-0.8) | ~83% |
418 | | + GRIT + V3Det | 35.8 (-1.7) | ~85% |
419 | | **+ SOC-SFC-50K + SOC-SGC-50K** | **40.6 (+3.1)** | **90%** |
420 |
421 | **SOC-SFC/SGC**: Synthetic images with multiple instances of the same category but varied attributes (e.g., cars with different colors and makes).
422 |
423 | **Key Insight**: Large-scale auxiliary data (GRIT, V3Det) yields negligible or even negative impact. Only targeted synthetic data tailored to intra-class attribute variation significantly improves performance!
424 |
425 | ---
426 |
427 |
428 |
429 | # 📧 Contact
430 |
431 | * **Weikai Huang**: weikaih@cs.washington.edu
432 | * **Jieyu Zhang**: jieyuz2@cs.washington.edu
433 |
434 | ---
435 |
436 | # 📝 Citation
437 |
438 | ```bibtex
439 | @misc{huang2025syntheticobjectcompositionsscalable,
440 | title={Synthetic Object Compositions for Scalable and Accurate Learning in Detection, Segmentation, and Grounding},
441 | author={Weikai Huang and Jieyu Zhang and Taoyang Jia and Chenhao Zheng and Ziqi Gao and Jae Sung Park and Winson Han and Ranjay Krishna},
442 | year={2025},
443 | eprint={2510.09110},
444 | archivePrefix={arXiv},
445 | primaryClass={cs.CV},
446 | url={https://arxiv.org/abs/2510.09110},
447 | }
448 | ```
449 |
450 | ---
451 |
452 | # 🙏 Acknowledgments
453 |
454 | We thank the authors of FLUX-1, IC-Light, DIS, Qwen, and QwQ for their excellent open-source models that made this work possible.
455 |
--------------------------------------------------------------------------------
/referring_expression_generation/inference.py:
--------------------------------------------------------------------------------
1 | # --- Chunking helper ---
2 | import random
3 |
4 | def chunk_result(result, threshold=None):
5 | """
6 | Splits the result['grounding']['regions'] into chunks based on a character threshold,
7 | recomputing tokens_positive (now character offsets) for each region in each chunk.
8 | Each region in a chunk will have its tokens_positive field adjusted to match the chunk's caption.
9 | """
10 | if threshold is None:
11 | threshold = random.randint(150, 255)
12 |
13 | regions_orig = result['grounding']['regions']
14 | chunks = []
15 | region_idx = 0
16 | total_regions = len(regions_orig)
17 |
18 | while region_idx < total_regions:
19 | temp_caption = []
20 | temp_regions = []
21 | cursor_chunk = 0
22 |
23 | while region_idx < total_regions:
24 | region = regions_orig[region_idx]
25 | text = region['phrase']
26 | next_caption = '. '.join(temp_caption + [text]) + '. '
27 | if len(next_caption) > threshold and temp_regions:
28 | break
29 |
30 | # Character-based span
31 | span_start = cursor_chunk
32 | span_end = cursor_chunk + len(text) - 1
33 |
34 | chunk_region = {
35 | 'phrase': text,
36 | 'bboxes': region['bboxes'],
37 | 'tokens_positive': [[span_start, span_end]],
38 | }
39 | temp_caption.append(text)
40 | temp_regions.append(chunk_region)
41 |
42 | # Advance by text length + separator (". ")
43 | cursor_chunk += len(text) + 2
44 | region_idx += 1
45 |
46 | chunks.append({
47 | 'filename': result['filename'],
48 | 'height': result['height'],
49 | 'width': result['width'],
50 | 'grounding': {
51 | 'caption': '. '.join([r['phrase'] for r in temp_regions]) + '. ',
52 | 'regions': temp_regions
53 | }
54 | })
55 |
56 | return chunks
57 | # import debugpy
58 | # try:
59 | # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
60 | # debugpy.listen(("localhost", 9501))
61 | # print("Waiting for debugger attach")
62 | # debugpy.wait_for_client()
63 | # except Exception as e:
64 | # pass
65 |
66 |
67 | # generate_referring_expressions_odvg.py
68 | """
69 | Generate attribute-, spatial- and reasoning-based referring expressions (single, multi, non) for a synthetic dataset
70 | and write the results to OVDG-style jsonl.
71 |
72 | Key design changes vs. original script
73 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
74 | 1. **One LLM call per image** – All nine buckets of expressions are requested in **one** chat completion by
75 | asking the model for a structured JSON block. This halves token usage and latency.
76 | 2. **Configurable question counts** – Numbers per bucket are read from a JSON/YAML config or CLI flags.
77 | 3. **Cleaner prompt & examples** – Prompt now includes user-provided examples, and explicit instructions to:
78 | - Use `category`, `short_phrase`, and `features` fields when generating expressions.
79 | - Infer spatial relationships strictly from the provided `bbox` values (absolute and relative positions).
80 | 4. **Direct OVDG output** – Each question becomes an OVDG *region* entry and grouped under one
81 | caption per image. `tokens_positive` are computed automatically.
82 | 5. **Stateless retries** – Automatic validation + adaptive temperature retry loop keeps the pipeline robust.
83 | """
84 |
85 | import argparse, json, re, time
86 | from pathlib import Path
87 | from typing import Any, Dict, List
88 | from tqdm import tqdm
89 | import openai
90 | import random
91 | from pydantic import BaseModel, Field
92 | from typing import List
93 | import os
94 | from multiprocessing import Pool
95 | import tempfile
96 | from google.cloud import storage
97 |
98 | # --- GCS helpers ---
99 | def parse_gcs_path(gcs_path: str):
100 | if not gcs_path.startswith("gs://"):
101 | raise ValueError(f"Not a valid GCS path: {gcs_path}")
102 | path = gcs_path[5:]
103 | bucket, _, blob = path.partition("/")
104 | return bucket, blob
105 |
106 | def download_from_gcs(gcs_path: str, local_path: str):
107 | bucket_name, blob_name = parse_gcs_path(gcs_path)
108 | client = storage.Client()
109 | bucket = client.bucket(bucket_name)
110 | blob = bucket.blob(blob_name)
111 | blob.download_to_filename(local_path)
112 | print(f"Downloaded {gcs_path} to {local_path}")
113 |
114 | def upload_to_gcs(local_path: str, gcs_path: str):
115 | bucket_name, blob_name = parse_gcs_path(gcs_path)
116 | client = storage.Client()
117 | bucket = client.bucket(bucket_name)
118 | blob = bucket.blob(blob_name)
119 | blob.upload_from_filename(local_path)
120 | print(f"Uploaded {local_path} to {gcs_path}")
121 |
122 | class Expression(BaseModel):
123 | q: str = Field(..., description="The referring‐expression string")
124 | ids: List[int] = Field(..., description="List of segment IDs referenced")
125 |
126 | class ObjectExpressions(BaseModel):
127 | attribute: List[Expression] = Field(..., description="Attribute‐based expressions")
128 | spatial: List[Expression] = Field(..., description="Spatial‐based expressions")
129 | reasoning: List[Expression] = Field(..., description="Reasoning‐based expressions")
130 |
131 | class RefExpressions(BaseModel):
132 | single_object: ObjectExpressions = Field(..., description="Expressions targeting exactly one object")
133 | multi_object: ObjectExpressions = Field(..., description="Expressions targeting multiple objects")
134 |
135 | #####################################################################
136 | # -------------------------- CONFIG --------------------------------#
137 | #####################################################################
138 | DEFAULT_NUM_Q = {
139 | "single": {"attribute": 6, "spatial": 6, "reasoning": 6},
140 | "multi": {"attribute": 3, "spatial": 3, "reasoning": 3},
141 | }
142 |
143 | MAX_RETRIES = 10
144 | SLEEP_BETWEEN_RETRIES = 10 # seconds
145 |
146 | #####################################################################
147 | # ----------------------- EXAMPLES --------------------------------#
148 | #####################################################################
149 | EXAMPLE_SEGMENTS = '''
150 | [ID: 7377552] tongs | short_phrase: tongs with rough iron texture | features: [] | description: tongs with a rough iron texture, painted in old bronze | bbox: [318, 535, 128, 282]
151 | [ID: 10569372] bath_towel | short_phrase: bath towel with tribal flair | features: [] | description: a bath_towel with geometric tribal flair in coppery tones | bbox: [474, 10, 493, 879]
152 | [ID: 2187630] Canned | short_phrase: cylindrical can of recycled aluminum | features: [] | description: A cylindrical can made of recycled aluminum. | bbox: [376, 217, 102, 247]
153 | [ID: 10733385] shovel | short_phrase: shovel with gleaming blade | features: [] | description: a shovel with a blade that gleams like polished alabaster | bbox: [424, 726, 48, 184]
154 | [ID: 519546] knitting_needle | short_phrase: knitting needle with glossy finish | features: [] | description: a knitting_needle with a glossy, clear finish and a spiral ridge | bbox: [0, 80, 97, 125]
155 | [ID: 4995055] strap | short_phrase: slim clear strap with blue stripe | features: [] | description: a slim, clear strap with a spray-painted blue stripe | bbox: [725, 178, 54, 60]
156 | [ID: 9339368] teakettle | short_phrase: teakettle with glass body | features: [] | description: A teakettle with a round glass body and a charming, twisted copper handle. | bbox: [324, 789, 103, 117]
157 | [ID: 8109537] cushion | short_phrase: hunter green leather cushion | features: [] | description: A hunter green, sleek leather cushion. | bbox: [684, 219, 89, 79]
158 | [ID: 4123758] raspberry | short_phrase: glossy raspberry with maroon tinge | features: [] | description: a glossy raspberry with a subtle maroon tinge | bbox: [183, 324, 134, 140]
159 | [ID: 9903309] dropper | short_phrase: dropper with matte black body | features: [] | description: A dropper with a matte black body and a glossy dropper tip | bbox: [204, 868, 93, 122]
160 | [ID: 2998739] snowmobile | short_phrase: snowmobile with white surface | features: [] | description: a snowmobile with a glossy white surface decorated with longitudinal red strips | bbox: [898, 451, 63, 38]
161 | [ID: 13570439] box | short_phrase: box with bold stripes and smiley | features: [] | description: a box painted in bold stripes with a quirky smiley face | bbox: [305, 837, 46, 42]
162 | [ID: 232766] ram_animal | short_phrase: compact ram with patchwork wool | features: [] | description: a compact ram boasting an intricate pattern of color on its wool, resembling patchwork | bbox: [152, 704, 35, 36]
163 | [ID: 15801147] birthday_card | short_phrase: cheerful pirate ship with map | features: [] | description: A birthday_card featuring a cheerful pirate ship with a colorful map. | bbox: [34, 924, 45, 46]
164 | [ID: 8631767] Egg_tart | short_phrase: marigold custard egg tart | features: [] | description: An egg tart with a marigold-colored custard that slightly spills over around the cornflower-blue border. | bbox: [442, 207, 43, 44]
165 | [ID: 12491521] cornbread | short_phrase: crispy cornbread with golden flecks | features: [] | description: A crispy slice of cornbread, with a myriad of shimmering, golden flecks and patches. | bbox: [191, 621, 42, 42]
166 | [ID: 7054199] wooden_spoon | short_phrase: stout wooden spoon for dough | features: [] | description: a stout, thick wooden_spoon with a hefty feel, perfect for tackling hefty dough mixtures | bbox: [207, 292, 46, 40]
167 | [ID: 2379817] motor | short_phrase: tiny pink motor with meta plates | features: [] | description: a tiny, delicate motor painted in a pastel pink with tiny meta plates | bbox: [95, 812, 40, 25]
168 | '''
169 |
170 | EXAMPLE_PROMPT = '''
171 | {
172 | "single_object": {
173 | "attribute": [
174 | { "q": "The glossy raspberry with maroon tinge", "ids": [4123758] },
175 | { "q": "The slim clear strap with blue stripe", "ids": [4995055] },
176 | { "q": "The cylindrical can of recycled aluminum", "ids": [2187630] },
177 | { "q": "The hunter green leather cushion", "ids": [8109537] }
178 | ],
179 | "spatial": [
180 | { "q": "The knitting needle with glossy finish on the far left", "ids": [519546] },
181 | { "q": "The snowmobile with white surface on the far right", "ids": [2998739] }
182 | ],
183 | "reasoning": [
184 | { "q": "The teakettle with glass body below the shovel with the gleaming blade", "ids": [9339368] },
185 | { "q": "The box with bold stripes and smiley to the left of the snowmobile with white surface", "ids": [13570439] },
186 | { "q": "The glossy raspberry with maroon tinge to the left of the cylindrical can of recycled aluminum", "ids": [4123758] },
187 | { "q": "The slim clear strap with blue stripe on the bath towel with tribal flair", "ids": [4995055] },
188 | { "q": "The stout wooden spoon for dough above the shovel with the gleaming blade", "ids": [7054199] }
189 | ]
190 | },
191 | "multi_object": {
192 | "attribute": [
193 | { "q": "All the objects with a glossy finish", "ids": [519546, 4123758] },
194 | { "q": "All the striped objects", "ids": [13570439, 2998739] },
195 | { "q": "All the containers", "ids": [2187630, 13570439] },
196 | ],
197 | "spatial": [
198 | {
199 | "q": "All the objects above the horizontal midpoint of the image",
200 | "ids": [10569372, 2187630, 4995055, 8109537, 4123758, 8631767, 7054199, 519546]
201 | },
202 | {
203 | "q": "All the objects that span the central vertical band of the image",
204 | "ids": [7054199, 4123758, 2187630, 7377552, 2998739]
205 | },
206 | {
207 | "q": "All the objects to the left of center and above the shovel with gleaming blade",
208 | "ids": [7377552, 2187630, 519546, 4123758, 7054199, 13570439, 232766, 2379817, 8631767]
209 | },
210 | {
211 | "q": "All the objects surrounding the cylindrical can of recycled aluminum",
212 | "ids": [7377552, 7054199, 8631767, 4123758]
213 | }
214 | ],
215 | "reasoning": [
216 | { "q": "All the metallic objects to the left of the bath towel with tribal flair", "ids": [7377552, 2187630, 2379817] },
217 | { "q": "All the objects with a glossy finish above the box with bold stripes and smiley", "ids": [519546, 4123758] },
218 | { "q": "All the containers to the right of the knitting needle with glossy finish", "ids": [2187630, 13570439] },
219 | ]
220 | }
221 | }
222 | '''
223 |
224 | #####################################################################
225 | # ----------------------- PROMPT UTILS -----------------------------#
226 | #####################################################################
227 | OBJECT_TEMPLATE = "[ID: {id}] {category} | short_phrase: {short_phrase} | features: {features} | description: {description} | bbox: {bbox}"
228 |
229 | SCHEMA_SNIPPET = '''
230 | ### JSON schema you MUST return
231 | {
232 | "single_object": {"attribute": [{"q": str, "ids": [int]}], "spatial": [{"q": str, "ids": [int]}], "reasoning": [{"q": str, "ids": [int]}]},
233 | "multi_object": { … same keys, but ids lists contain 2+ ints … }
234 | }
235 | '''
236 |
237 |
238 | def build_prompt(segments: List[Dict[str, Any]], num_q: Dict[str, Dict[str, int]]) -> str:
239 | # Build object table
240 | table = "\n".join(OBJECT_TEMPLATE.format(**seg) for seg in segments)
241 | # Build count instructions
242 | counts_text = "\n".join(f"- {scope}/{kind}: {cnt} expressions" for scope, kinds in num_q.items() for kind, cnt in kinds.items())
243 |
244 | # Combine prompt
245 | prompt = f"""
246 | You are a referring expression detection data generator. I will provide you with a list of objects WITHIN an IMAGE, and you will generate referring expressions similar to RefCOCO, RefCOCO+, RefCOCOg, and GrefCOCO.
247 |
248 | We categorized referring expressions into 3 types: attribute-based, spatial-based, and reasoning-based.
249 | **Requirements**:
250 | - Attribute-based: ask about `features`, `category`, or `short_phrase` of exactly one object. (e.g. The white dog)
251 | - Spatial-based: infer absolute or relative positions strictly from the `bbox` values (e.g. left/right, above/below, center). (e.g. The dog left to the people with brown shirt)
252 | - Reasoning-based: combine features, short_phrase, category and spatial bbox relationships between objects. (e.g. The whilte animal left to the person with brown shirt)
253 | - Use `short_phrase` or `features` preferentially to refer to objects; also can use `category` with some features to refer it.
254 | - Return ONE JSON block matching the schema exactly, with **exactly** the requested counts per bucket. No extra keys.
255 |
256 | For each referring expressions, we have 3 types of returning objects:
257 | 1. **Single object**: Expression refers to exactly one object in the image. (e.g. The white dog)
258 | 2. **Multi-object**: Expression refers to 2 or more objects in the image. (e.g. All the white dog in the images)
259 |
260 | We have categorized referring expressions into three types: attribute-based, spatial-based, and reasoning-based.
261 |
262 | Requirements:
263 | - Attribute-based: Refer to exactly one object by its features, category, or short_phrase (e.g., "the white dog").
264 | - Spatial-based: Infer absolute or relative positions strictly from bbox values (e.g., "the dog to the left of the person with a brown shirt").
265 | - Reasoning-based: Combine features, short_phrases, categories, and spatial bbox relationships between objects (e.g., "the white animal to the left of the person wearing a brown shirt").
266 | - Use short_phrase or features preferentially to refer to objects; you may also use category together with features.
267 | - Return a single JSON block matching the schema exactly, containing exactly the requested counts per bucket; include no extra keys.
268 |
269 | For each referring expression, there are three reference types:
270 | 1. Single-object: The expression refers to exactly one object in the image (e.g., "the white dog").
271 | 2. Multi-object: The expression refers to two or more objects in the image (e.g., "all the white dogs in the image").
272 |
273 |
274 |
275 | # ### Example segments and expressions
276 | # Example annotation
277 | {EXAMPLE_SEGMENTS}
278 |
279 | # Example generated expression
280 | {EXAMPLE_PROMPT}
281 |
282 |
283 | ### Here is the objects in the image (Bounding Box is provided in XYWH COCO format, you should compare then in coco's way):
284 | {table}
285 |
286 | ### Counts to generate
287 | {counts_text}
288 |
289 | {SCHEMA_SNIPPET}
290 | """
291 | return prompt
292 |
293 |
294 | #####################################################################
295 | # ----------------------- OVDG HELPERS -----------------------------#
296 | #####################################################################
297 |
298 | def ovdg_from_expressions(image: Dict[str, Any], qs: Dict[str, Any]) -> Dict[str, Any]:
299 | regions, caption_parts, cursor = [], [], 0
300 | # Map segment IDs to their bounding boxes
301 | id2bbox = {seg['id']: seg['bbox'] for seg in image['segments_info']}
302 |
303 | def add_region(text: str, ids: List[int]):
304 | nonlocal cursor
305 | start, end = cursor, cursor + len(text) - 1
306 | # Gather all bboxes for each refZrenced ID
307 | bboxes = [id2bbox.get(i, [0, 0, 0, 0]) for i in ids]
308 | regions.append({
309 | 'bboxes': bboxes,
310 | 'phrase': text,
311 | 'tokens_positive': [[start, end]]
312 | })
313 | caption_parts.append(text)
314 | cursor += len(text)
315 |
316 | # Process both single-object and multi-object queries
317 | for scope in ('single_object', 'multi_object'):
318 | for kind in ('attribute', 'spatial', 'reasoning'):
319 | for e in qs[scope][kind]:
320 | # Strip punctuation and add region
321 | add_region(e['q'].rstrip('?. '), e['ids'])
322 |
323 | return {
324 | 'filename': image.get('file_name', ''),
325 | 'height': image.get('height', 1024),
326 | 'width': image.get('width', 1024),
327 | 'grounding': {
328 | 'caption': '. '.join(caption_parts) + '. ',
329 | 'regions': regions
330 | }
331 | }
332 |
333 |
334 | #####################################################################
335 | # -------------------- MAIN PIPELINE -------------------------------#
336 | #####################################################################
337 | def infer_expressions(client: openai.Client, prompt: str) -> Dict[str, Any]:
338 | from pydantic import ValidationError
339 |
340 | attempt = 0
341 | parsed: RefExpressions | None = None
342 |
343 | while attempt < MAX_RETRIES:
344 | try:
345 | # resp = client.beta.chat.completions.parse(
346 | # model=MODEL_NAME,
347 | # messages=[{'role':'user','content':prompt}],
348 | # response_format=RefExpressions
349 | # )
350 | resp = client.beta.chat.completions.parse(
351 | model="Qwen/QwQ-32B-AWQ",
352 | messages=[{'role':'user','content':prompt}],
353 | response_format=RefExpressions,
354 | temperature=0.1
355 | )
356 | # .choices[0].message.parsed is already a RefExpressions instance
357 | parsed = resp.choices[0].message.parsed
358 | break
359 | except ValidationError as ve:
360 | # JSON structure didn’t match—retry
361 | attempt += 1
362 | print(f"[Attempt {attempt}] validation error: {ve}")
363 | time.sleep(SLEEP_BETWEEN_RETRIES)
364 | except Exception as e:
365 | # LLM or network error—also retry
366 | attempt += 1
367 | print(f"[Attempt {attempt}] error: {e}")
368 | time.sleep(SLEEP_BETWEEN_RETRIES)
369 |
370 | if parsed is None:
371 | # Give back an empty—but correctly typed—fallback
372 | parsed = RefExpressions(
373 | single_object=ObjectExpressions(attribute=[], spatial=[], reasoning=[]),
374 | multi_object= ObjectExpressions(attribute=[], spatial=[], reasoning=[])
375 | )
376 |
377 | # Finally return a plain dict for the rest of your pipeline
378 | return parsed.dict()
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 | # -------------------- Distributed Runner --------------------#
387 |
388 | def create_client(api_key: str):
389 | return openai.Client(base_url="http://localhost:8080/v1", api_key=api_key)
390 |
391 | def process_annotation(args):
392 | idx, ann, cache_dir, api_key, orig_output_dir = args
393 | cache_file = os.path.join(cache_dir, f"{idx}.json")
394 | output_dir = orig_output_dir
395 | cache_dir_name = os.path.basename(cache_dir)
396 | if cache_dir_name.startswith("cache_job_"):
397 | job_index = cache_dir_name[len("cache_job_") :]
398 | else:
399 | job_index = "0"
400 | remote_cache_path = os.path.join(orig_output_dir, f"cache_job_{job_index}", f"{idx}.json")
401 | # If output_dir is a GCS path, check remote cache existence
402 | if output_dir.startswith("gs://"):
403 | bucket_name, blob_name = parse_gcs_path(remote_cache_path)
404 | client_gcs = storage.Client()
405 | bucket = client_gcs.bucket(bucket_name)
406 | blob = bucket.blob(blob_name)
407 | if blob.exists():
408 | # Remote cache exists, download and use it
409 | blob.download_to_filename(cache_file)
410 | with open(cache_file, 'r') as f:
411 | raw = json.load(f)
412 | chunks = chunk_result(raw)
413 | return idx, chunks
414 | # Only check local cache for non-GCS cases
415 | if not output_dir.startswith("gs://") and os.path.exists(cache_file) and os.path.getsize(cache_file) > 0:
416 | with open(cache_file, 'r') as f:
417 | raw = json.load(f)
418 | chunks = chunk_result(raw)
419 | return idx, chunks
420 |
421 | client = create_client(api_key)
422 | prompt = build_prompt(ann['segments_info'], DEFAULT_NUM_Q)
423 |
424 | # Use the shared infer_expressions function instead of inline logic
425 | qs = infer_expressions(client, prompt)
426 |
427 | valid_ids = {seg['id'] for seg in ann['segments_info']}
428 | for scope in ('single_object','multi_object'):
429 | for kind in ('attribute','spatial','reasoning'):
430 | for expr in qs[scope][kind]:
431 | expr['ids'] = [rid for rid in expr['ids'] if rid in valid_ids]
432 |
433 | result = ovdg_from_expressions(ann, qs)
434 | chunks = chunk_result(result)
435 | with open(cache_file, 'w') as f:
436 | json.dump(result, f)
437 | # After writing local cache, upload to GCS if output_dir is GCS
438 | if output_dir.startswith("gs://"):
439 | upload_to_gcs(local_path=cache_file, gcs_path=remote_cache_path)
440 | return idx, chunks
441 |
442 | def main():
443 | parser = argparse.ArgumentParser(description="Distributed ODVG inference")
444 | parser.add_argument('total_jobs', type=int, help='Total number of jobs')
445 | parser.add_argument('job_index', type=int, help='Index of this job (0-based)')
446 | parser.add_argument('input_file', type=str, help='Path to input JSONL annotations')
447 | parser.add_argument('output_dir', type=str, help='Directory for outputs and cache')
448 | parser.add_argument('--api_key', type=str, default="placeholder", help='OpenAI API key')
449 | parser.add_argument('--num_workers', type=int, default=300, help='Number of parallel processes')
450 | args = parser.parse_args()
451 |
452 | # --- GCS input/output support ---
453 | # Download input file if on GCS
454 | is_gcs_input = args.input_file.startswith("gs://")
455 | if is_gcs_input:
456 | local_input = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl").name
457 | download_from_gcs(args.input_file, local_input)
458 | args.input_file = local_input
459 |
460 | # Prepare local output directory (and flag for later upload)
461 | is_gcs_output = args.output_dir.startswith("gs://")
462 | if is_gcs_output:
463 | local_output_dir = tempfile.mkdtemp()
464 | else:
465 | local_output_dir = args.output_dir
466 |
467 | api_key = args.api_key or os.environ.get('OPENAI_API_KEY')
468 | if not api_key:
469 | raise ValueError("OpenAI API key must be provided via --api_key or OPENAI_API_KEY")
470 |
471 | os.makedirs(local_output_dir, exist_ok=True)
472 | cache_dir = os.path.join(local_output_dir, f"cache_job_{args.job_index}")
473 | os.makedirs(cache_dir, exist_ok=True)
474 |
475 | annotations = []
476 | annotations = json.load(open(args.input_file)).get('annotations', [])
477 |
478 | total = len(annotations)
479 | per_job = total // args.total_jobs
480 | start = args.job_index * per_job
481 | end = total if args.job_index == args.total_jobs-1 else start + per_job
482 | partition = annotations[start:end]
483 |
484 | tasks = [(i, ann, cache_dir, api_key, args.output_dir) for i, ann in enumerate(partition)]
485 | with Pool(processes=args.num_workers) as pool:
486 | results = list(tqdm(pool.imap_unordered(process_annotation, tasks), total=len(tasks)))
487 |
488 | # results is a list of (idx, chunks), where chunks is a list of dicts
489 | results.sort(key=lambda x: x[0])
490 | output_path = os.path.join(local_output_dir, f"job_{args.job_index}.jsonl")
491 | with open(output_path, 'w') as out_f:
492 | for _, chunks in results:
493 | if chunks is None:
494 | continue
495 | for entry in chunks:
496 | out_f.write(json.dumps(entry) + '\n')
497 |
498 |
499 | # Upload result back to GCS if requested
500 | if is_gcs_output:
501 | gcs_dest = args.output_dir.rstrip("/") + f"/job_{args.job_index}.jsonl"
502 | upload_to_gcs(output_path, gcs_dest)
503 |
504 | if __name__ == '__main__':
505 | main()
506 |
--------------------------------------------------------------------------------
/segmentation_generator/segmentataion_synthesizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .generate_image_and_mask import generate_image_and_mask
3 | from segment_layout.existing_layout_generator import ExistingLayoutGenerator
4 | from segment_layout.fine_grained_bbox_layout_generator import FineGrainedBoundingBoxLayoutGenerator
5 | from PIL import Image
6 |
7 | def generate_random_color_and_id(rng):
8 | # Generate random values for R, G, and B between 0 and 255
9 | R = rng.integers(0, 256)
10 | G = rng.integers(0, 256)
11 | B = rng.integers(0, 256)
12 |
13 | # Calculate the color id based on the formula provided
14 | color_id = R + G * 256 + B * (256**2)
15 |
16 | # Return the color and the id
17 | color = (R, G, B)
18 | return color, color_id
19 |
20 |
21 | class BaseSegmentationSynthesizer():
22 | synthesize_method = None
23 | def __init__(self, data_manager, save_path, random_seed) -> None:
24 | if random_seed is None:
25 | random_seed = 42 # use the default seed
26 | self.rng = np.random.default_rng(random_seed)
27 | self.data_manager = data_manager
28 | self.save_path = save_path
29 |
30 | def random_position(self, width, height):
31 | x = self.rng.integers(0, width)
32 | y = self.rng.integers(0, height)
33 | return (x, y)
34 |
35 | def random_augmentation(self):
36 | # Randomly select a scale factor between 0.5 and 1.5,
37 | # a boolean flag for horizontal flip (50% chance),
38 | # a boolean flag for vertical flip (50% chance),
39 | # and a rotation angle between -180 and 180 degrees.
40 | scale = self.rng.uniform(0.5, 1.5)
41 | flip_horizontal = bool(self.rng.integers(0, 2))
42 | flip_vertical = bool(self.rng.integers(0, 2))
43 | rotate = int(self.rng.integers(-180, 180))
44 | return {"scale": scale, "flip_horizontal": flip_horizontal, "flip_vertical": flip_vertical, "rotate": rotate}
45 |
46 |
47 |
48 | def category_to_id(self, category):
49 | if category.startswith("COCO2017"):
50 | return int(category.split("_")[-1])
51 | if category in self.data_manager.unified_object_categories_to_idx:
52 | return self.data_manager.unified_object_categories_to_idx[category]
53 | else:
54 | raise ValueError(f"Category {category} not found in the unified object categories.")
55 |
56 | def generate(self, image_metadata, resize_mode, containSmallObjectMask=False):
57 | # start_time = time.time()
58 | if containSmallObjectMask:
59 | image, mask, small_object_mask_np = generate_image_and_mask(self.data_manager, image_metadata, resize_mode=resize_mode, containSmallObjectMask=True)
60 | output = {"image_metadata": image_metadata, "image": image, "mask": mask, "small_object_mask_np": small_object_mask_np}
61 | else:
62 | image, mask = generate_image_and_mask(self.data_manager, image_metadata, resize_mode=resize_mode, containSmallObjectMask=False)
63 | output = {"image_metadata": image_metadata, "image": image, "mask": mask}
64 |
65 | # elapsed_time = time.time() - start_time # compute elapsed time
66 | # print(f"generate_image_and_mask took {elapsed_time:.6f} seconds")
67 |
68 | return output
69 |
70 |
71 | def generate_with_coco_panoptic_format(self, image_metadata, resize_mode="fit", containRGBA=False, containBbox=True, containCategory=True, containSmallObjectMask=False):
72 | output = self.generate(image_metadata, resize_mode, containSmallObjectMask)
73 | image, mask, metadata = output["image"], output["mask"], output["image_metadata"]
74 |
75 | segment_id_to_color = {}
76 | used_colors = set()
77 | mask = np.array(mask)
78 | height, width = mask.shape
79 |
80 | coco_mask = np.zeros((height, width, 3), dtype=np.uint8)
81 | segments_info = []
82 |
83 | for object_info in metadata["objects"]:
84 | segment_id = object_info["segment_id"]
85 | if segment_id != 0:
86 | while True:
87 | unique_color, unique_id = generate_random_color_and_id(self.rng)
88 | if unique_color not in used_colors:
89 | used_colors.add(unique_color) # Mark this color as used
90 | break
91 |
92 | if segment_id not in segment_id_to_color:
93 | segment_id_to_color[segment_id] = (unique_color, unique_id)
94 | coco_mask[mask == segment_id] = segment_id_to_color[segment_id][0]
95 |
96 | if containBbox:
97 | rows, cols = np.where(mask == segment_id)
98 | if rows.size and cols.size:
99 | x_min, y_min = int(cols.min()), int(rows.min())
100 | x_max, y_max = int(cols.max()), int(rows.max())
101 | # COCO uses [x, y, width, height] format.
102 | bbox = [x_min, y_min, x_max - x_min + 1, y_max - y_min + 1]
103 | else:
104 | bbox = [0, 0, 0, 0]
105 | else:
106 | bbox = [0, 0, 0, 0]
107 |
108 | # iscrowd, bbox, area are not used in this case, so we set them to dummy values
109 | segments_info.append({
110 | "id": unique_id,
111 | "category_id": self.category_to_id(object_info['object_metadata']["category"]) if containCategory else 1,
112 | "iscrowd": 0,
113 | "bbox": bbox,
114 | "area": 1,
115 | })
116 |
117 | output["coco_mask"] = coco_mask
118 | output["segments_info"] = segments_info
119 | if containRGBA:
120 | alpha_channel = (mask > 0).astype(np.uint8) * 255
121 | image_rgba = np.dstack((image, alpha_channel))
122 | output["image_rgba"] = image_rgba
123 | if containSmallObjectMask:
124 | alpha_channel_small_obj = (output['small_object_mask_np'] > 0).astype(np.uint8) * 255
125 | image_rgba_small_obj = np.dstack((image, alpha_channel_small_obj))
126 | output['image_rgba_small_object'] = image_rgba_small_obj
127 | return output
128 |
129 |
130 | def generate_with_unified_format(self, image_metadata, resize_mode="fit", containRGBA=False, containBbox=True, containCategory=True, containSmallObjectMask=False):
131 | output = self.generate(image_metadata, resize_mode, containSmallObjectMask)
132 | image, mask, metadata = output["image"], output["mask"], output["image_metadata"]
133 |
134 | segment_id_to_color = {}
135 | used_colors = set()
136 | mask = np.array(mask)
137 | height, width = mask.shape
138 |
139 | coco_mask = np.zeros((height, width, 3), dtype=np.uint8)
140 | segments_info = []
141 |
142 | for object_info in metadata["objects"]:
143 | segment_id = object_info["segment_id"]
144 | if segment_id != 0:
145 | while True:
146 | unique_color, unique_id = generate_random_color_and_id(self.rng)
147 | if unique_color not in used_colors:
148 | used_colors.add(unique_color) # Mark this color as used
149 | break
150 |
151 | if segment_id not in segment_id_to_color:
152 | segment_id_to_color[segment_id] = (unique_color, unique_id)
153 | coco_mask[mask == segment_id] = segment_id_to_color[segment_id][0]
154 |
155 | if containBbox:
156 | rows, cols = np.where(mask == segment_id)
157 | if rows.size and cols.size:
158 | x_min, y_min = int(cols.min()), int(rows.min())
159 | x_max, y_max = int(cols.max()), int(rows.max())
160 | # COCO uses [x, y, width, height] format.
161 | bbox = [x_min, y_min, x_max - x_min + 1, y_max - y_min + 1]
162 | else:
163 | bbox = [0, 0, 0, 0]
164 | else:
165 | bbox = [0, 0, 0, 0]
166 |
167 | # iscrowd, bbox, area are not used in this case, so we set them to dummy values
168 | segments_info.append({
169 | "id": unique_id,
170 | "category_id": self.category_to_id(object_info['object_metadata']["category"]) if containCategory else 1,
171 | "iscrowd": 0,
172 | "bbox": bbox,
173 | "category": object_info['object_metadata']["category"],
174 | "sub_category": object_info['object_metadata']["sub_category"],
175 | "description": object_info['object_metadata']["description"],
176 | "features": object_info['object_metadata']["features"],
177 | "short_phrase": object_info['object_metadata']["short_phrase"],
178 | "area": 1,
179 | })
180 |
181 | output["coco_mask"] = coco_mask
182 | output["segments_info"] = segments_info
183 | if containRGBA:
184 | alpha_channel = (mask > 0).astype(np.uint8) * 255
185 | image_rgba = np.dstack((image, alpha_channel))
186 | output["image_rgba"] = image_rgba
187 | if containSmallObjectMask:
188 | alpha_channel_small_obj = (output['small_object_mask_np'] > 0).astype(np.uint8) * 255
189 | image_rgba_small_obj = np.dstack((image, alpha_channel_small_obj))
190 | output['image_rgba_small_object'] = image_rgba_small_obj
191 | return output
192 |
193 | class RandomCenterPointSegmentationSynthesizer(BaseSegmentationSynthesizer):
194 | synthesize_method = "random"
195 | def __init__(self, data_manager, save_path, random_seed=None):
196 | super().__init__(data_manager, save_path, random_seed)
197 |
198 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False): # -> dict:
199 | image_metadata = {}
200 | image_metadata["width"] = width
201 | image_metadata["height"] = height
202 | image_metadata["number_of_objects"] = number_of_objects
203 | image_metadata["synthesize_method"] = self.synthesize_method
204 |
205 | if hasBackground:
206 | background_metadata = self.data_manager.get_random_background_metadata(self.rng)
207 | image_metadata["background"] = {"background_metadata": background_metadata}
208 |
209 | curr_segment_id = 1
210 | objects = []
211 | for _ in range(number_of_objects):
212 | # Randomly select a segmentation from the data manager
213 | object_metadata = self.data_manager.get_random_object_metadata(self.rng)
214 | object_position = self.random_position(width, height)
215 | if dataAugmentation:
216 | augmentation = self.random_augmentation()
217 | objects.append({
218 | "object_metadata": object_metadata,
219 | "object_position": object_position,
220 | "segment_id": curr_segment_id,
221 | "augmentation": augmentation
222 | })
223 | else:
224 | objects.append({
225 | "object_metadata": object_metadata,
226 | "object_position": object_position,
227 | "segment_id": curr_segment_id
228 | })
229 | curr_segment_id += 1
230 | image_metadata["objects"] = objects
231 | return image_metadata
232 |
233 | class FineGrainedBoundingBoxSegmentationSynthesizer(BaseSegmentationSynthesizer):
234 | synthesize_method = "fine_grained_bbox"
235 | def __init__(self, data_manager, save_path, random_seed=None):
236 | super().__init__(data_manager, save_path, random_seed)
237 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator()
238 |
239 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False, considerArea=False): # -> dict:
240 | image_metadata = {}
241 | image_metadata["width"] = width
242 | image_metadata["height"] = height
243 | image_metadata["number_of_objects"] = number_of_objects
244 | image_metadata["synthesize_method"] = self.synthesize_method
245 |
246 | if hasBackground:
247 | background_metadata = self.data_manager.get_random_background_metadata(self.rng)
248 | image_metadata["background"] = {"background_metadata": background_metadata}
249 |
250 | curr_segment_id = 1
251 | objects = []
252 | # sort based on the original size
253 | for _ in range(number_of_objects):
254 | # Randomly select a segmentation from the data manager
255 | object_metadata = self.data_manager.get_random_object_metadata(self.rng)
256 | if considerArea:
257 | # add area
258 | path = object_metadata['image_path']
259 | with Image.open(path) as img:
260 | img_np = np.array(img)
261 | area = np.sum(img_np[:, :, 3] > 0)
262 | object_metadata['area'] = area
263 |
264 | if dataAugmentation:
265 | augmentation = self.random_augmentation()
266 | objects.append({
267 | "object_metadata": object_metadata,
268 | "segment_id": curr_segment_id,
269 | "augmentation": augmentation
270 | })
271 | else:
272 | objects.append({
273 | "object_metadata": object_metadata,
274 | "segment_id": curr_segment_id
275 | })
276 | curr_segment_id += 1
277 |
278 | num_bboxes = len(objects)
279 | # based on the COCO website
280 | num_large = int(round(num_bboxes * 0.24))
281 | num_mid = int(round(num_bboxes * 0.34))
282 | num_small = num_bboxes - num_large - num_mid
283 | object_positions = self.fine_grained_bbox_layout_generator.generate(num_large=num_large, num_mid=num_mid,
284 | num_small=num_small, width=width, height=height)
285 | if considerArea:
286 | bbox_with_area = []
287 | for bbox in object_positions:
288 | # Calculate width and height from [x_min, y_min, x_max, y_max]
289 | width = bbox[2] - bbox[0]
290 | height = bbox[3] - bbox[1]
291 | bbox_area = width * height
292 | bbox_with_area.append({"bbox": bbox, "area": bbox_area})
293 |
294 | # Sort objects based on the computed object area (descending order)
295 | objects_sorted = sorted(objects, key=lambda x: x['object_metadata']['area'], reverse=True)
296 | # Sort bounding boxes based on their computed area (descending order)
297 | bboxes_sorted = sorted(bbox_with_area, key=lambda x: x['area'], reverse=True)
298 |
299 | # Assign the sorted bounding boxes to the sorted objects
300 | for idx, obj in enumerate(objects_sorted):
301 | obj["object_position"] = bboxes_sorted[idx]["bbox"]
302 | image_metadata["objects"] = objects_sorted
303 | # print the match results
304 | for obj in objects_sorted:
305 | print(f"Object: {obj['object_metadata']['image_path']}, Area: {obj['object_metadata']['area']}")
306 | for bbox in bboxes_sorted:
307 | print(f"bbox: {bbox['bbox']}, Area: {bbox['area']}")
308 | print("in")
309 | else:
310 | for idx, obj in enumerate(objects):
311 | obj["object_position"] = object_positions[idx]
312 | image_metadata["objects"] = objects
313 |
314 | return image_metadata
315 |
316 |
317 | class FineGrainedBoundingBoxSingleCategorySegmentationSynthesizer(BaseSegmentationSynthesizer):
318 | synthesize_method = "fine_grained_bbox"
319 | def __init__(self, data_manager, save_path, random_seed=None):
320 | super().__init__(data_manager, save_path, random_seed)
321 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator()
322 |
323 | def sampling_metadata(
324 | self,
325 | width,
326 | height,
327 | number_of_objects,
328 | hasBackground=False,
329 | dataAugmentation=False,
330 | considerArea=False
331 | ):
332 | image_metadata = {
333 | "width": width,
334 | "height": height,
335 | "number_of_objects": number_of_objects,
336 | "synthesize_method": self.synthesize_method
337 | }
338 |
339 | if hasBackground:
340 | bg_meta = self.data_manager.get_random_background_metadata(self.rng)
341 | image_metadata["background"] = {"background_metadata": bg_meta}
342 |
343 | # First object: random to determine category
344 | first_meta = self.data_manager.get_random_object_metadata(self.rng)
345 | category = first_meta.get('category')
346 | if category is None:
347 | raise KeyError("First sampled object metadata has no 'category' field.")
348 |
349 | objects = []
350 | curr_id = 1
351 | # Process first object
352 | if considerArea:
353 | path = first_meta['image_path']
354 | with Image.open(path) as img:
355 | arr = np.array(img)
356 | first_meta['area'] = int(np.sum(arr[:, :, 3] > 0))
357 | entry = {"object_metadata": first_meta, "segment_id": curr_id}
358 | if dataAugmentation:
359 | entry["augmentation"] = self.random_augmentation()
360 | objects.append(entry)
361 | curr_id += 1
362 |
363 | # Sample remaining objects of the same category
364 | for _ in range(number_of_objects - 1):
365 | obj_meta = self.data_manager.get_random_object_metadata_by_category(self.rng, category)
366 | if considerArea:
367 | path = obj_meta['image_path']
368 | with Image.open(path) as img:
369 | arr = np.array(img)
370 | obj_meta['area'] = int(np.sum(arr[:, :, 3] > 0))
371 | entry = {"object_metadata": obj_meta, "segment_id": curr_id}
372 | if dataAugmentation:
373 | entry["augmentation"] = self.random_augmentation()
374 | objects.append(entry)
375 | curr_id += 1
376 |
377 | # Compute bbox size splits per COCO proportions
378 | n = len(objects)
379 | n_large = int(round(n * 0.24))
380 | n_mid = int(round(n * 0.34))
381 | n_small = n - n_large - n_mid
382 | bboxes = self.fine_grained_bbox_layout_generator.generate(
383 | num_large=n_large,
384 | num_mid=n_mid,
385 | num_small=n_small,
386 | width=width,
387 | height=height
388 | )
389 |
390 | if considerArea:
391 | boxes_with_area = [
392 | {"bbox": box, "area": (box[2] - box[0]) * (box[3] - box[1])}
393 | for box in bboxes
394 | ]
395 | objects_sorted = sorted(objects, key=lambda o: o['object_metadata']['area'], reverse=True)
396 | boxes_sorted = sorted(boxes_with_area, key=lambda b: b['area'], reverse=True)
397 | for obj, box in zip(objects_sorted, boxes_sorted):
398 | obj['object_position'] = box['bbox']
399 | image_metadata['objects'] = objects_sorted
400 | else:
401 | for obj, box in zip(objects, bboxes):
402 | obj['object_position'] = box
403 | image_metadata['objects'] = objects
404 |
405 | return image_metadata
406 |
407 |
408 |
409 |
410 | class ExistingLayoutSegmentationSynthesizer(BaseSegmentationSynthesizer):
411 | synthesize_method = "existing_layout"
412 | def __init__(self, data_manager, save_path, table_path, random_seed = None):
413 | super().__init__(data_manager, save_path, random_seed)
414 | self.existing_layout_generator = ExistingLayoutGenerator(table_path=table_path)
415 |
416 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=None, considerArea=False):# -> dict:
417 | image_metadata = {}
418 | image_metadata["width"] = width
419 | image_metadata["height"] = height
420 | image_metadata["number_of_objects"] = number_of_objects
421 | image_metadata["synthesize_method"] = self.synthesize_method
422 | # add background image background
423 |
424 | if hasBackground:
425 | background_metadata = self.data_manager.get_random_background_metadata(self.rng)
426 | image_metadata["background"] = {"background_metadata": background_metadata}
427 |
428 | curr_segment_id = 1
429 | objects = []
430 | # sort based on the original size
431 | for _ in range(number_of_objects):
432 | # Randomly select a segmentation from the data manager
433 | object_metadata = self.data_manager.get_random_object_metadata(self.rng)
434 | if dataAugmentation:
435 | augmentation = self.random_augmentation()
436 | objects.append({
437 | "object_metadata": object_metadata,
438 | "segment_id": curr_segment_id,
439 | "augmentation": augmentation
440 | })
441 | else:
442 | objects.append({
443 | "object_metadata": object_metadata,
444 | "segment_id": curr_segment_id
445 | })
446 | curr_segment_id += 1
447 |
448 | segment_path_list = [obj['object_metadata']['image_path'] for obj in objects]
449 | if considerArea:
450 | # havn't implemented the area matching algorithm
451 | # object_areas = [obj['object_metadata']['area'] for obj in objects]
452 | object_areas = segment_path_list
453 | object_positions = self.existing_layout_generator.predict_wi_area(num_bboxes=len(objects), segment_path_list=segment_path_list, areas=object_areas, width=width, height=height)
454 | else:
455 | object_positions = self.existing_layout_generator.predict_wo_area(num_bboxes=len(objects), width=width, height=height)
456 |
457 | for idx, obj in enumerate(objects):
458 | obj["object_position"] = object_positions[idx]
459 | image_metadata["objects"] = objects
460 |
461 | return image_metadata
462 |
463 |
464 | class FineGrainedBoundingBoxSegmentationSynthesizer(BaseSegmentationSynthesizer):
465 | synthesize_method = "fine_grained_bbox"
466 | def __init__(self, data_manager, save_path, random_seed=None):
467 | super().__init__(data_manager, save_path, random_seed)
468 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator()
469 |
470 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False, considerArea=False): # -> dict:
471 | image_metadata = {}
472 | image_metadata["width"] = width
473 | image_metadata["height"] = height
474 | image_metadata["number_of_objects"] = number_of_objects
475 | image_metadata["synthesize_method"] = self.synthesize_method
476 |
477 | if hasBackground:
478 | background_metadata = self.data_manager.get_random_background_metadata(self.rng)
479 | image_metadata["background"] = {"background_metadata": background_metadata}
480 |
481 | curr_segment_id = 1
482 | objects = []
483 | # sort based on the original size
484 | for _ in range(number_of_objects):
485 | # Randomly select a segmentation from the data manager
486 | object_metadata = self.data_manager.get_random_object_metadata(self.rng)
487 | if considerArea:
488 | # add area
489 | path = object_metadata['image_path']
490 | with Image.open(path) as img:
491 | img_np = np.array(img)
492 | area = np.sum(img_np[:, :, 3] > 0)
493 | object_metadata['area'] = area
494 |
495 | if dataAugmentation:
496 | augmentation = self.random_augmentation()
497 | objects.append({
498 | "object_metadata": object_metadata,
499 | "segment_id": curr_segment_id,
500 | "augmentation": augmentation
501 | })
502 | else:
503 | objects.append({
504 | "object_metadata": object_metadata,
505 | "segment_id": curr_segment_id
506 | })
507 | curr_segment_id += 1
508 |
509 | num_bboxes = len(objects)
510 | # based on the COCO website
511 | num_large = int(round(num_bboxes * 0.24))
512 | num_mid = int(round(num_bboxes * 0.34))
513 | num_small = num_bboxes - num_large - num_mid
514 | object_positions = self.fine_grained_bbox_layout_generator.generate(num_large=num_large, num_mid=num_mid,
515 | num_small=num_small, width=width, height=height)
516 | if considerArea:
517 | bbox_with_area = []
518 | for bbox in object_positions:
519 | # Calculate width and height from [x_min, y_min, x_max, y_max]
520 | width = bbox[2] - bbox[0]
521 | height = bbox[3] - bbox[1]
522 | bbox_area = width * height
523 | bbox_with_area.append({"bbox": bbox, "area": bbox_area})
524 |
525 | # Sort objects based on the computed object area (descending order)
526 | objects_sorted = sorted(objects, key=lambda x: x['object_metadata']['area'], reverse=True)
527 | # Sort bounding boxes based on their computed area (descending order)
528 | bboxes_sorted = sorted(bbox_with_area, key=lambda x: x['area'], reverse=True)
529 |
530 | # Assign the sorted bounding boxes to the sorted objects
531 | for idx, obj in enumerate(objects_sorted):
532 | obj["object_position"] = bboxes_sorted[idx]["bbox"]
533 | image_metadata["objects"] = objects_sorted
534 | # print the match results
535 | for obj in objects_sorted:
536 | print(f"Object: {obj['object_metadata']['image_path']}, Area: {obj['object_metadata']['area']}")
537 | for bbox in bboxes_sorted:
538 | print(f"bbox: {bbox['bbox']}, Area: {bbox['area']}")
539 | else:
540 | for idx, obj in enumerate(objects):
541 | obj["object_position"] = object_positions[idx]
542 | image_metadata["objects"] = objects
543 |
544 | return image_metadata
--------------------------------------------------------------------------------
/relighting_and_blending/inference.py:
--------------------------------------------------------------------------------
1 | # import debugpy
2 | # try:
3 | # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
4 | # debugpy.listen(("localhost", 9501))
5 | # print("Waiting for debugger attach")
6 | # debugpy.wait_for_client()
7 | # except Exception as e:
8 | # pass
9 |
10 | #!/usr/bin/env python
11 | """
12 | Modified version of inference_with_blending.py with GCS integration.
13 | Supports downloading input images (and related JSON/masks) from GCS and uploading outputs to GCS,
14 | similar to the provided flux pipeline code.
15 | """
16 |
17 | import os
18 | import re
19 | import json
20 | import time
21 | import argparse
22 | import numpy as np
23 | from tqdm import tqdm
24 | import tempfile
25 |
26 | from PIL import Image
27 | from skimage.color import rgb2lab, lab2rgb
28 | import torch
29 | import safetensors.torch as sf
30 |
31 | # Diffusers and related libraries
32 | from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
33 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
34 | from diffusers.models.attention_processor import AttnProcessor2_0
35 | from transformers import CLIPTextModel, CLIPTokenizer
36 | from torch.hub import download_url_to_file
37 |
38 | # Custom functions from your previous code (e.g., blending utilities)
39 | from gradio_demo import (
40 | hooked_unet_forward,
41 | encode_prompt_pair,
42 | pytorch2numpy,
43 | numpy2pytorch,
44 | resize_and_center_crop,
45 | resize_without_crop,
46 | )
47 |
48 | # --- GCS Utility Functions ---
49 | from google.cloud import storage
50 |
51 | def parse_gcs_path(gcs_path):
52 | """
53 | Given a GCS path in the form gs://bucket_name/path/to/file,
54 | returns a tuple (bucket_name, blob_path).
55 | """
56 | if not gcs_path.startswith("gs://"):
57 | raise ValueError("Not a valid GCS path")
58 | path = gcs_path[5:]
59 | parts = path.split("/", 1)
60 | bucket = parts[0]
61 | blob = parts[1] if len(parts) > 1 else ""
62 | return bucket, blob
63 |
64 | def download_from_gcs(gcs_path, local_path):
65 | """
66 | Downloads a file from a GCS path to a local file.
67 | """
68 | bucket_name, blob_name = parse_gcs_path(gcs_path)
69 | client = storage.Client()
70 | bucket = client.bucket(bucket_name)
71 | blob = bucket.blob(blob_name)
72 | blob.download_to_filename(local_path)
73 | print(f"Downloaded {gcs_path} to {local_path}")
74 |
75 | def upload_to_gcs(local_path, gcs_path):
76 | """
77 | Upload a local file to a GCS path.
78 | """
79 | bucket_name, blob_name = parse_gcs_path(gcs_path)
80 | client = storage.Client()
81 | bucket = client.bucket(bucket_name)
82 | blob = bucket.blob(blob_name)
83 | blob.upload_from_filename(local_path)
84 | print(f"Uploaded {local_path} to {gcs_path}")
85 |
86 | def gcs_blob_exists(gcs_path):
87 | """
88 | Check if a blob exists at the given GCS path.
89 | """
90 | bucket_name, blob_name = parse_gcs_path(gcs_path)
91 | client = storage.Client()
92 | bucket = client.bucket(bucket_name)
93 | blob = bucket.blob(blob_name)
94 | return blob.exists()
95 |
96 | def list_gcs_files(gcs_path, suffix=""):
97 | """
98 | List file names (relative to the prefix) from a GCS path with an optional suffix filter.
99 | """
100 | bucket_name, prefix = parse_gcs_path(gcs_path)
101 | client = storage.Client()
102 | bucket = client.bucket(bucket_name)
103 | blobs = list(client.list_blobs(bucket, prefix=prefix))
104 | file_names = []
105 | for blob in blobs:
106 | if blob.name.endswith(suffix):
107 | # Remove the prefix portion for clarity (if needed)
108 | relative_name = blob.name[len(prefix):].lstrip("/")
109 | file_names.append(relative_name)
110 | return file_names
111 |
112 | def build_color_mask_gcs_path(data_path, file_id, original_folder="train", new_folder="panoptic_train"):
113 | """
114 | Builds the GCS path for the color mask by replacing the folder name
115 | in the blob portion without altering the bucket name.
116 | """
117 | if not data_path.startswith("gs://"):
118 | raise ValueError("Not a valid GCS path")
119 | # Remove the "gs://" prefix.
120 | path_without_prefix = data_path[5:]
121 | # Split into bucket and blob path.
122 | bucket, blob_path = path_without_prefix.split("/", 1)
123 | # Replace only the intended folder segment in blob_path.
124 | new_blob_path = blob_path.replace(original_folder, new_folder, 1)
125 | return f"gs://{bucket}/{new_blob_path}/{file_id}.png"
126 |
127 | # --- Diffusion and Relightening Setup ---
128 |
129 | from enum import Enum
130 |
131 | class BGSource(Enum):
132 | NONE = "None"
133 | LEFT = "Left Light"
134 | RIGHT = "Right Light"
135 | TOP = "Top Light"
136 | BOTTOM = "Bottom Light"
137 |
138 | # Set up Stable Diffusion components
139 | sd15_name = 'stablediffusionapi/realistic-vision-v51'
140 | tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
141 | text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
142 | vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
143 | unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
144 |
145 | # RMBG and UNet modifications
146 | from briarmbg import BriaRMBG
147 | rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
148 |
149 | with torch.no_grad():
150 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels,
151 | unet.conv_in.kernel_size,
152 | unet.conv_in.stride,
153 | unet.conv_in.padding)
154 | new_conv_in.weight.zero_()
155 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
156 | new_conv_in.bias = unet.conv_in.bias
157 | unet.conv_in = new_conv_in
158 |
159 | unet_original_forward = unet.forward
160 | unet.forward = hooked_unet_forward
161 |
162 | # Load model offset and merge with UNet
163 | model_path = './models/iclight_sd15_fc.safetensors'
164 | if not os.path.exists(model_path):
165 | download_url_to_file(
166 | url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors',
167 | dst=model_path
168 | )
169 | sd_offset = sf.load_file(model_path)
170 | sd_origin = unet.state_dict()
171 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
172 | unet.load_state_dict(sd_merged, strict=True)
173 | del sd_offset, sd_origin, sd_merged
174 |
175 | # Set device and send models to device
176 | device = torch.device('cuda')
177 | text_encoder = text_encoder.to(device=device, dtype=torch.float16)
178 | vae = vae.to(device=device, dtype=torch.bfloat16)
179 | unet = unet.to(device=device, dtype=torch.float16)
180 | rmbg = rmbg.to(device=device, dtype=torch.float32)
181 |
182 | # SDP attention processor
183 | unet.set_attn_processor(AttnProcessor2_0())
184 | vae.set_attn_processor(AttnProcessor2_0())
185 |
186 | # Define schedulers
187 | ddim_scheduler = DDIMScheduler(
188 | num_train_timesteps=1000,
189 | beta_start=0.00085,
190 | beta_end=0.012,
191 | beta_schedule="scaled_linear",
192 | clip_sample=False,
193 | set_alpha_to_one=False,
194 | steps_offset=1,
195 | )
196 | euler_a_scheduler = EulerAncestralDiscreteScheduler(
197 | num_train_timesteps=1000,
198 | beta_start=0.00085,
199 | beta_end=0.012,
200 | steps_offset=1
201 | )
202 | dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
203 | num_train_timesteps=1000,
204 | beta_start=0.00085,
205 | beta_end=0.012,
206 | algorithm_type="sde-dpmsolver++",
207 | use_karras_sigmas=True,
208 | steps_offset=1
209 | )
210 |
211 | # Pipelines for text-to-image and image-to-image
212 | t2i_pipe = StableDiffusionPipeline(
213 | vae=vae,
214 | text_encoder=text_encoder,
215 | tokenizer=tokenizer,
216 | unet=unet,
217 | scheduler=dpmpp_2m_sde_karras_scheduler,
218 | safety_checker=None,
219 | requires_safety_checker=False,
220 | feature_extractor=None,
221 | image_encoder=None
222 | )
223 | i2i_pipe = StableDiffusionImg2ImgPipeline(
224 | vae=vae,
225 | text_encoder=text_encoder,
226 | tokenizer=tokenizer,
227 | unet=unet,
228 | scheduler=dpmpp_2m_sde_karras_scheduler,
229 | safety_checker=None,
230 | requires_safety_checker=False,
231 | feature_extractor=None,
232 | image_encoder=None
233 | )
234 |
235 | # --- Utility Functions for Image Parsing and Blending ---
236 |
237 | def parse_rgba(img, sigma=0.0):
238 | """
239 | Given a RGBA image (as a NumPy array), returns the blended RGB image using the alpha channel.
240 | """
241 | assert img.shape[2] == 4, "Input image must have 4 channels (RGBA)."
242 | rgb = img[:, :, :3]
243 | alpha = img[:, :, 3].astype(np.float32) / 255.0
244 | # set alpha to be like if it is transparent then 0, otherwise 1
245 | # alpha = np.where(alpha < 0.5, 0.0, 1.0)
246 | result = 127 + (rgb.astype(np.float32) - 127 + sigma) * alpha[:, :, None]
247 | # temporarily
248 | return result.clip(0, 255).astype(np.uint8), alpha
249 |
250 | def blend_images_with_mask_rank_sigmoid(old_image, new_image, color_mask, alpha_min=0.3, alpha_max=0.9, steepness=10):
251 | """
252 | Blends new_image into old_image according to a color mask using a sigmoid function based on segment area.
253 | """
254 | if not isinstance(old_image, Image.Image):
255 | old_image = Image.fromarray(old_image)
256 | if not isinstance(new_image, Image.Image):
257 | new_image = Image.fromarray(new_image)
258 | if not isinstance(color_mask, Image.Image):
259 | color_mask = Image.fromarray(color_mask)
260 |
261 | old_image = old_image.convert("RGBA")
262 | new_image = new_image.convert("RGBA")
263 | color_mask = color_mask.convert("RGBA")
264 |
265 | old_np = np.array(old_image).astype(np.float32)
266 | new_np = np.array(new_image).astype(np.float32)
267 | mask_np = np.array(color_mask).astype(np.float32)
268 | mask_rgb = mask_np[..., :3]
269 | total_pixels = mask_rgb.shape[0] * mask_rgb.shape[1]
270 |
271 | unique_colors = np.unique(mask_rgb.reshape(-1, 3), axis=0)
272 | unique_colors = [c for c in unique_colors if not np.allclose(c, [0, 0, 0])]
273 |
274 | color_to_norm_area = {}
275 | for color in unique_colors:
276 | region = ((mask_rgb[..., 0] == color[0]) &
277 | (mask_rgb[..., 1] == color[1]) &
278 | (mask_rgb[..., 2] == color[2]))
279 | area = np.sum(region)
280 | norm_area = area / total_pixels
281 | color_to_norm_area[tuple(color)] = norm_area
282 |
283 | segments_sorted = sorted(color_to_norm_area.items(), key=lambda x: x[1])
284 | total_segments = len(segments_sorted)
285 | color_to_alpha = {}
286 | for rank, (color, norm_area) in enumerate(segments_sorted):
287 | normalized_rank = rank / (total_segments - 1) if total_segments > 1 else 0
288 | sigmoid_value = 1 / (1 + np.exp(steepness * (normalized_rank - 0.5)))
289 | alpha = alpha_min + (alpha_max - alpha_min) * sigmoid_value
290 | color_to_alpha[color] = alpha
291 | print(f"Segment color: {color}, Norm Area: {norm_area:.6f}, Rank: {rank}, Normalized Rank: {normalized_rank:.3f}, Alpha: {alpha}")
292 |
293 | # Convert images from RGBA numpy arrays to Lab color space for lightness-only blending
294 | rgb_old = old_np[..., :3] / 255.0
295 | rgb_new = new_np[..., :3] / 255.0
296 | lab_old = rgb2lab(rgb_old)
297 | lab_new = rgb2lab(rgb_new)
298 | lab_out = lab_old.copy()
299 |
300 | # Blend only the L channel per segment while preserving original a/b channels
301 | for color, alpha in color_to_alpha.items():
302 | region = ((mask_rgb[..., 0] == color[0]) &
303 | (mask_rgb[..., 1] == color[1]) &
304 | (mask_rgb[..., 2] == color[2]))
305 | # Update lightness channel
306 | lab_out[..., 0][region] = alpha * lab_old[..., 0][region] + (1 - alpha) * lab_new[..., 0][region]
307 | # Lightly blend chroma channels (a and b) to avoid unrealistic color shifts
308 | lab_out[..., 1][region] = 0.25 * alpha * lab_old[..., 1][region] + (1 - 0.25 * alpha) * lab_new[..., 1][region]
309 | lab_out[..., 2][region] = 0.25 * alpha * lab_old[..., 2][region] + (1 - 0.25 * alpha) * lab_new[..., 2][region]
310 | # lab_out[..., 1][region] = lab_old[..., 1][region]
311 | # lab_out[..., 2][region] = lab_old[..., 2][region]
312 | lab_out[..., 1][region] = lab_new[..., 1][region]
313 | lab_out[..., 2][region] = lab_new[..., 2][region]
314 | # Convert Lab back to RGB
315 | rgb_out = lab2rgb(lab_out)
316 | rgb_out_uint8 = (rgb_out * 255).clip(0, 255).astype(np.uint8)
317 | blended_image = Image.fromarray(rgb_out_uint8)
318 | return blended_image
319 |
320 | # --- Inference Functions ---
321 |
322 | @torch.inference_mode()
323 | def process(input_fg, prompt, image_width, image_height, num_samples,
324 | seed, steps, a_prompt, n_prompt, cfg,
325 | highres_scale, highres_denoise, lowres_denoise, bg_source):
326 | bg_source = BGSource(bg_source)
327 | input_bg = None
328 | if bg_source == BGSource.NONE:
329 | pass
330 | elif bg_source == BGSource.LEFT:
331 | gradient = np.linspace(255, 0, image_width)
332 | image = np.tile(gradient, (image_height, 1))
333 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
334 | elif bg_source == BGSource.RIGHT:
335 | gradient = np.linspace(0, 255, image_width)
336 | image = np.tile(gradient, (image_height, 1))
337 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
338 | elif bg_source == BGSource.TOP:
339 | gradient = np.linspace(255, 0, image_height)[:, None]
340 | image = np.tile(gradient, (1, image_width))
341 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
342 | elif bg_source == BGSource.BOTTOM:
343 | gradient = np.linspace(0, 255, image_height)[:, None]
344 | image = np.tile(gradient, (1, image_width))
345 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
346 | else:
347 | raise ValueError("Wrong initial latent!")
348 |
349 | rng = torch.Generator(device=device).manual_seed(int(seed))
350 | fg = resize_and_center_crop(input_fg, image_width, image_height)
351 | concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
352 | concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
353 | conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
354 |
355 | if input_bg is None:
356 | latents = t2i_pipe(
357 | prompt_embeds=conds,
358 | negative_prompt_embeds=unconds,
359 | width=image_width,
360 | height=image_height,
361 | num_inference_steps=steps,
362 | num_images_per_prompt=num_samples,
363 | generator=rng,
364 | output_type='latent',
365 | guidance_scale=cfg,
366 | cross_attention_kwargs={'concat_conds': concat_conds},
367 | ).images.to(vae.dtype) / vae.config.scaling_factor
368 | else:
369 | bg = resize_and_center_crop(input_bg, image_width, image_height)
370 | bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
371 | bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
372 | latents = i2i_pipe(
373 | image=bg_latent,
374 | strength=lowres_denoise,
375 | prompt_embeds=conds,
376 | negative_prompt_embeds=unconds,
377 | width=image_width,
378 | height=image_height,
379 | num_inference_steps=int(round(steps / lowres_denoise)),
380 | num_images_per_prompt=num_samples,
381 | generator=rng,
382 | output_type='latent',
383 | guidance_scale=cfg,
384 | cross_attention_kwargs={'concat_conds': concat_conds},
385 | ).images.to(vae.dtype) / vae.config.scaling_factor
386 |
387 | pixels = vae.decode(latents).sample
388 | pixels = pytorch2numpy(pixels)
389 | pixels = [resize_without_crop(
390 | image=p,
391 | target_width=int(round(image_width * highres_scale / 64.0) * 64),
392 | target_height=int(round(image_height * highres_scale / 64.0) * 64))
393 | for p in pixels]
394 | pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
395 | latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
396 | latents = latents.to(device=unet.device, dtype=unet.dtype)
397 | image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
398 | fg = resize_and_center_crop(input_fg, image_width, image_height)
399 | concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
400 | concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
401 | latents = i2i_pipe(
402 | image=latents,
403 | strength=highres_denoise,
404 | prompt_embeds=conds,
405 | negative_prompt_embeds=unconds,
406 | width=image_width,
407 | height=image_height,
408 | num_inference_steps=int(round(steps / highres_denoise)),
409 | num_images_per_prompt=num_samples,
410 | generator=rng,
411 | output_type='latent',
412 | guidance_scale=cfg,
413 | cross_attention_kwargs={'concat_conds': concat_conds},
414 | ).images.to(vae.dtype) / vae.config.scaling_factor
415 |
416 | pixels = vae.decode(latents).sample
417 | return pytorch2numpy(pixels)
418 |
419 | @torch.inference_mode()
420 | def process_relight(input_fg, prompt, image_width, image_height, num_samples,
421 | seed, steps, a_prompt, n_prompt, cfg,
422 | highres_scale, highres_denoise, lowres_denoise, bg_source):
423 | input_fg, _ = parse_rgba(input_fg)
424 | results = process(input_fg, prompt, image_width, image_height, num_samples,
425 | seed, steps, a_prompt, n_prompt, cfg,
426 | highres_scale, highres_denoise, lowres_denoise, bg_source)
427 | return input_fg, results
428 |
429 | def adjust_dimensions(width, height, max_dim=1024, divisible_by=8):
430 | # For simplicity, we return a fixed dimension (or you can implement your resizing logic)
431 | return 1024, 1024
432 |
433 | # --- Main Inference Workflow with GCS Integration ---
434 |
435 | def main(args):
436 | data_path = args.dataset_path
437 | output_data_path = args.output_data_path
438 | illuminate_prompts_path = args.illuminate_prompts_path
439 | record_path = args.record_path
440 |
441 | # Determine if the dataset, index JSON, and output_data_path reside on GCS.
442 | input_on_gcs = data_path.startswith("gs://")
443 | output_on_gcs = output_data_path.startswith("gs://")
444 | index_on_gcs = args.index_json_path is not None and args.index_json_path.startswith("gs://")
445 | illuminate_on_gcs = illuminate_prompts_path.startswith("gs://") if illuminate_prompts_path else False
446 |
447 | # Create a local temporary directory to use when downloading files from GCS.
448 | temp_dir = tempfile.mkdtemp()
449 |
450 | # If the illumination prompts are on GCS, download them locally.
451 | if illuminate_on_gcs:
452 | local_illuminate_path = os.path.join(temp_dir, "illumination_prompt.json")
453 | download_from_gcs(illuminate_prompts_path, local_illuminate_path)
454 | illuminate_prompts_path = local_illuminate_path
455 |
456 | with open(illuminate_prompts_path, "r") as f:
457 | illuminate_prompts = json.load(f)
458 |
459 | records = {}
460 | split_index = args.split
461 | num_splits = args.num_splits
462 |
463 | # Prepare list of filenames based on index JSON if provided; otherwise list .png files in dataset.
464 | if args.index_json_path:
465 | # If index JSON is on GCS, download it.
466 | if index_on_gcs:
467 | local_index_path = os.path.join(temp_dir, "index.json")
468 | download_from_gcs(args.index_json_path, local_index_path)
469 | index_json_path = local_index_path
470 | else:
471 | index_json_path = args.index_json_path
472 |
473 | with open(index_json_path, 'r') as f:
474 | all_filenames = json.load(f)
475 | if not isinstance(all_filenames, list):
476 | raise ValueError("The index JSON file must contain a list of filenames.")
477 | splits = np.array_split(all_filenames, num_splits)
478 | split_filenames = list(splits[split_index])
479 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images from index JSON.")
480 | else:
481 | if input_on_gcs:
482 | # List .png files from the GCS path.
483 | bucket_name, prefix = parse_gcs_path(data_path)
484 | files = list_gcs_files(data_path, suffix=".png")
485 | # Optionally, sort numerically if filenames are numbers.
486 | pattern = re.compile(r'^(\d+)\.png$')
487 | file_numbers = []
488 | for f in files:
489 | m = pattern.match(os.path.basename(f))
490 | if m:
491 | file_numbers.append((int(m.group(1)), f))
492 | file_numbers.sort(key=lambda x: x[0])
493 | sorted_filenames = [f for _, f in file_numbers]
494 | splits = np.array_split(sorted_filenames, num_splits)
495 | split_filenames = list(splits[split_index])
496 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images from GCS.")
497 | else:
498 | # Local directory listing.
499 | all_files = [f for f in os.listdir(data_path) if f.endswith(".png")]
500 | pattern = re.compile(r'^(\d+)\.png$')
501 | filtered_files = []
502 | for f in all_files:
503 | m = pattern.match(f)
504 | if m:
505 | numeric_value = int(m.group(1))
506 | filtered_files.append((numeric_value, f))
507 | filtered_files.sort(key=lambda x: x[0])
508 | sorted_filenames = [f for _, f in filtered_files]
509 | splits = np.array_split(sorted_filenames, num_splits)
510 | split_filenames = list(splits[split_index])
511 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images (local).")
512 |
513 | # Prepare the output destination.
514 | if not output_on_gcs:
515 | os.makedirs(output_data_path, exist_ok=True)
516 |
517 | # Process each file.
518 | for fg_name in tqdm(split_filenames, desc="Processing images"):
519 | # For input image, if residing on GCS, construct full GCS path and download to temp file.
520 | if input_on_gcs:
521 | full_fg_path = os.path.join(data_path, fg_name)
522 | local_fg_path = os.path.join(temp_dir, os.path.basename(fg_name))
523 | download_from_gcs(full_fg_path, local_fg_path)
524 | else:
525 | local_fg_path = os.path.join(data_path, fg_name)
526 |
527 | # Open the foreground image.
528 | try:
529 | input_fg = np.array(Image.open(local_fg_path))
530 | except Exception as e:
531 | print(f"Error opening image {local_fg_path}: {e}")
532 | continue
533 |
534 | # Determine output file name.
535 | file_id = os.path.splitext(os.path.basename(fg_name))[0]
536 | if output_on_gcs:
537 | output_blob = f"{file_id}.jpg"
538 | full_output_path = os.path.join(output_data_path, output_blob).replace("\\", "/")
539 | if gcs_blob_exists(full_output_path):
540 | print(f"Skipping '{fg_name}': Output blob already exists on GCS.")
541 | continue
542 | else:
543 | output_path = os.path.join(output_data_path, f"{file_id}.jpg")
544 | if os.path.exists(output_path):
545 | print(f"Skipping '{fg_name}': Output file '{output_path}' already exists.")
546 | continue
547 |
548 | # Dynamically adjust dimensions.
549 | orig_height, orig_width = input_fg.shape[:2]
550 | image_width, image_height = adjust_dimensions(orig_width, orig_height, max_dim=1024, divisible_by=8)
551 | print(f"Processing '{fg_name}': Adjusted dimensions: {image_width} x {image_height}")
552 |
553 | # Select a random prompt from illumination prompts.
554 | prompt = np.random.choice(illuminate_prompts)
555 | bg_source = np.random.choice([BGSource.NONE, BGSource.NONE, BGSource.NONE, BGSource.NONE,
556 | BGSource.LEFT, BGSource.RIGHT, BGSource.TOP, BGSource.BOTTOM])
557 | seed = 123456
558 | steps = 25
559 | a_prompt = "not obvious objects in the background, best quality, don't significantly change foreground objects, keep its semantic meaning"
560 | n_prompt = "have obvious objects in the background, lowres, bad anatomy, bad hands, cropped, worst quality, change foreground objects, don't keep its semantic meaning"
561 | cfg = 2.0
562 | highres_scale = 1.0
563 | highres_denoise = 0.5
564 | lowres_denoise = 0.9
565 | num_samples = 1
566 |
567 | # Process relighting.
568 | input_fg, results = process_relight(
569 | input_fg=input_fg,
570 | prompt=prompt,
571 | image_width=image_width,
572 | image_height=image_height,
573 | num_samples=num_samples,
574 | seed=seed,
575 | steps=steps,
576 | a_prompt=a_prompt,
577 | n_prompt=n_prompt,
578 | cfg=cfg,
579 | highres_scale=highres_scale,
580 | highres_denoise=highres_denoise,
581 | lowres_denoise=lowres_denoise,
582 | bg_source=bg_source.value # Pass Enum string value if needed
583 | )
584 |
585 | # For blending, determine the color mask path.
586 | if input_on_gcs:
587 | # Use the helper function to build the color mask path properly.
588 | color_mask_path = build_color_mask_gcs_path(data_path, file_id)
589 | local_mask_path = os.path.join(temp_dir, f"{file_id}_mask.png")
590 | download_from_gcs(color_mask_path, local_mask_path)
591 | else:
592 | color_mask_path = os.path.join(data_path.replace("train", "panoptic_train"), f"{file_id}.png")
593 | local_mask_path = color_mask_path
594 |
595 | try:
596 | color_mask = Image.open(local_mask_path)
597 | except Exception as e:
598 | print(f"Error opening color mask {local_mask_path}: {e}")
599 | continue
600 |
601 | blended_image = blend_images_with_mask_rank_sigmoid(
602 | old_image=results[0],
603 | new_image=input_fg,
604 | color_mask=color_mask
605 | )
606 |
607 | # Save the output.
608 | if output_on_gcs:
609 | local_output_file = os.path.join(temp_dir, f"{file_id}.jpg")
610 | blended_image.save(local_output_file)
611 | upload_to_gcs(local_output_file, full_output_path)
612 | os.remove(local_output_file)
613 | else:
614 | blended_image.save(output_path)
615 | print(f"Saved relit image to '{output_path}'")
616 |
617 | # Record details.
618 | records[fg_name] = {
619 | "output_path": full_output_path if output_on_gcs else output_path,
620 | "prompt": prompt,
621 | "bg_source": bg_source.value,
622 | "seed": seed,
623 | "steps": steps,
624 | "cfg": cfg,
625 | "highres_scale": highres_scale,
626 | "highres_denoise": highres_denoise,
627 | "lowres_denoise": lowres_denoise
628 | }
629 |
630 | # (Optional) Save or update your records file.
631 | # if record_path:
632 | # if record_path.startswith("gs://"):
633 | # local_record_path = os.path.join(temp_dir, "record.json")
634 | # else:
635 | # local_record_path = record_path
636 | #
637 | # with open(local_record_path, 'w') as f:
638 | # json.dump(records, f, indent=4)
639 | # if record_path.startswith("gs://"):
640 | # upload_to_gcs(local_record_path, record_path)
641 | # os.remove(local_record_path)
642 | # print("Processing complete.")
643 |
644 | if __name__ == "__main__":
645 | def parse_args():
646 | parser = argparse.ArgumentParser(description="Relight images using Stable Diffusion pipelines with GCS integration.")
647 | parser.add_argument('--dataset_path', type=str, required=True,
648 | help="Path to the segment dataset. Can be a local path or a GCS path (e.g., gs://bucket/path).")
649 | parser.add_argument('--output_data_path', type=str, required=True,
650 | help="Path to save the output data. Can be a local path or a GCS path.")
651 | parser.add_argument('--num_splits', type=int, default=1, help="Number of splits to create")
652 | parser.add_argument('--split', type=int, default=0, help="Split index to process (0-indexed)")
653 | parser.add_argument('--index_json_path', type=str, default=None,
654 | help="Path to the JSON file containing image filenames; supports GCS paths.")
655 | parser.add_argument('--illuminate_prompts_path', type=str, required=True,
656 | help="Path to the JSON file containing illumination prompts; supports GCS paths.")
657 | parser.add_argument('--record_path', type=str, default=None,
658 | help="Path to the JSON file where records are saved; supports GCS paths.")
659 | return parser.parse_args()
660 |
661 | args = parse_args()
662 | main(args)
--------------------------------------------------------------------------------
/datasets/object_datasets.py:
--------------------------------------------------------------------------------
1 | from .base_datasets import BaseObjectDataset
2 | import os
3 | from tqdm import tqdm
4 | import pickle
5 | from concurrent.futures import ThreadPoolExecutor, as_completed
6 | import pandas as pd
7 | import json
8 | import re
9 | from shutil import copyfile
10 | from PIL import Image
11 | import PIL
12 | from .coco_categories import COCO_CATEGORIES
13 | def build_category_index(categories):
14 | return {category['name'].replace("_", "").replace(" ", ""): category['id'] for category in categories}
15 |
16 | COCO_CATEGORY_INDEX = build_category_index(COCO_CATEGORIES)
17 |
18 |
19 | def getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation):
20 | # getting the first value of the filtering_annotations as schema
21 |
22 |
23 | if universial_idx in filtering_annotations:
24 | return filtering_annotations[universial_idx]
25 | else:
26 | return default_annotation
27 |
28 | class ADE20KDataset(BaseObjectDataset):
29 | def __init__(self, dataset_path, filtering_annotations_path):
30 | dataset_name = "ADE20K"
31 | data_type = "panoptic"
32 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
33 | # get filtering annotation
34 | if self.filtering_annotations_path is not None:
35 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
36 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
37 | default_annotation = {key: False for key in default_annotation.keys()}
38 | else:
39 | filtering_annotations = None
40 |
41 | # get metadata_table
42 | index_counter = 0
43 | rows = []
44 | for folder_name in os.listdir(self.dataset_path):
45 | folder_path = os.path.join(self.dataset_path, folder_name)
46 | if os.path.isdir(folder_path):
47 | for file_name in os.listdir(folder_path):
48 | if ".png" in file_name:
49 | image_path = os.path.join(self.dataset_path, folder_name, file_name)
50 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}"
51 | rows.append({
52 | 'universal_idx': universial_idx,
53 | 'idx_for_curr_dataset': index_counter,
54 | 'category': folder_name,
55 | 'image_path': image_path,
56 | 'mask_path': image_path, # Assuming mask_path is the same as image_path
57 | 'dataset_name': dataset_name,
58 | 'data_type': data_type,
59 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None
60 | })
61 | index_counter += 1
62 | self.metadata_table = pd.DataFrame(rows)
63 | self.categories = set(self.metadata_table['category'])
64 |
65 | class VOC2012Dataset(BaseObjectDataset):
66 | def __init__(self, dataset_path, filtering_annotations_path):
67 | dataset_name = "VOC2012"
68 | data_type = "panoptic"
69 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
70 | # get filtering annotation
71 | if self.filtering_annotations_path is not None:
72 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
73 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
74 | default_annotation = {key: False for key in default_annotation.keys()}
75 | else:
76 | filtering_annotations = None
77 |
78 | # get metadata_table
79 | index_counter = 0
80 | rows = []
81 | for folder_name in os.listdir(self.dataset_path):
82 | folder_path = os.path.join(self.dataset_path, folder_name)
83 | if os.path.isdir(folder_path):
84 | for file_name in os.listdir(folder_path):
85 | if ".png" in file_name:
86 | image_path = os.path.join(self.dataset_path, folder_name, file_name)
87 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}"
88 | rows.append({
89 | 'universal_idx': universial_idx,
90 | 'idx_for_curr_dataset': index_counter,
91 | 'category': folder_name,
92 | 'image_path': image_path,
93 | 'mask_path': image_path, # Assuming mask_path is the same as image_path
94 | 'dataset_name': dataset_name,
95 | 'data_type': data_type,
96 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None
97 | })
98 | index_counter += 1
99 | self.metadata_table = pd.DataFrame(rows)
100 | self.categories = set(self.metadata_table['category'])
101 |
102 |
103 |
104 | # First, build an index mapping cleaned category names to the category data.
105 | # Now define the lookup function that uses the pre-built index.
106 |
107 | class COCO2017Dataset(BaseObjectDataset):
108 | def _load_coco_id(self, folder_name: str) -> dict:
109 | tokens = folder_name.strip().split()
110 | if not tokens:
111 | raise ValueError("Empty folder name provided.")
112 | # Remove trailing underscore and digits from the first token.
113 | cleaned_folder_name = re.sub(r'\d+', '', tokens[0])
114 | cleaned_folder_name = cleaned_folder_name.replace("_", "")
115 | try:
116 | return f"COCO2017_{cleaned_folder_name}_{COCO_CATEGORY_INDEX[cleaned_folder_name]}"
117 | except KeyError:
118 | raise ValueError(f"Could not find COCO category for folder name '{folder_name}, cleaned to '{cleaned_folder_name}'.")
119 |
120 | def __init__(self, dataset_path, filtering_annotations_path=None, available_coco_image_path=None, coco_segment_to_image_path=None, dataset_name="COCO2017"):
121 | data_type = "panoptic"
122 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
123 | # get filtering annotation
124 | if self.filtering_annotations_path is not None:
125 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
126 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
127 | default_annotation = {key: False for key in default_annotation.keys()}
128 | else:
129 | filtering_annotations = None
130 |
131 | # add coco image limit, limit the source image that each segments coming from
132 | if available_coco_image_path is not None:
133 | available_coco_image_list = json.load(open(available_coco_image_path, "r"))
134 | print(f"Only consider {len(available_coco_image_list)} images from COCO2017")
135 |
136 | # get metadata_table
137 | index_counter = 0
138 | rows = []
139 | for folder_name in os.listdir(self.dataset_path):
140 | folder_path = os.path.join(self.dataset_path, folder_name)
141 | if os.path.isdir(folder_path):
142 | for file_name in os.listdir(folder_path):
143 | if ".png" in file_name:
144 |
145 | # check if segment are not in the available_coco_image_list, then remove it from segments pool
146 | if available_coco_image_path is not None:
147 | correspond_image_name = file_name.split("_")[0] + ".jpg"
148 | if correspond_image_name not in available_coco_image_list:
149 | continue
150 |
151 | image_path = os.path.join(self.dataset_path, folder_name, file_name)
152 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}"
153 | rows.append({
154 | 'universal_idx': universial_idx,
155 | 'idx_for_curr_dataset': index_counter,
156 | 'category': self._load_coco_id(folder_name),
157 | 'image_path': image_path,
158 | 'mask_path': image_path, # Assuming mask_path is the same as image_path
159 | 'dataset_name': dataset_name,
160 | 'data_type': data_type,
161 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None
162 | })
163 | index_counter += 1
164 | self.metadata_table = pd.DataFrame(rows)
165 | self.categories = set(self.metadata_table['category'])
166 |
167 |
168 | class COCO2017FullDataset(COCO2017Dataset):
169 | def __init__(self, dataset_path, filtering_annotations_path=None, available_coco_image_path=None, coco_segment_to_image_path=None):
170 | super().__init__(dataset_path, filtering_annotations_path, available_coco_image_path, coco_segment_to_image_path, dataset_name="COCO_Full")
171 |
172 |
173 |
174 | # class SyntheticDataset(BaseObjectDataset):
175 | # def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None, dataset_name="SyntheticDatasetPlaceHolder"):
176 | # data_type = "panoptic"
177 | # super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
178 | # # get filtering annotation
179 | # if self.filtering_annotations_path is not None:
180 | # filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
181 | # default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
182 | # default_annotation = {key: False for key in default_annotation.keys()}
183 | # else:
184 | # filtering_annotations = None
185 |
186 | # if synthetic_annotation_path is not None:
187 | # synthetic_annotation = json.load(open(synthetic_annotation_path, "r"))
188 |
189 | # # get metadata_table
190 | # index_counter = 0
191 | # rows = []
192 |
193 | # # Wrap the first layer iteration with tqdm for progress display
194 | # for folder_name in tqdm(os.listdir(self.dataset_path), desc="Processing top-level directories"):
195 | # folder_path = os.path.join(self.dataset_path, folder_name)
196 | # if os.path.isdir(folder_path):
197 | # for subfolder_name in os.listdir(folder_path):
198 | # subfolder_path = os.path.join(folder_path, subfolder_name)
199 | # if os.path.isdir(subfolder_path):
200 | # for file_name in os.listdir(subfolder_path):
201 | # if ".png" in file_name:
202 | # # check if segment are not in the available_coco_image_list, then remove it from segments pool
203 | # image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name)
204 | # image_subfolder_id = file_name.replace('.png', '').split("_")[-1]
205 |
206 | # universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}"
207 | # query_idx = f"{folder_name}_{subfolder_name}"
208 |
209 | # if query_idx in synthetic_annotation:
210 | # description = synthetic_annotation[query_idx]["description"]
211 | # short_phrase = synthetic_annotation[query_idx]["short_phrase"]
212 | # features = synthetic_annotation[query_idx]["features"]
213 | # rows.append({
214 | # 'universal_idx': universial_idx,
215 | # 'query_idx': query_idx,
216 | # 'idx_for_curr_dataset': index_counter,
217 | # 'image_path': image_path,
218 | # 'mask_path': image_path, # Assuming mask_path is the same as image_path
219 | # 'dataset_name': dataset_name,
220 | # 'data_type': data_type,
221 | # 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None,
222 | # 'category': folder_name,
223 | # 'sub_category': subfolder_name,
224 | # 'description': description,
225 | # 'short_phrase': short_phrase,
226 | # 'features': features
227 | # })
228 | # index_counter += 1
229 | # self.metadata_table = pd.DataFrame(rows)
230 | # self.categories = set(self.metadata_table['category'])
231 |
232 |
233 | class SyntheticDataset(BaseObjectDataset):
234 | def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None, dataset_name="SyntheticDatasetPlaceHolder", cache_path="metadata_cache.pkl"):
235 | data_type = "panoptic"
236 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
237 |
238 | # Load filtering annotations if provided
239 | if self.filtering_annotations_path is not None:
240 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
241 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
242 | default_annotation = {key: False for key in default_annotation.keys()}
243 | else:
244 | filtering_annotations = None
245 | default_annotation = None
246 |
247 | if synthetic_annotation_path is not None:
248 | synthetic_annotation = json.load(open(synthetic_annotation_path, "r"))
249 | else:
250 | synthetic_annotation = {}
251 |
252 | # Check if cache exists:
253 | if os.path.exists(cache_path):
254 | print(f"Loading metadata from cache: {cache_path}")
255 | with open(cache_path, "rb") as fp:
256 | self.metadata_table = pickle.load(fp)
257 | else:
258 | # If cache does not exist, create the metadata table
259 | self.metadata_table = self._build_metadata_table(synthetic_annotation, filtering_annotations, default_annotation, dataset_name)
260 | # Save to cache for future runs
261 | with open(cache_path, "wb") as fp:
262 | pickle.dump(self.metadata_table, fp)
263 | print(f"Metadata cache saved to: {cache_path}")
264 |
265 | # Create a set of categories for later use
266 | self.categories = set(self.metadata_table['category'])
267 |
268 | def _process_folder(self, folder_name, synthetic_annotation, filtering_annotations, default_annotation, dataset_name):
269 | """
270 | Process one top-level directory and return a list of rows (dicts) for that folder.
271 | """
272 | index_counter = 0 # each folder can have an independent counter, we'll fix indices after merging if needed
273 | rows = []
274 | folder_path = os.path.join(self.dataset_path, folder_name)
275 | if os.path.isdir(folder_path):
276 | for subfolder_name in os.listdir(folder_path):
277 | subfolder_path = os.path.join(folder_path, subfolder_name)
278 | if os.path.isdir(subfolder_path):
279 | for file_name in os.listdir(subfolder_path):
280 | if file_name.endswith(".png"):
281 | image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name)
282 | image_subfolder_id = file_name.replace('.png', '').split("_")[-1]
283 |
284 | universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}"
285 | query_idx = f"{folder_name}_{subfolder_name}"
286 |
287 | if query_idx in synthetic_annotation:
288 | description = synthetic_annotation[query_idx]["description"]
289 | short_phrase = synthetic_annotation[query_idx]["short_phrase"]
290 | features = synthetic_annotation[query_idx]["features"]
291 | # Get filtering annotation if exists
292 | filtering_ann = (getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation)
293 | if filtering_annotations is not None else None)
294 | rows.append({
295 | 'universal_idx': universial_idx,
296 | 'query_idx': query_idx,
297 | 'image_path': image_path,
298 | 'mask_path': image_path, # assuming mask_path is the same as image_path
299 | 'dataset_name': dataset_name,
300 | 'data_type': "panoptic",
301 | 'filtering_annotation': filtering_ann,
302 | 'category': folder_name,
303 | 'sub_category': subfolder_name,
304 | 'description': description,
305 | 'short_phrase': short_phrase,
306 | 'features': features
307 | })
308 | index_counter += 1
309 | return rows
310 |
311 | def _build_metadata_table(self, synthetic_annotation, filtering_annotations, default_annotation, dataset_name):
312 | """
313 | Build the metadata table from the dataset_path using multithreading.
314 | """
315 | all_rows = []
316 |
317 | # Get the list of top-level directories
318 | folders = [folder for folder in os.listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, folder))]
319 |
320 | # Use a ThreadPoolExecutor to process folders concurrently
321 | with ThreadPoolExecutor(max_workers=600) as executor:
322 | # Submit jobs for each top-level folder
323 | future_to_folder = {
324 | executor.submit(self._process_folder, folder, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): folder
325 | for folder in folders
326 | }
327 | # Use tqdm to monitor progress
328 | for future in tqdm(as_completed(future_to_folder), total=len(future_to_folder), desc="Processing folders"):
329 | folder = future_to_folder[future]
330 | try:
331 | rows = future.result()
332 | all_rows.extend(rows)
333 | except Exception as exc:
334 | print(f"Folder {folder} generated an exception: {exc}")
335 |
336 | # Optionally, if you want sequential indexing across all rows:
337 | for idx, row in enumerate(all_rows):
338 | row["idx_for_curr_dataset"] = idx
339 |
340 | return pd.DataFrame(all_rows)
341 |
342 |
343 | # class SyntheticDataset(BaseObjectDataset):
344 | # def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None,
345 | # dataset_name="SyntheticDatasetPlaceHolder", cache_path="metadata_cache.pkl",
346 | # image_cache_dir=None): # image_cache_dir is the directory to save verified images (optional)
347 | # data_type = "panoptic"
348 | # super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
349 |
350 | # # Load filtering annotations if provided
351 | # if self.filtering_annotations_path is not None:
352 | # filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
353 | # default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
354 | # default_annotation = {key: False for key in default_annotation.keys()}
355 | # else:
356 | # filtering_annotations = None
357 | # default_annotation = None
358 |
359 | # if synthetic_annotation_path is not None:
360 | # synthetic_annotation = json.load(open(synthetic_annotation_path, "r"))
361 | # else:
362 | # synthetic_annotation = {}
363 |
364 | # # Image cache directory for verified images (optional)
365 | # self.image_cache_dir = image_cache_dir
366 |
367 | # # Check if metadata cache exists:
368 | # if os.path.exists(cache_path):
369 | # print(f"Loading metadata from cache: {cache_path}")
370 | # with open(cache_path, "rb") as fp:
371 | # self.metadata_table = pickle.load(fp)
372 | # else:
373 | # # If cache does not exist, create the metadata table
374 | # self.metadata_table = self._build_metadata_table(synthetic_annotation, filtering_annotations, default_annotation, dataset_name)
375 | # # Save to cache for future runs
376 | # with open(cache_path, "wb") as fp:
377 | # pickle.dump(self.metadata_table, fp)
378 | # print(f"Metadata cache saved to: {cache_path}")
379 |
380 | # # Create a set of categories for later use
381 | # self.categories = set(self.metadata_table['category'])
382 |
383 | # def _process_folder(self, folder_name, synthetic_annotation, filtering_annotations, default_annotation, dataset_name):
384 | # """
385 | # Process one top-level directory and return a list of rows (dicts) for that folder.
386 | # Includes checking if each image is valid before adding it to the dataset.
387 | # """
388 | # rows = []
389 | # folder_path = os.path.join(self.dataset_path, folder_name)
390 | # if os.path.isdir(folder_path):
391 | # for subfolder_name in os.listdir(folder_path):
392 | # subfolder_path = os.path.join(folder_path, subfolder_name)
393 | # if os.path.isdir(subfolder_path):
394 | # for file_name in os.listdir(subfolder_path):
395 | # # Only process PNG files
396 | # if file_name.endswith(".png"):
397 | # image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name)
398 |
399 | # # Attempt to open and verify the image using Pillow
400 | # try:
401 | # with Image.open(image_path) as img:
402 | # img.verify() # Verify image integrity
403 | # except Exception as e:
404 | # print(f"Skipping corrupted image: {image_path} ({e})")
405 | # continue # Skip this image if it is corrupted
406 |
407 | # # If image is verified and an image cache directory is set, save the verified image there
408 | # if self.image_cache_dir:
409 | # # Reconstruct a similar folder structure in the cache directory
410 | # cache_folder = os.path.join(self.image_cache_dir, folder_name, subfolder_name)
411 | # os.makedirs(cache_folder, exist_ok=True)
412 | # cached_image_path = os.path.join(cache_folder, file_name)
413 | # # Copy the image file to the cache directory
414 | # copyfile(image_path, cached_image_path)
415 | # # Use the cached image path for further processing
416 | # image_path = cached_image_path
417 |
418 | # image_subfolder_id = file_name.replace('.png', '').split("_")[-1]
419 | # universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}"
420 | # query_idx = f"{folder_name}_{subfolder_name}"
421 |
422 | # if query_idx in synthetic_annotation:
423 | # description = synthetic_annotation[query_idx]["description"]
424 | # short_phrase = synthetic_annotation[query_idx]["short_phrase"]
425 | # features = synthetic_annotation[query_idx]["features"]
426 | # # Get filtering annotation if exists
427 | # filtering_ann = (getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation)
428 | # if filtering_annotations is not None else None)
429 | # rows.append({
430 | # 'universal_idx': universial_idx,
431 | # 'query_idx': query_idx,
432 | # 'image_path': image_path,
433 | # 'mask_path': image_path, # Assuming mask_path is the same as image_path
434 | # 'dataset_name': dataset_name,
435 | # 'data_type': "panoptic",
436 | # 'filtering_annotation': filtering_ann,
437 | # 'category': folder_name,
438 | # 'sub_category': subfolder_name,
439 | # 'description': description,
440 | # 'short_phrase': short_phrase,
441 | # 'features': features
442 | # })
443 | # return rows
444 |
445 | # def _build_metadata_table(self, synthetic_annotation, filtering_annotations, default_annotation, dataset_name):
446 | # """
447 | # Build the metadata table from the dataset_path using multithreading.
448 | # """
449 | # all_rows = []
450 |
451 | # # Get the list of top-level directories
452 | # folders = [folder for folder in os.listdir(self.dataset_path)
453 | # if os.path.isdir(os.path.join(self.dataset_path, folder))]
454 |
455 | # # Define the number of threads you want to use
456 | # num_threads = 300 # Adjust as needed
457 |
458 | # # Use a ThreadPoolExecutor with a defined number of workers
459 | # with ThreadPoolExecutor(max_workers=num_threads) as executor:
460 | # # Submit jobs for each top-level folder
461 | # future_to_folder = {
462 | # executor.submit(self._process_folder, folder, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): folder
463 | # for folder in folders
464 | # }
465 | # # Use tqdm to monitor progress
466 | # for future in tqdm(as_completed(future_to_folder), total=len(future_to_folder), desc="Processing folders"):
467 | # folder = future_to_folder[future]
468 | # try:
469 | # rows = future.result()
470 | # all_rows.extend(rows)
471 | # except Exception as exc:
472 | # print(f"Folder {folder} generated an exception: {exc}")
473 |
474 | # # Optionally, assign a sequential index across all rows
475 | # for idx, row in enumerate(all_rows):
476 | # row["idx_for_curr_dataset"] = idx
477 |
478 | # return pd.DataFrame(all_rows)
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 | class SA1BDataset(BaseObjectDataset):
489 | def __init__(self, dataset_path, filtering_annotations_path):
490 | dataset_name = "sa1b_parsed"
491 | data_type = "non-semantic"
492 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path)
493 | # get filtering annotation
494 | if self.filtering_annotations_path is not None:
495 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r"))
496 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]]
497 | default_annotation = {key: False for key in default_annotation.keys()}
498 | else:
499 | filtering_annotations = None
500 |
501 |
502 | # get metadata_table
503 | index_counter = 0
504 | rows = []
505 | for folder_name in os.listdir(self.dataset_path):
506 | folder_path = os.path.join(self.dataset_path, folder_name)
507 | if os.path.isdir(folder_path):
508 | for file_name in os.listdir(folder_path):
509 | if ".png" in file_name or ".jpg" in file_name:
510 | image_path = os.path.join(self.dataset_path, folder_name, file_name)
511 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}"
512 | if universial_idx in filtering_annotations:
513 | rows.append({
514 | 'universal_idx': universial_idx,
515 | 'idx_for_curr_dataset': index_counter,
516 | 'category': folder_name,
517 | 'image_path': image_path,
518 | 'mask_path': image_path, # Assuming mask_path is the same as image_path
519 | 'dataset_name': dataset_name,
520 | 'data_type': data_type,
521 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None
522 | })
523 | index_counter += 1
524 | self.metadata_table = pd.DataFrame(rows)
525 | self.categories = set(self.metadata_table['category'])
526 |
527 |
528 | class CustomizedDatasetOurFormat(BaseObjectDataset):
529 | def __init__(self, dataset_path, dataset_name, data_type):
530 | super().__init__(dataset_name, data_type, dataset_path)
531 | # get metadata_table
532 | index_counter = 0
533 | rows = []
534 | for folder_name in os.listdir(self.dataset_path):
535 | folder_path = os.path.join(self.dataset_path, folder_name)
536 | if os.path.isdir(folder_path):
537 | for file_name in os.listdir(folder_path):
538 | if ".png" in file_name:
539 | image_path = os.path.join(self.dataset_path, folder_name, file_name)
540 | rows.append({
541 | 'idx_for_curr_dataset': index_counter,
542 | 'category': folder_name,
543 | 'image_path': image_path,
544 | 'mask_path': image_path, # Assuming mask_path is the same as image_path
545 | 'dataset_name': dataset_name,
546 | 'data_type': data_type
547 | })
548 | index_counter += 1
549 | self.metadata_table = pd.DataFrame(rows)
550 | self.categories = set(self.metadata_table['category'])
551 |
552 |
--------------------------------------------------------------------------------