├── .gitignore ├── datasets ├── __init__.py ├── background_datasets.py ├── base_datasets.py ├── data_manager.py ├── panoptic_coco_categories.json ├── coco_categories.py └── object_datasets.py ├── segment_layout ├── __init__.py ├── existing_layout_generator.py ├── random_bbox_layout_generator.py └── fine_grained_bbox_layout_generator.py ├── segmentation_generator ├── utils.py ├── __init__.py ├── generate_image_and_mask.py ├── panoptic_coco_categories.json └── segmentataion_synthesizer.py ├── requirements.txt ├── requirements_referring_expression_generation.txt ├── .gitattributes ├── assets ├── ovd.png ├── ref.png ├── fcgc.png ├── pipeline.png ├── results.png ├── sfcsgc.png ├── teaser.png ├── comparison.png └── intra-class.png ├── requirements_relight_and_blending.txt ├── scripts └── generate_with_batch.py ├── README.md ├── referring_expression_generation └── inference.py └── relighting_and_blending └── inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segment_layout/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation_generator/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segmentation_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open_clip_torch 2 | pandas 3 | pycocotools 4 | imageio[pyav] -------------------------------------------------------------------------------- /requirements_referring_expression_generation.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | vllm 3 | google-cloud-storage -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pdf filter=lfs diff=lfs merge=lfs -text 2 | *.png filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /assets/ovd.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:900c4744d0d78ea674ec2e1ba0e77d34046d5bf3206697e82ddc313b2e69b8a2 3 | size 474129 4 | -------------------------------------------------------------------------------- /assets/ref.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5c57b5a85ae1778a44b7af9c0a49a5506e3b306151923f1dc1f7b6276d3026ab 3 | size 504939 4 | -------------------------------------------------------------------------------- /assets/fcgc.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a5638531abd756b072c6135233e750be5ba4f8ea269fc5eb283f3813d5e206ce 3 | size 1039387 4 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:76725df1ae1ddc3906464b997cce6a8b90548fa65a1cb159e0f0435133efa14a 3 | size 872622 4 | -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ca283f0a17b8958b1857fe24fbd1edc6b25b9c261b49112c2732c1e0e08cde95 3 | size 387757 4 | -------------------------------------------------------------------------------- /assets/sfcsgc.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d862973fcb1de835f907486a06aa1c1db87b0b946a53b0c1db3ba99a5a95dd63 3 | size 708186 4 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7d6bc7e1b3a3d6258b180c5396cde1db6da76e0b1a118868bf963cf10274ba24 3 | size 2273149 4 | -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7ce424160c8d85ea1baf949734a2f7891d5d9bd848cd702fa3fbb04d4001e39f 3 | size 1307979 4 | -------------------------------------------------------------------------------- /assets/intra-class.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2c615f4a058fa08f4d1bf98f468164f752d452b7f00c411d1c066c596bc3c90e 3 | size 715132 4 | -------------------------------------------------------------------------------- /requirements_relight_and_blending.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.32.1 2 | transformers==4.48.0 3 | opencv-python 4 | safetensors 5 | pillow==10.2.0 6 | einops 7 | torch 8 | peft 9 | gradio==3.41.2 10 | protobuf==3.20 11 | google-cloud-storage 12 | scikit-image -------------------------------------------------------------------------------- /datasets/background_datasets.py: -------------------------------------------------------------------------------- 1 | from .base_datasets import BaseBackgroundDataset 2 | import os 3 | import glob 4 | import pandas as pd 5 | 6 | 7 | class BG20KDataset(BaseBackgroundDataset): 8 | def __init__(self, dataset_path): 9 | dataset_name = "BG20k" 10 | data_type = "background" 11 | super().__init__(dataset_name, data_type, dataset_path) 12 | # get metadata_table 13 | rows = [] 14 | self.metadata_table = pd.DataFrame(columns=['idx_for_curr_dataset', 'category', 'image_path', 'dataset_name', 'data_type']) 15 | index_counter = 0 16 | for image_path in glob.glob(self.dataset_path + "/train/*.jpg"): 17 | rows.append({ 18 | 'idx_for_curr_dataset': index_counter, 19 | 'category': "no category", 20 | 'image_path': image_path, 21 | 'dataset_name': dataset_name, 22 | 'data_type': data_type 23 | }) 24 | index_counter += 1 25 | self.metadata_table = pd.DataFrame(rows) 26 | 27 | class NoBGDataset(BaseBackgroundDataset): 28 | pass -------------------------------------------------------------------------------- /segment_layout/existing_layout_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy as np 4 | import pandas as pd 5 | import random 6 | import ast 7 | 8 | def calculate_image_areas(image_paths): 9 | components = [] 10 | for path in image_paths: 11 | with Image.open(path) as img: 12 | width, height = img.size 13 | components.append((width, height)) 14 | return components 15 | 16 | def compute_similarity(component, box): 17 | """ 18 | Compute similarity score based on area and aspect ratio. 19 | """ 20 | # Unpack dimensions 21 | comp_h, comp_w = component 22 | box_h = box[2] - box[0] 23 | box_w = box[3] - box[1] 24 | 25 | # Calculate areas and aspect ratios 26 | comp_area = comp_h * comp_w 27 | box_area = box_h * box_w 28 | 29 | comp_aspect_ratio = comp_h / comp_w 30 | box_aspect_ratio = box_h / box_w 31 | 32 | # Area difference (normalized) 33 | area_diff = abs(comp_area - box_area) / max(comp_area, box_area) 34 | 35 | # Aspect ratio difference (normalized) 36 | aspect_ratio_diff = abs(comp_aspect_ratio - box_aspect_ratio) 37 | 38 | # Weighted score (lower is better) 39 | score = 0.8 * area_diff + 0.2 * aspect_ratio_diff 40 | return score 41 | 42 | 43 | def postprocess_bboxes(bbox_predictions, width, height): 44 | # Scale the normalized bbox predictions to image dimensions 45 | processed_bboxes = [] 46 | for bbox in bbox_predictions: 47 | xmin = bbox[0] * width 48 | ymin = bbox[1] * height 49 | xmax = bbox[2] * width 50 | ymax = bbox[3] * height 51 | processed_bboxes.append([xmin, ymin, xmax, ymax]) 52 | 53 | return processed_bboxes 54 | 55 | class ExistingLayoutGenerator(): 56 | def __init__(self, table_path, device=None): 57 | # Determine device 58 | if device is None: 59 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | else: 61 | self.device = device 62 | 63 | # Load model weights 64 | self.layout_table = pd.read_csv(table_path) 65 | 66 | def predict_wi_area(self, num_bboxes, segment_list, width, height): 67 | # implement future 68 | pass 69 | 70 | def predict_wo_area(self, num_bboxes, width, height): 71 | # should we pass the result here 72 | layout_table_filtered = self.layout_table[self.layout_table['num_bboxes'] >= num_bboxes] 73 | # sample a layout from the filtered table 74 | bboxes_result = ast.literal_eval(layout_table_filtered.sample(n=1)['bboxes'].item()) 75 | correspond_bboxes_result = bboxes_result[0:num_bboxes] 76 | # sort the box 77 | correspond_bboxes_result = sorted( 78 | bboxes_result, key=lambda bbox: (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), reverse=True 79 | ) 80 | # correspond_bboxes_result = random.sample(bboxes_result, num_bboxes) 81 | 82 | outputs = postprocess_bboxes(correspond_bboxes_result, width, height) 83 | return outputs 84 | -------------------------------------------------------------------------------- /datasets/base_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def convert_rgba_to_rgb_and_mask(rgba_image: Image): 7 | # Open the RGBA image 8 | if rgba_image.mode != "RGBA": 9 | raise ValueError("Input image must be in RGBA mode") 10 | # Convert to RGB 11 | rgb_image = rgba_image.convert("RGB") 12 | 13 | # Create a mask based on the alpha channel 14 | alpha_channel = np.array(rgba_image.getchannel('A')) 15 | mask = Image.fromarray(alpha_channel) # This will be a grayscale image 16 | return rgb_image, mask 17 | 18 | 19 | def get_bounding_box(mask): 20 | pass 21 | 22 | def get_image_and_mask(image_path, mask_path): 23 | if image_path == mask_path: 24 | image_with_mask = Image.open(image_path) 25 | if image_with_mask.mode == "RGBA": 26 | return convert_rgba_to_rgb_and_mask(image_with_mask) 27 | else: 28 | raise ValueError("Image and mask are the same, but image is not in RGBA mode") 29 | else: 30 | return Image.open(image_path), Image.open(mask_path) 31 | 32 | 33 | class BaseObjectDataset(Dataset): 34 | def __init__(self, dataset_name, data_type, dataset_path, filtering_annotations_path=None): 35 | self.dataset_name = dataset_name 36 | self.dataset_path = dataset_path 37 | self.data_type = data_type 38 | self.metadata_table = None 39 | self.filtering_annotations_path = filtering_annotations_path 40 | 41 | 42 | def __len__(self): 43 | return len(self.metadata_table) 44 | 45 | def __getitem__(self, index): 46 | data_package = {} 47 | metadata = self.metadata_table[self.metadata_table['idx_for_curr_dataset'] == index].iloc[0] 48 | data_package['metadata'] = metadata 49 | data_package['image'], data_package['mask'] = get_image_and_mask(metadata['image_path'], metadata['mask_path']) 50 | return data_package 51 | 52 | def return_metadata_table(self): 53 | # metadata contains the information of each mask, like its cateogory, path. 54 | # each mask store as a row in the metadata table 55 | return self.metadata_table 56 | 57 | 58 | class BaseBackgroundDataset(Dataset): 59 | def __init__(self, dataset_name, data_type, dataset_path): 60 | self.dataset_name = dataset_name 61 | self.dataset_path = dataset_path 62 | self.data_type = data_type 63 | self.metadata_table = None 64 | 65 | def __len__(self): 66 | return len(self.metadata_table) 67 | 68 | def __getitem__(self, index): 69 | data_package = {} 70 | metadata = self.metadata_table[self.metadata_table['idx_for_curr_dataset'] == index].iloc[0] 71 | data_package['metadata'] = metadata 72 | data_package['image'] = Image.open(metadata['image_path']).convert("RGB") 73 | return data_package 74 | 75 | def return_metadata_table(self): 76 | # metadata contains the information of each mask, like its cateogory, path. 77 | # each mask store as a row in the metadata table 78 | return self.metadata_table -------------------------------------------------------------------------------- /segment_layout/random_bbox_layout_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy as np 4 | import pandas as pd 5 | import random 6 | import ast 7 | 8 | def postprocess_bboxes(bbox_predictions, width, height): 9 | # Scale the normalized bbox predictions to image dimensions. 10 | processed_bboxes = [] 11 | for bbox in bbox_predictions: 12 | xmin = bbox[0] * width 13 | ymin = bbox[1] * height 14 | xmax = bbox[2] * width 15 | ymax = bbox[3] * height 16 | processed_bboxes.append([xmin, ymin, xmax, ymax]) 17 | return processed_bboxes 18 | 19 | def compute_iou(bbox1, bbox2): 20 | """ 21 | Compute Intersection over Union (IoU) for two boxes in normalized coordinates. 22 | Each bbox is in the format [x_min, y_min, x_max, y_max]. 23 | """ 24 | x_left = max(bbox1[0], bbox2[0]) 25 | y_top = max(bbox1[1], bbox2[1]) 26 | x_right = min(bbox1[2], bbox2[2]) 27 | y_bottom = min(bbox1[3], bbox2[3]) 28 | 29 | if x_right < x_left or y_bottom < y_top: 30 | return 0.0 31 | 32 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 33 | area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) 34 | area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) 35 | union_area = area1 + area2 - intersection_area 36 | if union_area == 0: 37 | return 0.0 38 | return intersection_area / union_area 39 | 40 | class RandomBoundingBoxLayoutGenerator(): 41 | def __init__(self, device=None): 42 | # Determine device. 43 | if device is None: 44 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | else: 46 | self.device = device 47 | 48 | def generate(self, num_bboxes, width, height, min_box_scale=0.05, max_box_scale=0.6, max_overlap=0.5): 49 | """ 50 | Generate num_bboxes random bounding boxes in a controlled way with limited overlap. 51 | The boxes are defined in normalized coordinates and then scaled to the given width and height. 52 | min_box_scale and max_box_scale define the fraction of the image that a box can take in width and height. 53 | max_overlap defines the maximum allowed Intersection over Union (IoU) between any two boxes. 54 | """ 55 | bboxes = [] 56 | max_attempts_total = num_bboxes * 50 # Maximum number of candidate generations. 57 | attempts = 0 58 | 59 | while len(bboxes) < num_bboxes and attempts < max_attempts_total: 60 | # Randomly select box width and height as a fraction of the full image. 61 | box_w = random.uniform(min_box_scale, max_box_scale) 62 | box_h = random.uniform(min_box_scale, max_box_scale) 63 | # Choose top-left coordinates so that the box fits within [0,1]. 64 | x_min = random.uniform(0, 1 - box_w) 65 | y_min = random.uniform(0, 1 - box_h) 66 | x_max = x_min + box_w 67 | y_max = y_min + box_h + 0.3 68 | candidate_box = [x_min, y_min, x_max, y_max] 69 | 70 | # Check overlap with all previously accepted boxes. 71 | valid = True 72 | for existing_box in bboxes: 73 | if compute_iou(candidate_box, existing_box) > max_overlap: 74 | valid = False 75 | break 76 | 77 | if valid: 78 | bboxes.append(candidate_box) 79 | attempts += 1 80 | 81 | if len(bboxes) < num_bboxes: 82 | print(f"Warning: Only generated {len(bboxes)} boxes with the given overlap constraints.") 83 | 84 | processed_bboxes = postprocess_bboxes(bboxes, width, height) 85 | return processed_bboxes -------------------------------------------------------------------------------- /datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | class DataManager(): 4 | def __init__(self, available_object_datasets, available_background_datasets, filtering_setting): 5 | # Process object datasets. 6 | self.available_object_datasets = available_object_datasets 7 | self.filtered_object_data = {} # To store filtered data for each object dataset. 8 | self.accumulated_object_metadata_table = pd.DataFrame() 9 | unified_object_categories = set() 10 | 11 | for dataset_name, curr_dataset in self.available_object_datasets.items(): 12 | # Retrieve the original metadata table. 13 | metadata_table = curr_dataset.return_metadata_table() 14 | count_before = len(metadata_table) 15 | 16 | # Apply filtering based on the provided settings. 17 | if metadata_table['filtering_annotation'].isna().any(): 18 | filtered_table = metadata_table 19 | count_after = count_before 20 | else: 21 | filter_mask = pd.Series(True, index=metadata_table.index) 22 | for metric, condition in filtering_setting.items(): 23 | if condition == "filter": 24 | filter_mask &= metadata_table['filtering_annotation'].apply( 25 | lambda x: x.get(metric, False) if pd.notna(x) else False 26 | ) 27 | filtered_table = metadata_table[filter_mask] 28 | count_after = len(filtered_table) 29 | 30 | # Print counts for this dataset. 31 | print(f"Object dataset '{dataset_name}':") 32 | print(f" Count before filtering: {count_before}") 33 | print(f" Count after filtering: {count_after}") 34 | 35 | # Store the filtered table for later access and accumulate in one table. 36 | self.filtered_object_data[dataset_name] = filtered_table 37 | self.accumulated_object_metadata_table = pd.concat( 38 | [self.accumulated_object_metadata_table, filtered_table] 39 | ) 40 | 41 | # Collect unified object categories. 42 | unified_object_categories.update(curr_dataset.categories) 43 | 44 | # Prepare a sorted list of unified object categories and a mapping to indices. 45 | self.unified_object_categories = sorted(list(unified_object_categories)) 46 | self.unified_object_categories_to_idx = {category: idx for idx, category in enumerate(self.unified_object_categories)} 47 | 48 | # Process background datasets. 49 | self.available_background_datasets = available_background_datasets 50 | self.accumulated_background_metadata_table = pd.DataFrame() 51 | self.background_data = {} # Store each background dataset's metadata table. 52 | for dataset_name, curr_dataset in self.available_background_datasets.items(): 53 | metadata_table = curr_dataset.return_metadata_table() 54 | self.background_data[dataset_name] = metadata_table 55 | self.accumulated_background_metadata_table = pd.concat( 56 | [self.accumulated_background_metadata_table, metadata_table] 57 | ) 58 | print(f"Background dataset '{dataset_name}' has {len(metadata_table)} records.") 59 | 60 | print("Finished processing object and background datasets.") 61 | 62 | # object 63 | def query_object_metadata(self, query): 64 | return self.accumulated_object_metadata_table.query(query) 65 | 66 | def get_object_by_metadata(self, metadata): 67 | dataset_name = metadata["dataset_name"] 68 | idx_for_curr_dataset = metadata["idx_for_curr_dataset"] 69 | return self.available_object_datasets[dataset_name][idx_for_curr_dataset] 70 | 71 | def get_random_object_metadata(self, rng): 72 | # First uniformly select an object dataset. 73 | dataset_name = rng.choice(list(self.filtered_object_data.keys())) 74 | dataset_table = self.filtered_object_data[dataset_name] 75 | # Then uniformly sample from that dataset's metadata table. 76 | idx = rng.choice(len(dataset_table)) 77 | selected_metadata = dataset_table.iloc[idx].copy() 78 | # Add dataset information to the metadata. 79 | selected_metadata["dataset_name"] = dataset_name 80 | return selected_metadata 81 | 82 | # background 83 | def query_background_metadata(self, query): 84 | return self.accumulated_background_metadata_table.query(query) 85 | 86 | def get_background_by_metadata(self, metadata): 87 | dataset_name = metadata["dataset_name"] 88 | idx_for_curr_dataset = metadata["idx_for_curr_dataset"] 89 | return self.available_background_datasets[dataset_name][idx_for_curr_dataset] 90 | 91 | def get_random_background_metadata(self, rng): 92 | # First uniformly select a background dataset. 93 | dataset_name = rng.choice(list(self.background_data.keys())) 94 | metadata_table = self.background_data[dataset_name] 95 | # Then uniformly sample from that dataset's metadata table. 96 | idx = rng.choice(len(metadata_table)) 97 | selected_metadata = metadata_table.iloc[idx].copy() 98 | # Add dataset information to the metadata. 99 | selected_metadata["dataset_name"] = dataset_name 100 | return selected_metadata 101 | 102 | def get_random_object_metadata_by_category(self, rng, category): 103 | df = self.accumulated_object_metadata_table 104 | # vectorized filter 105 | cat_df = df[df['category'] == category] 106 | if cat_df.empty: 107 | raise ValueError(f"No objects found for category '{category}'") 108 | # sample a single row by integer location 109 | i = rng.integers(0, len(cat_df)) 110 | return cat_df.iloc[i].copy() -------------------------------------------------------------------------------- /segment_layout/fine_grained_bbox_layout_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | from shapely.geometry import box as shapely_box 6 | from shapely.ops import unary_union 7 | import math 8 | 9 | # Function to scale normalized bounding boxes to pixel coordinates 10 | def postprocess_bboxes(bboxes, width, height): 11 | processed = [] 12 | for bbox in bboxes: 13 | x_min, y_min, x_max, y_max = bbox 14 | processed.append([int(x_min * width), 15 | int(y_min * height), 16 | int(min(x_max * width, width)), 17 | int(min(y_max * height, height))]) 18 | return processed 19 | 20 | def compute_visible_ratios(bboxes): 21 | """ 22 | For a given list of boxes (in generation order), compute the visible area ratio 23 | for each box (visible_area / total_area), assuming later boxes occlude earlier ones. 24 | """ 25 | ratios = [] 26 | polys = [shapely_box(*bbox) for bbox in bboxes] 27 | n = len(polys) 28 | for i in range(n): 29 | poly = polys[i] 30 | if i == n - 1: # last box is fully visible 31 | ratios.append(1.0) 32 | else: 33 | occluders = polys[i+1:] 34 | union_occlusion = unary_union(occluders) 35 | intersection = poly.intersection(union_occlusion) 36 | occluded_area = intersection.area 37 | visible_area = poly.area - occluded_area 38 | ratio = visible_area / poly.area if poly.area > 0 else 0 39 | ratios.append(ratio) 40 | return ratios 41 | 42 | def sample_box(min_area, max_area, aspect_range=(0.5, 2.0)): 43 | """ 44 | Sample a box (in normalized coordinates) with an area in [min_area, max_area] and 45 | an aspect ratio in aspect_range. 46 | Returns [x_min, y_min, x_max, y_max] (all between 0 and 1), or None if failed. 47 | """ 48 | # Sample a desired area uniformly 49 | area = random.uniform(min_area, max_area) 50 | # Sample an aspect ratio (width/height) 51 | aspect = random.uniform(aspect_range[0], aspect_range[1]) 52 | # Compute width and height from area and aspect ratio: 53 | # area = width * height, and width = aspect * height 54 | # => height = sqrt(area / aspect), width = sqrt(area * aspect) 55 | h = math.sqrt(area / aspect) 56 | w = math.sqrt(area * aspect) 57 | # Ensure the box fits inside [0,1] by sampling a top-left coordinate. 58 | if w > 1 or h > 1: 59 | return None 60 | x_min = random.uniform(0, 1 - w) 61 | y_min = random.uniform(0, 1 - h) 62 | return [x_min, y_min, x_min + w, y_min + h] 63 | 64 | class FineGrainedBoundingBoxLayoutGenerator(): 65 | def __init__(self, device=None): 66 | if device is None: 67 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 68 | else: 69 | self.device = device 70 | 71 | def generate(self, num_large, num_mid, num_small, width, height, 72 | # COCO area thresholds (normalized) 73 | # For an image of area 1, small: <1024/1e6, medium: [1024/1e6, 9216/1e6), large: >=9216/1e6 74 | # For a 1024x1024 image, total area=1. 75 | small_area_range=((32 * 16)/(640*480), (32 * 32)/(640*480)), # ~[0, 0.0009766) 76 | mid_area_range=((32*32)/(640*480), (96 * 96)/(640*480)), # ~[0.0009766, 0.0087891) 77 | large_area_range=((96 * 96)/(640*480), 0.5), # large boxes: area >=0.0087891 and up to 0.5 78 | aspect_range=(0.5, 2.0), 79 | min_avg_display_area=0.8, 80 | min_single_display_area=0.5 81 | ): 82 | """ 83 | Generate a layout of boxes in three groups: 84 | - Large boxes: number=num_large, area in large_area_range 85 | - Mid boxes: number=num_mid, area in mid_area_range 86 | - Small boxes: number=num_small, area in small_area_range 87 | 88 | Boxes are generated sequentially (large first, then mid, then small). 89 | Overlap among boxes (regardless of group) is taken into account via visible area constraints. 90 | """ 91 | total_boxes = num_large + num_mid + num_small 92 | max_layout_attempts = total_boxes * 100 # maximum attempts to generate the full layout 93 | layout_attempt = 0 94 | 95 | candidate_bboxes = [] 96 | groups = [ 97 | (num_large, large_area_range), 98 | (num_mid, mid_area_range), 99 | (num_small, small_area_range) 100 | ] 101 | 102 | while layout_attempt < max_layout_attempts: 103 | candidate_bboxes = [] 104 | valid = True 105 | # For each group in order: 106 | for (num, area_range) in groups: 107 | for _ in range(num): 108 | max_box_attempts = 100 109 | box_found = False 110 | for _ in range(max_box_attempts): 111 | candidate_box = sample_box(area_range[0], area_range[1], aspect_range) 112 | if candidate_box is None: 113 | continue 114 | temp_bboxes = candidate_bboxes + [candidate_box] 115 | ratios = compute_visible_ratios(temp_bboxes) 116 | avg_ratio = sum(ratios) / len(ratios) 117 | if avg_ratio >= min_avg_display_area and all(r >= min_single_display_area for r in ratios): 118 | candidate_bboxes.append(candidate_box) 119 | box_found = True 120 | break 121 | if not box_found: 122 | valid = False 123 | break # break out if a box in this group cannot be generated 124 | if not valid: 125 | break 126 | if valid and len(candidate_bboxes) == total_boxes: 127 | return postprocess_bboxes(candidate_bboxes, width, height) 128 | layout_attempt += 1 129 | 130 | raise ValueError(f"After {max_layout_attempts} attempts, failed to generate a layout satisfying the constraints.") -------------------------------------------------------------------------------- /segmentation_generator/generate_image_and_mask.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | def generate_image_and_mask(data_manager, image_metadata, resize_mode, containSmallObjectMask=False, small_object_ratio=(36 * 36)/(640 * 480)): 5 | width = image_metadata["width"] 6 | height = image_metadata["height"] 7 | # Create a blank canvas and ground truth mask 8 | canvas = Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8)) 9 | ground_truth_mask = Image.new("L", (width, height), 0) 10 | 11 | # This will collect segment IDs for small objects based on the bounding box area. 12 | small_object_ids = [] 13 | 14 | # Add background if available 15 | if "background" in image_metadata: 16 | background_metadata = image_metadata["background"]["background_metadata"] 17 | background_image = data_manager.get_background_by_metadata(background_metadata)['image'] 18 | canvas = paste_background(background_image, canvas) 19 | 20 | # Process objects 21 | if "objects" in image_metadata: 22 | for object_info in image_metadata["objects"]: 23 | object_metadata = object_info["object_metadata"] 24 | object_position = object_info["object_position"] # either a bounding box or center point 25 | segment_id = object_info["segment_id"] 26 | 27 | # Use the bounding box area if available (i.e., when object_position has more than 2 elements). 28 | if len(object_position) != 2: 29 | bbox_area = (object_position[2] - object_position[0]) * (object_position[3] - object_position[1]) 30 | if bbox_area / (width * height) <= small_object_ratio: 31 | small_object_ids.append(segment_id) 32 | # Optionally, you might decide what to do if only a center point is provided. 33 | # For example, you could skip the small object check or assign a default small area. 34 | 35 | # Get the object image and its mask from the data manager. 36 | object_data_package = data_manager.get_object_by_metadata(object_metadata) 37 | image_obj, mask_obj = object_data_package['image'], object_data_package['mask'] 38 | 39 | # If augmentation information is provided, apply the augmentation 40 | aug_params = object_info.get("augmentation", None) 41 | if aug_params is not None: 42 | image_obj, mask_obj = apply_augmentation(image_obj, mask_obj, aug_params) 43 | 44 | # Paste using the appropriate method based on how the object is defined. 45 | if len(object_position) == 2: 46 | paste_segment_wi_center_point(image_obj, mask_obj, canvas, ground_truth_mask, object_position, segment_id) 47 | else: 48 | paste_segment_wi_bbox(image_obj, mask_obj, canvas, ground_truth_mask, object_position, segment_id, resize_mode) 49 | 50 | # If a small object mask is desired, create it by keeping only the visible pixels for small objects. 51 | if containSmallObjectMask: 52 | ground_truth_mask_np = np.array(ground_truth_mask) 53 | # Retain only the pixels whose segment ID is in the small_object_ids list. 54 | small_object_mask_np = np.where(np.isin(ground_truth_mask_np, small_object_ids), ground_truth_mask_np, 0).astype(np.uint8) 55 | return canvas, ground_truth_mask, small_object_mask_np 56 | else: 57 | return canvas, ground_truth_mask 58 | 59 | def paste_background(background, canvas): 60 | # Resize background if necessary 61 | if canvas.size != background.size: 62 | background = background.resize(canvas.size) 63 | return background 64 | 65 | def normalize_mask(mask: Image) -> Image: 66 | """ 67 | Convert a mask to mode 'L' and ensure its pixel values are in [0, 255]. 68 | """ 69 | if mask.mode != 'L': 70 | mask = mask.convert('L') 71 | mask_np = np.array(mask) 72 | if mask_np.max() <= 1: 73 | mask_np = (mask_np * 255).astype(np.uint8) 74 | return Image.fromarray(mask_np) 75 | 76 | def paste_segment_wi_center_point(image: Image, mask: Image, canvas: Image, 77 | ground_truth_mask: Image, position: tuple, segment_id): 78 | # Ensure correct image modes 79 | if image.mode != 'RGB': 80 | image = image.convert('RGB') 81 | mask = normalize_mask(mask) 82 | 83 | # Get the bounding box of the nonzero (segmented) area 84 | bbox = mask.getbbox() 85 | if bbox is None: 86 | raise ValueError("No segmented area found in the mask.") 87 | 88 | # Crop both the image and mask to the bounding box 89 | cropped_image = image.crop(bbox) 90 | cropped_mask = mask.crop(bbox) 91 | 92 | # Compute offset so that the segment's center aligns with the given position 93 | segment_center = ((bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2) 94 | offset = (int(position[0] - segment_center[0]), int(position[1] - segment_center[1])) 95 | 96 | # Build a segment ID mask using vectorized operations (avoid slow Python loops) 97 | cropped_mask_np = np.array(cropped_mask) 98 | segment_mask_np = np.where(cropped_mask_np > 0, segment_id, 0).astype(np.uint8) 99 | segment_id_mask = Image.fromarray(segment_mask_np) 100 | 101 | # Paste the cropped image and the segment ID mask onto their respective canvases 102 | canvas.paste(cropped_image, offset, cropped_mask) 103 | ground_truth_mask.paste(segment_id_mask, offset, cropped_mask) 104 | 105 | def paste_segment_wi_bbox(image, mask, canvas, 106 | ground_truth_mask, bbox, segment_id, 107 | resize_mode): 108 | # Ensure the image is in RGB mode 109 | if image.mode != 'RGB': 110 | image = image.convert('RGB') 111 | mask = normalize_mask(mask) 112 | 113 | # Get the bounding box of the segmented area to crop out unnecessary parts 114 | original_object_bbox = mask.getbbox() 115 | if original_object_bbox is None: 116 | raise ValueError("No segmented area found in the mask.") 117 | 118 | cropped_image = image.crop(original_object_bbox) 119 | cropped_mask = mask.crop(original_object_bbox) 120 | 121 | # Calculate the target region size based on the provided bounding box 122 | x_min, y_min, x_max, y_max = bbox 123 | bbox_width = int(round(x_max - x_min)) 124 | bbox_height = int(round(y_max - y_min)) 125 | 126 | if resize_mode == "full": 127 | # Stretch the image to fully fit the bounding box 128 | new_width, new_height = bbox_width, bbox_height 129 | offset = (int(round(x_min)), int(round(y_min))) 130 | elif resize_mode == "fit": 131 | # Maintain the aspect ratio of the cropped image 132 | orig_width, orig_height = cropped_image.size 133 | ratio = orig_width / orig_height 134 | # Compute the maximum size that fits within the bounding box while preserving the ratio 135 | if bbox_width / bbox_height > ratio: 136 | new_height = bbox_height 137 | new_width = int(round(bbox_height * ratio)) 138 | else: 139 | new_width = bbox_width 140 | new_height = int(round(bbox_width / ratio)) 141 | # Center the image within the bounding box 142 | offset = (int(round(x_min + (bbox_width - new_width) / 2)), 143 | int(round(y_min + (bbox_height - new_height) / 2))) 144 | else: 145 | raise ValueError("resize_mode must be 'full' or 'fit'.") 146 | 147 | if new_width <= 0 or new_height <= 0: 148 | new_width = max(1, new_width) 149 | new_height = max(1, new_height) 150 | # fixing now 151 | # Resize the cropped image and mask to the calculated dimensions 152 | resized_image = cropped_image.resize((new_width, new_height), resample=Image.BICUBIC) 153 | resized_mask = cropped_mask.resize((new_width, new_height), resample=Image.NEAREST) 154 | 155 | # Create a segment ID mask using vectorized operations 156 | resized_mask_np = np.array(resized_mask) 157 | segment_id_mask_np = np.where(resized_mask_np > 0, segment_id, 0).astype(np.uint8) 158 | segment_id_mask = Image.fromarray(segment_id_mask_np) 159 | 160 | # Paste the resized image and segment ID mask onto the canvas and ground truth mask 161 | canvas.paste(resized_image, offset, resized_mask) 162 | ground_truth_mask.paste(segment_id_mask, offset, resized_mask) 163 | 164 | def apply_augmentation(image: Image, mask: Image, aug_params: dict): 165 | """ 166 | Apply scaling, horizontal and vertical flips, and rotation to both the image and its mask. 167 | Scaling uses bicubic interpolation for the image and nearest for the mask; 168 | rotation is performed with expansion to preserve the full transformed segment. 169 | """ 170 | scale = aug_params.get("scale", 1.0) 171 | flip_horizontal = aug_params.get("flip_horizontal", False) 172 | flip_vertical = aug_params.get("flip_vertical", False) 173 | rotate_angle = aug_params.get("rotate", 0) 174 | # Scaling: resize the image and mask if the scale factor is different from 1.0 175 | if scale != 1.0: 176 | new_size = (int(round(image.width * scale)), int(round(image.height * scale))) 177 | image = image.resize(new_size, resample=Image.BICUBIC) 178 | mask = mask.resize(new_size, resample=Image.NEAREST) 179 | 180 | # Horizontal flip: if the flag is True, perform a horizontal flip 181 | if flip_horizontal: 182 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 183 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 184 | 185 | # Vertical flip: if the flag is True, perform a vertical flip 186 | if flip_vertical: 187 | image = image.transpose(Image.FLIP_TOP_BOTTOM) 188 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 189 | 190 | # Rotation: if the rotation angle is non-zero, rotate the image and mask. 191 | if rotate_angle != 0: 192 | image = image.rotate(rotate_angle, expand=True, resample=Image.BICUBIC) 193 | mask = mask.rotate(rotate_angle, expand=True, resample=Image.NEAREST) 194 | 195 | return image, mask 196 | 197 | -------------------------------------------------------------------------------- /datasets/panoptic_coco_categories.json: -------------------------------------------------------------------------------- 1 | [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}] 2 | -------------------------------------------------------------------------------- /segmentation_generator/panoptic_coco_categories.json: -------------------------------------------------------------------------------- 1 | [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}] 2 | -------------------------------------------------------------------------------- /datasets/coco_categories.py: -------------------------------------------------------------------------------- 1 | COCO_CATEGORIES = [{"supercategory": "person", "color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "building", "color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "raw-material", "color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "furniture-stuff", "color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "textile", "color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "floor", "color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "food-stuff", "color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "ground", "color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "building", "color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "furniture-stuff", "color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "furniture-stuff", "color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "structural", "color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "textile", "color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "ground", "color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "ground", "color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "building", "color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "ground", "color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "ground", "color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "furniture-stuff", "color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "building", "color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "wall", "color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "window", "color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "plant", "color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, {"supercategory": "structural", "color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, {"supercategory": "ceiling", "color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, {"supercategory": "sky", "color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, {"supercategory": "furniture-stuff", "color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, {"supercategory": "furniture-stuff", "color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, {"supercategory": "floor", "color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, {"supercategory": "ground", "color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, {"supercategory": "solid", "color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, {"supercategory": "plant", "color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, {"supercategory": "ground", "color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, {"supercategory": "raw-material", "color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, {"supercategory": "food-stuff", "color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, {"supercategory": "building", "color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, {"supercategory": "solid", "color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, {"supercategory": "wall", "color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, {"supercategory": "textile", "color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}] 2 | THING_CATEGORIES = [cat for cat in COCO_CATEGORIES if cat["isthing"] == 1] -------------------------------------------------------------------------------- /scripts/generate_with_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Switch execution environment to parent directory 5 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 6 | os.chdir(project_root) 7 | sys.path.insert(0, project_root) 8 | 9 | from datasets.object_datasets import * 10 | from datasets.background_datasets import * 11 | from datasets.data_manager import DataManager 12 | from segmentation_generator.segmentataion_synthesizer import RandomCenterPointSegmentationSynthesizer, FineGrainedBoundingBoxSegmentationSynthesizer 13 | import numpy as np 14 | import json 15 | from PIL import Image 16 | import os 17 | import random 18 | import argparse 19 | from tqdm import tqdm 20 | import concurrent.futures 21 | 22 | 23 | # Filter conditions 24 | filtering_setting_0 = { 25 | "IsInstance": "non-filter", 26 | "Integrity": "non-filter", 27 | "QualityAndRegularity": "non-filter" 28 | } 29 | filtering_setting_1 = { 30 | "IsInstance": "filter", 31 | "Integrity": "non-filter", 32 | "QualityAndRegularity": "non-filter" 33 | } 34 | filtering_setting_2 = { 35 | "IsInstance": "non-filter", 36 | "Integrity": "filter", 37 | "QualityAndRegularity": "non-filter" 38 | } 39 | filtering_setting_3 = { 40 | "IsInstance": "non-filter", 41 | "Integrity": "non-filter", 42 | "QualityAndRegularity": "filter" 43 | } 44 | filtering_setting_4 = { 45 | "IsInstance": "filter", 46 | "Integrity": "filter", 47 | "QualityAndRegularity": "filter" 48 | } 49 | 50 | 51 | def reset_folders(paths): 52 | """Deletes and recreates specified directories to ensure a clean slate.""" 53 | for path in paths: 54 | if os.path.exists(path): 55 | # Uncomment below to delete existing folders if needed: 56 | # shutil.rmtree(path) 57 | pass 58 | if not os.path.exists(path): 59 | print(f"Creating: {path}") 60 | os.makedirs(path, exist_ok=True) 61 | 62 | 63 | def prepare_folders(image_save_path, mask_save_path, annotation_path): 64 | reset_folders([image_save_path, mask_save_path, annotation_path]) 65 | 66 | 67 | class NumpyEncoder(json.JSONEncoder): 68 | """Custom JSON encoder to handle NumPy data types.""" 69 | def default(self, obj): 70 | if isinstance(obj, np.integer): 71 | return int(obj) 72 | return super(NumpyEncoder, self).default(obj) 73 | 74 | # without bg, for relightening and blending 75 | def initialize_global_categories(annotation_path, filtering_setting, global_data_manager): 76 | """ 77 | Initializes the global categories from the DataManager. 78 | This function runs in the main process so that the global categories 79 | are available when merging annotation JSONs. 80 | """ 81 | categories = global_data_manager.unified_object_categories 82 | categories_to_idx = {category: idx for idx, category in enumerate(categories)} 83 | idx_to_categories = {idx: category for idx, category in enumerate(categories)} 84 | global_categories = [{"id": idx, "name": category} for idx, category in enumerate(categories)] 85 | 86 | # Save category files for reference. 87 | json.dump(categories, open(os.path.join(annotation_path, "categories.json"), "w"), indent=4) 88 | json.dump(categories_to_idx, open(os.path.join(annotation_path, "categories_to_idx.json"), "w"), indent=4) 89 | json.dump(idx_to_categories, open(os.path.join(annotation_path, "idx_to_categories.json"), "w"), indent=4) 90 | 91 | return global_categories 92 | 93 | 94 | def process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager): 95 | """ 96 | Worker function that processes a batch of images with a unique seed. 97 | For each image, the corresponding annotation is saved as a separate JSON file. 98 | """ 99 | random.seed(worker_seed) 100 | np.random.seed(worker_seed) 101 | 102 | rss = FineGrainedBoundingBoxSegmentationSynthesizer(data_manager, "./", random_seed=worker_seed) 103 | 104 | # Create a subfolder for separate annotations 105 | separate_annotation_path = os.path.join(annotation_path, "separate_annotations") 106 | if not (os.path.exists(separate_annotation_path) and os.path.isdir(separate_annotation_path)): 107 | os.makedirs(separate_annotation_path, exist_ok=True) 108 | 109 | for i in range(start_idx, end_idx): 110 | image_name = f"{i}.png" 111 | mask_name = f"{i}.png" 112 | image_path_full = os.path.join(image_save_path, image_name) 113 | mask_path_full = os.path.join(mask_save_path, mask_name) 114 | separate_annot_file = os.path.join(separate_annotation_path, f"{i}.json") 115 | 116 | # Skip processing if image, mask, and annotation all exist. 117 | if os.path.exists(image_path_full) and os.path.exists(mask_path_full) and os.path.exists(separate_annot_file): 118 | queue.put(1) 119 | continue 120 | 121 | obj_nums = np.random.randint(5, 20) 122 | data = rss.sampling_metadata(1024, 1024, obj_nums, hasBackground=False, dataAugmentation=False) 123 | res = rss.generate_with_unified_format(data, containRGBA=True, containCategory=True, containSmallObjectMask=False, resize_mode="fit") 124 | 125 | # Save image and mask. 126 | Image.fromarray(res['image_rgba']).save(image_path_full) 127 | Image.fromarray(res['coco_mask']).save(mask_path_full) 128 | 129 | # Build annotation for this image. 130 | annotation = { 131 | "segments_info": res["segments_info"], 132 | "file_name": mask_name, 133 | "image_id": i, 134 | } 135 | # Save the annotation as a separate JSON file. 136 | with open(separate_annot_file, "w") as f: 137 | json.dump(annotation, f, cls=NumpyEncoder, indent=4) 138 | 139 | queue.put(1) 140 | 141 | # with bg 142 | # def process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager): 143 | # """ 144 | # Worker function that processes a batch of images with a unique seed. 145 | # For each image, the corresponding annotation is saved as a separate JSON file. 146 | # """ 147 | # random.seed(worker_seed) 148 | # np.random.seed(worker_seed) 149 | 150 | # rss = FineGrainedBoundingBoxSegmentationSynthesizer(data_manager, "./", random_seed=worker_seed) 151 | 152 | # # Create a subfolder for separate annotations 153 | # separate_annotation_path = os.path.join(annotation_path, "separate_annotations") 154 | # if not (os.path.exists(separate_annotation_path) and os.path.isdir(separate_annotation_path)): 155 | # os.makedirs(separate_annotation_path, exist_ok=True) 156 | 157 | # for i in range(start_idx, end_idx): 158 | # image_name = f"{i}.jpg" 159 | # mask_name = f"{i}.png" 160 | # image_path_full = os.path.join(image_save_path, image_name) 161 | # mask_path_full = os.path.join(mask_save_path, mask_name) 162 | # separate_annot_file = os.path.join(separate_annotation_path, f"{i}.json") 163 | 164 | # # Skip processing if image, mask, and annotation all exist. 165 | # if os.path.exists(image_path_full) and os.path.exists(mask_path_full) and os.path.exists(separate_annot_file): 166 | # queue.put(1) 167 | # continue 168 | 169 | # obj_nums = np.random.randint(5, 20) 170 | # data = rss.sampling_metadata(1024, 1024, obj_nums, hasBackground=True, dataAugmentation=False) 171 | # res = rss.generate_with_unified_format(data, containRGBA=True, containCategory=True, containSmallObjectMask=False, resize_mode="fit") 172 | 173 | # # Save image and mask. 174 | # res['image'].save(image_path_full) 175 | # Image.fromarray(res['coco_mask']).save(mask_path_full) 176 | 177 | # # Build annotation for this image. 178 | # annotation = { 179 | # "segments_info": res["segments_info"], 180 | # "file_name": mask_name, 181 | # "image_id": i, 182 | # } 183 | # # Save the annotation as a separate JSON file. 184 | # with open(separate_annot_file, "w") as f: 185 | # json.dump(annotation, f, cls=NumpyEncoder, indent=4) 186 | 187 | # queue.put(1) 188 | 189 | 190 | def listener(queue, total): 191 | """Updates the tqdm progress bar as images are generated.""" 192 | pbar = tqdm(total=total) 193 | while True: 194 | msg = queue.get() 195 | if msg == 'kill': 196 | break 197 | pbar.update(msg) 198 | pbar.close() 199 | 200 | 201 | 202 | 203 | class NumpyEncoder(json.JSONEncoder): 204 | """Custom JSON encoder to handle NumPy data types.""" 205 | def default(self, obj): 206 | try: 207 | import numpy as np 208 | if isinstance(obj, np.integer): 209 | return int(obj) 210 | except ImportError: 211 | pass 212 | return super(NumpyEncoder, self).default(obj) 213 | 214 | # Define process_file at the module level so it is pickleable 215 | def process_file(annot_file): 216 | try: 217 | with open(annot_file, "r") as f: 218 | annot = json.load(f) 219 | # Build image info assuming a fixed size (1024x1024) 220 | image_info = { 221 | "file_name": f"{annot['image_id']}.jpg", 222 | "height": 1024, 223 | "width": 1024, 224 | "id": annot["image_id"] 225 | } 226 | return annot, image_info 227 | except Exception: 228 | return None 229 | 230 | def merge_annotation_jsons(annotation_path, json_save_path, categories): 231 | """ 232 | Merges separate annotation JSON files (stored in the "separate_annotations" subfolder) 233 | into a final COCO-format JSON. Also constructs the images list based on the annotation files. 234 | This version uses multiprocessing via a ProcessPoolExecutor with os.scandir for efficient directory listing. 235 | """ 236 | separate_annotation_path = os.path.join(annotation_path, "separate_annotations") 237 | # Use os.scandir for efficient directory traversal 238 | json_files = [entry.path for entry in os.scandir(separate_annotation_path) 239 | if entry.is_file() and entry.name.endswith(".json")] 240 | 241 | annotations = [] 242 | images = [] 243 | 244 | # Adjust max_workers to a number more appropriate for your system (e.g., 4 to 8) 245 | max_workers = 100 246 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: 247 | results = list(tqdm(executor.map(process_file, json_files), 248 | total=len(json_files), 249 | desc="Merging annotations")) 250 | for res in results: 251 | if res is not None: 252 | annot, image_info = res 253 | annotations.append(annot) 254 | images.append(image_info) 255 | 256 | # Sort the lists based on their IDs 257 | images = sorted(images, key=lambda x: x["id"]) 258 | annotations = sorted(annotations, key=lambda x: x["image_id"]) 259 | 260 | coco_json = { 261 | "images": images, 262 | "annotations": annotations, 263 | "categories": categories 264 | } 265 | 266 | with open(json_save_path, "w") as f: 267 | json.dump(coco_json, f, cls=NumpyEncoder, indent=4) 268 | 269 | 270 | def run_generation(args): 271 | """ 272 | Runs the synthetic image generation process. 273 | Each worker saves its annotation into a separate JSON file. 274 | In the end, all these JSON files are merged into one final JSON. 275 | """ 276 | image_save_path = args.image_save_path 277 | mask_save_path = args.mask_save_path 278 | annotation_path = args.annotation_path 279 | json_save_path = args.json_save_path 280 | 281 | prepare_folders(image_save_path, mask_save_path, annotation_path) 282 | 283 | # Set the filtering setting based on argument 284 | fs = str(args.filtering_setting) 285 | if fs.isdigit(): 286 | fs = "filter_" + fs 287 | if fs == 'filter_0': 288 | filtering_setting = filtering_setting_0 289 | elif fs == 'filter_1': 290 | filtering_setting = filtering_setting_1 291 | elif fs == 'filter_2': 292 | filtering_setting = filtering_setting_2 293 | elif fs == 'filter_3': 294 | filtering_setting = filtering_setting_3 295 | else: 296 | filtering_setting = filtering_setting_4 297 | 298 | # Initialize global DataManager instance in the main process. 299 | global_data_manager = DataManager(available_object_datasets, available_background_datasets, filtering_setting) 300 | 301 | # Initialize global categories in the main process by passing global_data_manager. 302 | global_categories = initialize_global_categories(annotation_path, filtering_setting, global_data_manager) 303 | 304 | num_images = args.total_images 305 | num_workers = args.num_processes 306 | 307 | chunk_size = num_images // num_workers 308 | remainder = num_images % num_workers 309 | 310 | manager = multiprocessing.Manager() 311 | queue = manager.Queue() 312 | 313 | listener_process = multiprocessing.Process(target=listener, args=(queue, num_images)) 314 | listener_process.start() 315 | 316 | processes = [] 317 | start_index = 0 318 | 319 | for i in range(num_workers): 320 | end_index = start_index + chunk_size 321 | if i < remainder: 322 | end_index += 1 323 | 324 | worker_seed = random.randint(0, int(1e6)) 325 | 326 | def worker_task(start_idx=start_index, end_idx=end_index, worker_seed=worker_seed): 327 | process_image_worker(start_idx, end_idx, worker_seed, filtering_setting, queue, image_save_path, mask_save_path, annotation_path, data_manager=global_data_manager) 328 | 329 | p = multiprocessing.Process(target=worker_task) 330 | processes.append(p) 331 | p.start() 332 | start_index = end_index 333 | 334 | for p in processes: 335 | p.join() 336 | 337 | queue.put('kill') 338 | listener_process.join() 339 | 340 | # Merge individual annotation JSON files into the final COCO JSON. 341 | merge_annotation_jsons(annotation_path, json_save_path, global_categories) 342 | 343 | 344 | def main(): 345 | parser = argparse.ArgumentParser( 346 | description="Synthetic Image Generation with Segmentation Annotations" 347 | ) 348 | parser.add_argument('--num_processes', type=int, default=100, 349 | help='Number of processes to use') 350 | parser.add_argument('--total_images', type=int, default=1000, 351 | help='Total number of images to generate') 352 | parser.add_argument('--filtering_setting', type=str, default='filter_4', 353 | help='Filtering setting to apply. Either "filter_1", "filter_2", "filter_3", "filter_4", or a digit 1-4.') 354 | parser.add_argument('--image_save_path', type=str, 355 | default="/output/train", 356 | help="Path to save images") 357 | parser.add_argument('--mask_save_path', type=str, 358 | default="/output/panoptic_train", 359 | help="Path to save masks") 360 | parser.add_argument('--annotation_path', type=str, 361 | default="/output/annotations", 362 | help="Path to save separate annotation JSON files") 363 | parser.add_argument('--json_save_path', type=str, 364 | default="/output/annotations/panoptic_train.json", 365 | help="Path to save the merged COCO panoptic JSON") 366 | args = parser.parse_args() 367 | 368 | run_generation(args) 369 | 370 | 371 | if __name__ == "__main__": 372 | 373 | # using FC split data 374 | available_object_datasets = { 375 | "Synthetic": SyntheticDataset( 376 | dataset_path="/fc_10m", 377 | synthetic_annotation_path="/fc_10m/gc_object_segments_metadata.json", 378 | dataset_name="Synthetic", 379 | cache_path="./metadata_ovd_cache" 380 | ), 381 | } 382 | 383 | # using GC split data 384 | # available_object_datasets = { 385 | # "Synthetic": SyntheticDataset( 386 | # dataset_path="/gc_10m", 387 | # synthetic_annotation_path="/gc_10m/gc_object_segments_metadata.json", 388 | # dataset_name="Synthetic", 389 | # cache_path="./metadata_ovd_cache" 390 | # ), 391 | # } 392 | 393 | # mixing data from different source 394 | # available_object_datasets = { 395 | # "Synthetic_fc": SyntheticDataset( 396 | # dataset_path="/fc_10m", 397 | # synthetic_annotation_path="/fc_10m/gc_object_segments_metadata.json", 398 | # dataset_name="Synthetic_fc", 399 | # cache_path="./metadata_fc_cache" 400 | # ), 401 | # "Synthetic_gc": SyntheticDataset( 402 | # dataset_path="/gc_10m", 403 | # synthetic_annotation_path="/gc_10m/gc_object_segments_metadata.json", 404 | # dataset_name="Synthetic_gc", 405 | # cache_path="./metadata_gc_cache" 406 | # ), 407 | # } 408 | 409 | # By default, this loads from a placeholder dataset containing only a single image using `background`. 410 | # To use `background` in BG20k, please set up the dataset path. 411 | available_background_datasets = { 412 | "BG20k": BG20KDataset("datasets/one_image_bg") 413 | # "BG20k": BG20KDataset("/your/path/to/bg20k_dataset") 414 | 415 | } 416 | 417 | main() 418 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Synthetic Object Compositions for Scalable and Accurate Learning in Detection, Segmentation, and Grounding 3 |

4 | 5 |

6 | 2 Million Diverse, Accurate Synthetic Dense-Annotated Images (FC-1M + GC-1M) + 20M Synthetic Object Segments to Supercharge Grounding-DINO, Mask2Former, and Any Detectors / Segmentors / Grounding-VLMs 7 |

8 | 9 | 10 |

11 | 12 | 📑 Paper | 13 | 🤗 Datasets: 2M images + 20M segments 14 |

15 | 16 |

17 | Weikai Huang1, Jieyu Zhang1, 18 | Taoyang Jia1, Chenhao Zheng1, Ziqi Gao1, 19 | Jae Sung Park1, Ranjay Krishna1,2
20 | 1  University of Washington   21 | 2Allen Institute for AI 22 |

23 | 24 | --- 25 | 26 |

27 | Text-to-Image Results 28 |

29 |

A scalable pipeline for composing high-quality synthetic object segments into richly annotated images for object detection, instance segmentation, and visual grounding.

30 | 31 | ## 🌟 Highlights 32 | 33 | **Why SOC?** A small amount of high-quality synthetic data can outperform orders of magnitude more real data: 34 | 35 | - 🚀 **Efficient & Scalable**: Just **50K** SOC images match the gains from **20M** model-generated (GRIT) or **200K** human-annotated (V3Det) images on LVIS detection. We compose 2 million diverse images (FC-1M + GC-1M) with annotations, along with 20 million synthetic object segments across 47,000+ categories from Flux. 36 | - 🎯 **Accurate Annotations**: Object-centric composition provides pixel-perfect masks, boxes, and referring expressions—no noisy pseudo-labels 37 | - 🎨 **Controllable Generation**: Synthesize targeted data for specific scenarios (e.g., intra-class referring, rare categories, domain-specific applications) 38 | - 🔄 **Complementary to Real Data**: Adding SOC to existing datasets (COCO, LVIS, V3Det, GRIT) yields consistent additive gains across all benchmarks 39 | - 💰 **Cost-Effective**: Generate unlimited training data from 20M object segments without expensive human annotation 40 | - 📈 **100K SOC surpasses larger real-data baselines**: +10.9 LVIS AP (OVD) and +8.4 gRefCOCO NAcc (VG); and remains complementary when combined with GRIT/V3Det 41 | 42 | 43 | --- 44 | 45 | # 📊 Released Datasets 46 | 47 | We release the following datasets for research use: 48 | 49 | | Dataset Name | # Images | # Categories | Description | Download | 50 | |-------------|----------|--------------|-------------|----------| 51 | | **FC-1M** | 1,000,000 | 1,600 | Frequent Categories | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-FC-1M) | 52 | | **GC-1M** | 1,000,000 | 47,000+ | General Categories | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-GC-1M) | 53 | | **SFC-200K** | 200,000 | 1,600 | Single-category Frequent Category — same category objects with varied attributes | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-SFC-200K) | 54 | | **SGC-200K** | 200,000 | 47,000+ | Single-category General Category — same category objects with varied attributes | [🤗 HuggingFace](https://huggingface.co/datasets/weikaih/SOC-SGC-200K) | 55 | 56 | 57 |

Examples of dataset types:

58 | 59 | 60 | 64 | 68 | 69 |
61 | FC/GC examples
62 | FC / GC 63 |
65 | SFC/SGC examples
66 | SFC / SGC 67 |
70 | 71 | **All datasets include:** 72 | - ✅ High-resolution images with photorealistic relighting and blending 73 | - ✅ Pixel-perfect segmentation masks 74 | - ✅ Tight bounding boxes 75 | - ✅ Category labels 76 | - ✅ Diverse referring expressions (attribute-based, spatial-based, and mixed) 77 | 78 | **Note:** Other dataset variants (e.g., SOC-LVIS, MixCOCO) contain segments from existing datasets and cannot be released. Please use the code in this repository to compose your own datasets from the released object segments. 79 | 80 | ## Object Segments 81 | 82 | We also release **20M synthetic object segments** used to compose the above datasets: 83 | 84 | | Segment Set | # Segments | # Categories | Prompts/Category | Segments/Prompt | Download | 85 | |-------------|------------|--------------|------------------|-----------------|----------| 86 | | **FC Object Segments** | 10,000,000 | 1,600 | 200 | 3 | [🤗 SOC-FC-Object-Segments-10M](https://huggingface.co/datasets/weikaih/SOC-FC-Object-Segments-10M) | 87 | | **GC Object Segments** | 10,000,000 | 47,000+ | 10 | 3 | [🤗 SOC-GC-Object-Segments-10M](https://huggingface.co/datasets/weikaih/SOC-GC-Object-Segments-10M) | 88 | 89 | Browse all sets via the collection: [🤗 HuggingFace Collection](https://huggingface.co/collections/weikaih/SOC-synthetic-object-segments-improves-detection-segmentat-682679751d20faa20800033c) 90 | 91 | --- 92 | 93 | # 📦 Installation 94 | 95 | *Notice*: We provide only minimal guidance for the core parts of the codebase for: image composing, relighting and blending, and referring expression generation. The full documentation (with an accompanying arXiv paper) covering additional tasks and case studies will be released soon. 96 | 97 | ## Environment Setup 98 | Follow the steps below to set up the environment and use the repository: 99 | ```bash 100 | # Clone the repository 101 | git clone https://github.com/weikaih04/SOC 102 | cd ./SOC 103 | 104 | # Create and activate a Python virtual environment: 105 | conda create -n SOC python==3.10 106 | conda activate SOC 107 | 108 | # Install the required dependencies for composing images with synthetic object segments: 109 | pip install -r requirements.txt 110 | 111 | # If you want to perform relighting and blending: 112 | conda create -n SOC-relight python==3.10 113 | conda activate SOC-relight 114 | pip install -r requirements_relight_and_blending.txt 115 | 116 | # If you want to generating referring expression: 117 | conda create -n SOC-ref python==3.10 118 | conda activate SOC-ref 119 | pip install -r requirements_referring_expression_generation.txt 120 | ``` 121 | 122 | 123 | ## Background Dataset (Optional) 124 | If you want to relight images and didn't directly paste object segments into the background, just use a random image as the background and set the `hasBackground` to false in `scripts/generate_with_batch.py` 125 | 126 | You can download the BG-20K from this repo: https://github.com/JizhiziLi/GFM.git 127 | 128 | # 🚀 Usage 129 | 130 | ## Composing Synthetic Images 131 | We provide scripts to compose images with synthetic segments: 132 | 133 | If you want to generate images for relighting and blending that only contain foreground object segments: 134 | ```bash 135 | python scripts/generate_with_batch.py \ 136 | --num_processes 100 \ # depends on your CPUs 137 | --total_images 100000 \ 138 | --filtering_setting filter_0 \ 139 | --image_save_path "/output/dataset_name/train" \ 140 | --mask_save_path "/output/dataset_name/panoptic_train" \ 141 | --annotation_path "/output/dataset_name/annotations" \ 142 | --json_save_path "/output/dataset_name/annotations/panoptic_train.json" 143 | ``` 144 | 145 | 146 | ### Key parameters 147 | - --num_processes: Number of parallel workers to generate images; set based on CPU cores. 148 | - --total_images: Total images to generate. 149 | - --filtering_setting: One of filter_0..filter_4 (filter_4 = strictest). Controls segment quality filters. 150 | - --image_save_path: Output path for rendered RGBA images (PNG). 151 | - --mask_save_path: Output path for color panoptic masks (PNG). 152 | - --annotation_path: Output folder for per-image JSONs and category maps. 153 | - --json_save_path: Final merged COCO-style panoptic JSON path. 154 | 155 | Important: At the end of scripts/generate_with_batch.py, available_object_datasets must point to your local copies of released FC/GC object segments and their metadata JSON. For example, if you downloaded SOC-FC-Object-Segments-10M to /data/fc_10m with metadata fc_object_segments_metadata.json, set: 156 | - dataset_path="/data/fc_10m" 157 | - synthetic_annotation_path="/data/fc_10m/fc_object_segments_metadata.json" 158 | Similarly for GC: gc_object_segments_metadata.json 159 | 160 | 161 | 162 | 163 | Notes 164 | - We expect dataset_path to contain category/subcategory/ID.png structure as provided in our released object-segment datasets. 165 | - The script writes per-image JSONs under annotation_path/separate_annotations and merges them into the final COCO-style panoptic JSON at json_save_path. 166 | 167 | Minimal example 168 | ```bash 169 | # Symlink your datasets to the default paths expected by the script (optional) 170 | ln -s /data/fc_10m /fc_10m 171 | ln -s /data/gc_10m /gc_10m 172 | 173 | # Generate a tiny sample dataset locally 174 | python scripts/generate_with_batch.py \ 175 | --num_processes 4 \ 176 | --total_images 20 \ 177 | --filtering_setting filter_0 \ 178 | --image_save_path "./out/train" \ 179 | --mask_save_path "./out/panoptic_train" \ 180 | --annotation_path "./out/annotations" \ 181 | --json_save_path "./out/annotations/panoptic_train.json" 182 | ``` 183 | 184 | If you want to generate images that directly paste objects onto backgrounds, uncomment the `with bg process_image_worker` function in `scripts/generate_with_batch.py`. 185 | 186 | ## Relighting and Blending 187 | Relight and blend images using IC-Light with mask-area-weighted blending to enhance photorealism while preserving object details and colors: 188 | 189 | ```bash 190 | python relighting_and_blending/inference.py \ 191 | --dataset_path "$DATASET_PATH" \ 192 | --output_data_path "$OUTPUT_DATA_PATH" \ 193 | --num_splits "$NUM_SPLITS" \ 194 | --split "$SPLIT" \ 195 | --index_json_path "" \ 196 | --illuminate_prompts_path "$ILLUMINATE_PROMPTS_PATH" \ 197 | --record_path "$RECORD_PATH" 198 | ``` 199 | 200 | Currently supports Google Cloud Storage access and local file system. 201 | 202 | Notes 203 | - Requires a CUDA GPU. Models load in half precision; 12GB+ VRAM recommended. 204 | - Weights auto-download on first run: 205 | - Stable Diffusion components from stablediffusionapi/realistic-vision-v51 206 | - Background remover briaai/RMBG-1.4 207 | - IC-Light offset iclight_sd15_fc.safetensors (downloaded to ./models if missing) 208 | - Input expectations: 209 | - dataset_path should point to the folder with RGBA foreground PNGs (e.g., ./out/train) named 0.png, 1.png, ... 210 | - A matching color panoptic mask must exist at the same id under dataset_path with "train" replaced by "panoptic_train" (e.g., ./out/panoptic_train/0.png) 211 | - illuminate_prompts_path must be a JSON file containing an array of prompt strings for relighting. 212 | 213 | Minimal example 214 | ````bash 215 | # Create a tiny illumination prompt list 216 | cat > ./illumination_prompt.json << 'JSON' 217 | [ 218 | "golden hour lighting, soft shadows", 219 | "overcast daylight, diffuse light", 220 | "studio softbox lighting" 221 | ] 222 | JSON 223 | 224 | # Relight a small sample from the composed outputs 225 | python relighting_and_blending/inference.py \ 226 | --dataset_path ./out/train \ 227 | --output_data_path ./out/relit \ 228 | --num_splits 1 \ 229 | --split 0 \ 230 | --illuminate_prompts_path ./illumination_prompt.json 231 | ```` 232 | 233 | 234 | ## Referring Expression Generation 235 | We use an OpenAI-compatible endpoint (vLLM) and query a local model. 236 | 237 | Step 1) Start an OpenAI-compatible server (port 8080) 238 | ````bash 239 | # Example: start vLLM OpenAI server with the model used in our script 240 | python -m vllm.entrypoints.openai.api_server \ 241 | --model Qwen/QwQ-32B-AWQ \ 242 | --host 0.0.0.0 \ 243 | --port 8080 244 | ```` 245 | Notes 246 | - Our script currently assumes base_url=http://localhost:8080/v1. 247 | - Ensure your GPU/driver supports the chosen model; adjust model name if needed. 248 | 249 | Step 2) Run the generator 250 | ````bash 251 | # INPUT_FILE is the merged COCO-style JSON from the composing stage 252 | # OUTPUT_DIR will contain jsonl shards (one per job): job_0.jsonl, ... 253 | export OPENAI_API_KEY=dummy_key # any non-empty string is accepted 254 | 255 | python referring_expression_generation/inference.py \ 256 | 1 \ 257 | 0 \ 258 | ./out/annotations/panoptic_train.json \ 259 | ./out/refexp \ 260 | --api_key "$OPENAI_API_KEY" \ 261 | --num_workers 8 262 | ```` 263 | Outputs 264 | - At least 9 expressions per image (balanced across attribute/spatial/reasoning, single/multi). 265 | - Writes per-job jsonl files under OUTPUT_DIR. 266 | - Supports local paths and GCS (gs://) for both inputs and outputs. 267 | 268 | # 📈 Results 269 | 270 | ## Task 1: Open-Vocabulary Object Detection 271 | 272 | **Model**: MM-Grounding-DINO | **Benchmarks**: LVIS v1.0 Full Val, OdinW-35 273 | 274 |

275 | Open-Vocabulary Detection Results 276 |

277 | 278 | ### Key Findings 279 | 280 | #### 🎯 Small Amount of SOC Efficiently Brings Strong Gains 281 | With only **50K** synthetic images, SOC delivers gains comparable to orders of magnitude more real data: 282 | 283 | | Training Data | LVIS AP | APrare | Gain vs Baseline | 284 | |--------------|---------|-------------------|------------------| 285 | | Object365+GoldG (Baseline) | 20.1 | 10.1 | - | 286 | | + GRIT (20M images) | 27.1 | 17.1 | +7.0 AP | 287 | | + V3Det (200K images) | 30.6 | 24.6 | +10.5 AP | 288 | | **+ SOC-50K** | **29.8** | **23.5** | **+9.7 AP** | 289 | 290 | **SOC-50K matches V3Det's gains with 400× fewer images!** 291 | 292 | #### 📊 Scaling Up SOC Data Leads to Better Performance 293 | Continuous improvements as we scale from 50K → 100K → 400K: 294 | 295 | | SOC Scale | LVIS AP | APrare | OdinW-35 mAP | 296 | |-----------|---------|-------------------|--------------| 297 | | 50K | 29.8 | 23.5 | 21.0 | 298 | | 100K | 31.0 (+1.2) | 26.3 (+2.8) | 21.0 | 299 | | 400K | **31.4 (+1.6)** | **27.9 (+1.6)** | **22.8 (+1.8)** | 300 | 301 | #### 🔄 SOC is Complementary to Real Datasets 302 | Adding SOC on top of large-scale real datasets yields additive gains: 303 | 304 | | Training Data | LVIS AP | APrare | OdinW-35 mAP | 305 | |--------------|---------|-------------------|--------------| 306 | | Object365+GoldG+V3Det+GRIT | 31.9 | 23.6 | - | 307 | | **+ SOC-100K** | **33.2 (+1.3)** | **29.8 (+6.2)** | **+2.8** | 308 | 309 | SOC introduces novel vocabulary and contextual variations not captured by existing real datasets. 310 | 311 | ## Task 2: Visual Grounding 312 | 313 | **Model**: MM-Grounding-DINO | **Benchmarks**: RefCOCO/+/g, gRefCOCO, DoD 314 | 315 |

316 | Visual Grounding Results 317 |

318 | 319 | ### Key Findings 320 | 321 | #### ⚠️ Existing Large Detection and Grounding Datasets Yield Only Marginal Improvements 322 | Large-scale real datasets provide limited gains for referring expression tasks: 323 | 324 | | Training Data | gRefCOCO P@1 | gRefCOCO NAcc | DoD FULL mAP | 325 | |--------------|--------------|--------------------------|--------------| 326 | | Object365+GoldG | - | 89.3 | - | 327 | | + V3Det (200K) | +0.5 | +0.0 | - | 328 | | + GRIT (20M) | - | - | +1.4 | 329 | 330 | **Why?** V3Det lacks sentence-level supervision; GRIT uses noisy model-generated caption-box pairs. 331 | 332 | #### ✨ SOC Provides Diverse, High-Quality Referring Expressions 333 | SOC generates precise referring pairs from ground truth annotations without human labels: 334 | 335 | | Training Data | gRefCOCO NAcc | DoD FULL mAP | Gain | 336 | |--------------|--------------------------|--------------|------| 337 | | Object365+GoldG | 89.3 | - | Baseline | 338 | | **+ SOC-50K** | **93.9 (+4.6)** | **+1.0** | 50K images | 339 | | **+ SOC-100K** | **97.7 (+8.4)** | **+3.8** | 100K images | 340 | 341 | **Expression Types** (3-6 per type, balanced coverage): 342 | - **Attribute-based**: "the red apple", "charcoal-grey cat" 343 | - **Spatial-based**: "dog to the right of the bike" 344 | - **Mixed-type**: "red object to the right of the child" 345 | 346 | SOC's gains per example far outperform GRIT (20M) and V3Det (200K)! 347 | 348 | ## Task 3: Instance Segmentation 349 | 350 | **Model**: APE (LVIS pre-trained) | **Benchmark**: LVIS v1.0 Val 351 | 352 |

353 | Instance Segmentation Results 354 |

355 | 356 | ### Key Findings 357 | 358 | #### 🎯 SOC Continuously Improves LVIS Segmentation 359 | Two-stage fine-tuning: (1) Train on 50K SOC-LVIS → (2) Continue on LVIS train split 360 | 361 | | Training Protocol | AP | APrare | APcommon | APfrequent | 362 | |------------------|-------|-------------------|---------------------|----------------------| 363 | | LVIS only | 46.96 | 40.87 | - | - | 364 | | **SOC-50K → LVIS** | **48.48 (+1.52)** | **44.70 (+3.83)** | - | **(+0.31)** | 365 | 366 | **Why the large rare-class gain?** Synthetic data can be generated to cover underrepresented classes, mitigating LVIS's long-tail imbalance. Frequent classes already have ample real examples and benefit less. 367 | 368 | --- 369 | 370 | ## Task 4: Small-Vocabulary, Limited-Data Regimes 371 | 372 | **Model**: Mask2Former-ResNet-50 | **Benchmark**: COCO Instance Segmentation 373 | 374 | ### Key Findings 375 | 376 | #### 💰 SOC Excels in Low-Data Regimes 377 | Mixing real COCO segments with SOC synthetic segments (80 COCO categories): 378 | 379 | | COCO Data Scale | COCO Only | COCO + SOC | Gain | 380 | |----------------|-----------|------------|------| 381 | | 1% (~1K images) | - | - | **+6.59 AP** | 382 | | 10% (~10K images) | - | - | **~+3 AP** | 383 | | 50% (~50K images) | - | - | **~+3 AP** | 384 | | 100% (Full) | - | - | **~+3 AP** | 385 | 386 | **Key Insight**: The boost is particularly dramatic at 1% COCO (+6.59 AP), and grows by roughly 3% at each subsequent data scale. SOC is most effective when real data is scarce! 387 | 388 | --- 389 | 390 | ## Task 5: Intra-Class Referring Expression 391 | 392 | **Model**: MM-Grounding-DINO | **Benchmark**: Custom intra-class benchmark (COCO + OpenImages V7) 393 | 394 |

395 | Intra-Class Referring Results 396 |

397 | 398 | ### What is Intra-Class Referring? 399 | A challenging visual grounding task requiring fine-grained attribute discrimination among same-category instances. 400 | 401 | **Example**: In an image with multiple cars of different colors and makes, locate "the charcoal-grey sedan" (not just "car"). 402 | 403 | **Why it's hard**: Models often shortcut by ignoring attributes and relying solely on category nouns. 404 | 405 | ### Evaluation Metrics 406 | - **Average Gap**: Average confidence margin between ground-truth box and highest-scoring same-category distractor 407 | - **Positive Gap Ratio**: Percentage of images where ground-truth box receives highest confidence among same-category candidates 408 | 409 | ### Key Findings 410 | 411 | #### 🎯 Targeted SOC Data Fixes Intra-Class Shortcuts 412 | 413 | | Training Data | Average Gap | Positive Gap Ratio | 414 | |--------------|-------------|-------------------| 415 | | Object365+GoldG | 37.5 | ~80% | 416 | | + GRIT (20M) | 34.6 (-2.9) | ~82% | 417 | | + V3Det (200K) | 36.7 (-0.8) | ~83% | 418 | | + GRIT + V3Det | 35.8 (-1.7) | ~85% | 419 | | **+ SOC-SFC-50K + SOC-SGC-50K** | **40.6 (+3.1)** | **90%** | 420 | 421 | **SOC-SFC/SGC**: Synthetic images with multiple instances of the same category but varied attributes (e.g., cars with different colors and makes). 422 | 423 | **Key Insight**: Large-scale auxiliary data (GRIT, V3Det) yields negligible or even negative impact. Only targeted synthetic data tailored to intra-class attribute variation significantly improves performance! 424 | 425 | --- 426 | 427 | 428 | 429 | # 📧 Contact 430 | 431 | * **Weikai Huang**: weikaih@cs.washington.edu 432 | * **Jieyu Zhang**: jieyuz2@cs.washington.edu 433 | 434 | --- 435 | 436 | # 📝 Citation 437 | 438 | ```bibtex 439 | @misc{huang2025syntheticobjectcompositionsscalable, 440 | title={Synthetic Object Compositions for Scalable and Accurate Learning in Detection, Segmentation, and Grounding}, 441 | author={Weikai Huang and Jieyu Zhang and Taoyang Jia and Chenhao Zheng and Ziqi Gao and Jae Sung Park and Winson Han and Ranjay Krishna}, 442 | year={2025}, 443 | eprint={2510.09110}, 444 | archivePrefix={arXiv}, 445 | primaryClass={cs.CV}, 446 | url={https://arxiv.org/abs/2510.09110}, 447 | } 448 | ``` 449 | 450 | --- 451 | 452 | # 🙏 Acknowledgments 453 | 454 | We thank the authors of FLUX-1, IC-Light, DIS, Qwen, and QwQ for their excellent open-source models that made this work possible. 455 | -------------------------------------------------------------------------------- /referring_expression_generation/inference.py: -------------------------------------------------------------------------------- 1 | # --- Chunking helper --- 2 | import random 3 | 4 | def chunk_result(result, threshold=None): 5 | """ 6 | Splits the result['grounding']['regions'] into chunks based on a character threshold, 7 | recomputing tokens_positive (now character offsets) for each region in each chunk. 8 | Each region in a chunk will have its tokens_positive field adjusted to match the chunk's caption. 9 | """ 10 | if threshold is None: 11 | threshold = random.randint(150, 255) 12 | 13 | regions_orig = result['grounding']['regions'] 14 | chunks = [] 15 | region_idx = 0 16 | total_regions = len(regions_orig) 17 | 18 | while region_idx < total_regions: 19 | temp_caption = [] 20 | temp_regions = [] 21 | cursor_chunk = 0 22 | 23 | while region_idx < total_regions: 24 | region = regions_orig[region_idx] 25 | text = region['phrase'] 26 | next_caption = '. '.join(temp_caption + [text]) + '. ' 27 | if len(next_caption) > threshold and temp_regions: 28 | break 29 | 30 | # Character-based span 31 | span_start = cursor_chunk 32 | span_end = cursor_chunk + len(text) - 1 33 | 34 | chunk_region = { 35 | 'phrase': text, 36 | 'bboxes': region['bboxes'], 37 | 'tokens_positive': [[span_start, span_end]], 38 | } 39 | temp_caption.append(text) 40 | temp_regions.append(chunk_region) 41 | 42 | # Advance by text length + separator (". ") 43 | cursor_chunk += len(text) + 2 44 | region_idx += 1 45 | 46 | chunks.append({ 47 | 'filename': result['filename'], 48 | 'height': result['height'], 49 | 'width': result['width'], 50 | 'grounding': { 51 | 'caption': '. '.join([r['phrase'] for r in temp_regions]) + '. ', 52 | 'regions': temp_regions 53 | } 54 | }) 55 | 56 | return chunks 57 | # import debugpy 58 | # try: 59 | # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 60 | # debugpy.listen(("localhost", 9501)) 61 | # print("Waiting for debugger attach") 62 | # debugpy.wait_for_client() 63 | # except Exception as e: 64 | # pass 65 | 66 | 67 | # generate_referring_expressions_odvg.py 68 | """ 69 | Generate attribute-, spatial- and reasoning-based referring expressions (single, multi, non) for a synthetic dataset 70 | and write the results to OVDG-style jsonl. 71 | 72 | Key design changes vs. original script 73 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 74 | 1. **One LLM call per image** – All nine buckets of expressions are requested in **one** chat completion by 75 | asking the model for a structured JSON block. This halves token usage and latency. 76 | 2. **Configurable question counts** – Numbers per bucket are read from a JSON/YAML config or CLI flags. 77 | 3. **Cleaner prompt & examples** – Prompt now includes user-provided examples, and explicit instructions to: 78 | - Use `category`, `short_phrase`, and `features` fields when generating expressions. 79 | - Infer spatial relationships strictly from the provided `bbox` values (absolute and relative positions). 80 | 4. **Direct OVDG output** – Each question becomes an OVDG *region* entry and grouped under one 81 | caption per image. `tokens_positive` are computed automatically. 82 | 5. **Stateless retries** – Automatic validation + adaptive temperature retry loop keeps the pipeline robust. 83 | """ 84 | 85 | import argparse, json, re, time 86 | from pathlib import Path 87 | from typing import Any, Dict, List 88 | from tqdm import tqdm 89 | import openai 90 | import random 91 | from pydantic import BaseModel, Field 92 | from typing import List 93 | import os 94 | from multiprocessing import Pool 95 | import tempfile 96 | from google.cloud import storage 97 | 98 | # --- GCS helpers --- 99 | def parse_gcs_path(gcs_path: str): 100 | if not gcs_path.startswith("gs://"): 101 | raise ValueError(f"Not a valid GCS path: {gcs_path}") 102 | path = gcs_path[5:] 103 | bucket, _, blob = path.partition("/") 104 | return bucket, blob 105 | 106 | def download_from_gcs(gcs_path: str, local_path: str): 107 | bucket_name, blob_name = parse_gcs_path(gcs_path) 108 | client = storage.Client() 109 | bucket = client.bucket(bucket_name) 110 | blob = bucket.blob(blob_name) 111 | blob.download_to_filename(local_path) 112 | print(f"Downloaded {gcs_path} to {local_path}") 113 | 114 | def upload_to_gcs(local_path: str, gcs_path: str): 115 | bucket_name, blob_name = parse_gcs_path(gcs_path) 116 | client = storage.Client() 117 | bucket = client.bucket(bucket_name) 118 | blob = bucket.blob(blob_name) 119 | blob.upload_from_filename(local_path) 120 | print(f"Uploaded {local_path} to {gcs_path}") 121 | 122 | class Expression(BaseModel): 123 | q: str = Field(..., description="The referring‐expression string") 124 | ids: List[int] = Field(..., description="List of segment IDs referenced") 125 | 126 | class ObjectExpressions(BaseModel): 127 | attribute: List[Expression] = Field(..., description="Attribute‐based expressions") 128 | spatial: List[Expression] = Field(..., description="Spatial‐based expressions") 129 | reasoning: List[Expression] = Field(..., description="Reasoning‐based expressions") 130 | 131 | class RefExpressions(BaseModel): 132 | single_object: ObjectExpressions = Field(..., description="Expressions targeting exactly one object") 133 | multi_object: ObjectExpressions = Field(..., description="Expressions targeting multiple objects") 134 | 135 | ##################################################################### 136 | # -------------------------- CONFIG --------------------------------# 137 | ##################################################################### 138 | DEFAULT_NUM_Q = { 139 | "single": {"attribute": 6, "spatial": 6, "reasoning": 6}, 140 | "multi": {"attribute": 3, "spatial": 3, "reasoning": 3}, 141 | } 142 | 143 | MAX_RETRIES = 10 144 | SLEEP_BETWEEN_RETRIES = 10 # seconds 145 | 146 | ##################################################################### 147 | # ----------------------- EXAMPLES --------------------------------# 148 | ##################################################################### 149 | EXAMPLE_SEGMENTS = ''' 150 | [ID: 7377552] tongs | short_phrase: tongs with rough iron texture | features: [] | description: tongs with a rough iron texture, painted in old bronze | bbox: [318, 535, 128, 282] 151 | [ID: 10569372] bath_towel | short_phrase: bath towel with tribal flair | features: [] | description: a bath_towel with geometric tribal flair in coppery tones | bbox: [474, 10, 493, 879] 152 | [ID: 2187630] Canned | short_phrase: cylindrical can of recycled aluminum | features: [] | description: A cylindrical can made of recycled aluminum. | bbox: [376, 217, 102, 247] 153 | [ID: 10733385] shovel | short_phrase: shovel with gleaming blade | features: [] | description: a shovel with a blade that gleams like polished alabaster | bbox: [424, 726, 48, 184] 154 | [ID: 519546] knitting_needle | short_phrase: knitting needle with glossy finish | features: [] | description: a knitting_needle with a glossy, clear finish and a spiral ridge | bbox: [0, 80, 97, 125] 155 | [ID: 4995055] strap | short_phrase: slim clear strap with blue stripe | features: [] | description: a slim, clear strap with a spray-painted blue stripe | bbox: [725, 178, 54, 60] 156 | [ID: 9339368] teakettle | short_phrase: teakettle with glass body | features: [] | description: A teakettle with a round glass body and a charming, twisted copper handle. | bbox: [324, 789, 103, 117] 157 | [ID: 8109537] cushion | short_phrase: hunter green leather cushion | features: [] | description: A hunter green, sleek leather cushion. | bbox: [684, 219, 89, 79] 158 | [ID: 4123758] raspberry | short_phrase: glossy raspberry with maroon tinge | features: [] | description: a glossy raspberry with a subtle maroon tinge | bbox: [183, 324, 134, 140] 159 | [ID: 9903309] dropper | short_phrase: dropper with matte black body | features: [] | description: A dropper with a matte black body and a glossy dropper tip | bbox: [204, 868, 93, 122] 160 | [ID: 2998739] snowmobile | short_phrase: snowmobile with white surface | features: [] | description: a snowmobile with a glossy white surface decorated with longitudinal red strips | bbox: [898, 451, 63, 38] 161 | [ID: 13570439] box | short_phrase: box with bold stripes and smiley | features: [] | description: a box painted in bold stripes with a quirky smiley face | bbox: [305, 837, 46, 42] 162 | [ID: 232766] ram_animal | short_phrase: compact ram with patchwork wool | features: [] | description: a compact ram boasting an intricate pattern of color on its wool, resembling patchwork | bbox: [152, 704, 35, 36] 163 | [ID: 15801147] birthday_card | short_phrase: cheerful pirate ship with map | features: [] | description: A birthday_card featuring a cheerful pirate ship with a colorful map. | bbox: [34, 924, 45, 46] 164 | [ID: 8631767] Egg_tart | short_phrase: marigold custard egg tart | features: [] | description: An egg tart with a marigold-colored custard that slightly spills over around the cornflower-blue border. | bbox: [442, 207, 43, 44] 165 | [ID: 12491521] cornbread | short_phrase: crispy cornbread with golden flecks | features: [] | description: A crispy slice of cornbread, with a myriad of shimmering, golden flecks and patches. | bbox: [191, 621, 42, 42] 166 | [ID: 7054199] wooden_spoon | short_phrase: stout wooden spoon for dough | features: [] | description: a stout, thick wooden_spoon with a hefty feel, perfect for tackling hefty dough mixtures | bbox: [207, 292, 46, 40] 167 | [ID: 2379817] motor | short_phrase: tiny pink motor with meta plates | features: [] | description: a tiny, delicate motor painted in a pastel pink with tiny meta plates | bbox: [95, 812, 40, 25] 168 | ''' 169 | 170 | EXAMPLE_PROMPT = ''' 171 | { 172 | "single_object": { 173 | "attribute": [ 174 | { "q": "The glossy raspberry with maroon tinge", "ids": [4123758] }, 175 | { "q": "The slim clear strap with blue stripe", "ids": [4995055] }, 176 | { "q": "The cylindrical can of recycled aluminum", "ids": [2187630] }, 177 | { "q": "The hunter green leather cushion", "ids": [8109537] } 178 | ], 179 | "spatial": [ 180 | { "q": "The knitting needle with glossy finish on the far left", "ids": [519546] }, 181 | { "q": "The snowmobile with white surface on the far right", "ids": [2998739] } 182 | ], 183 | "reasoning": [ 184 | { "q": "The teakettle with glass body below the shovel with the gleaming blade", "ids": [9339368] }, 185 | { "q": "The box with bold stripes and smiley to the left of the snowmobile with white surface", "ids": [13570439] }, 186 | { "q": "The glossy raspberry with maroon tinge to the left of the cylindrical can of recycled aluminum", "ids": [4123758] }, 187 | { "q": "The slim clear strap with blue stripe on the bath towel with tribal flair", "ids": [4995055] }, 188 | { "q": "The stout wooden spoon for dough above the shovel with the gleaming blade", "ids": [7054199] } 189 | ] 190 | }, 191 | "multi_object": { 192 | "attribute": [ 193 | { "q": "All the objects with a glossy finish", "ids": [519546, 4123758] }, 194 | { "q": "All the striped objects", "ids": [13570439, 2998739] }, 195 | { "q": "All the containers", "ids": [2187630, 13570439] }, 196 | ], 197 | "spatial": [ 198 | { 199 | "q": "All the objects above the horizontal midpoint of the image", 200 | "ids": [10569372, 2187630, 4995055, 8109537, 4123758, 8631767, 7054199, 519546] 201 | }, 202 | { 203 | "q": "All the objects that span the central vertical band of the image", 204 | "ids": [7054199, 4123758, 2187630, 7377552, 2998739] 205 | }, 206 | { 207 | "q": "All the objects to the left of center and above the shovel with gleaming blade", 208 | "ids": [7377552, 2187630, 519546, 4123758, 7054199, 13570439, 232766, 2379817, 8631767] 209 | }, 210 | { 211 | "q": "All the objects surrounding the cylindrical can of recycled aluminum", 212 | "ids": [7377552, 7054199, 8631767, 4123758] 213 | } 214 | ], 215 | "reasoning": [ 216 | { "q": "All the metallic objects to the left of the bath towel with tribal flair", "ids": [7377552, 2187630, 2379817] }, 217 | { "q": "All the objects with a glossy finish above the box with bold stripes and smiley", "ids": [519546, 4123758] }, 218 | { "q": "All the containers to the right of the knitting needle with glossy finish", "ids": [2187630, 13570439] }, 219 | ] 220 | } 221 | } 222 | ''' 223 | 224 | ##################################################################### 225 | # ----------------------- PROMPT UTILS -----------------------------# 226 | ##################################################################### 227 | OBJECT_TEMPLATE = "[ID: {id}] {category} | short_phrase: {short_phrase} | features: {features} | description: {description} | bbox: {bbox}" 228 | 229 | SCHEMA_SNIPPET = ''' 230 | ### JSON schema you MUST return 231 | { 232 | "single_object": {"attribute": [{"q": str, "ids": [int]}], "spatial": [{"q": str, "ids": [int]}], "reasoning": [{"q": str, "ids": [int]}]}, 233 | "multi_object": { … same keys, but ids lists contain 2+ ints … } 234 | } 235 | ''' 236 | 237 | 238 | def build_prompt(segments: List[Dict[str, Any]], num_q: Dict[str, Dict[str, int]]) -> str: 239 | # Build object table 240 | table = "\n".join(OBJECT_TEMPLATE.format(**seg) for seg in segments) 241 | # Build count instructions 242 | counts_text = "\n".join(f"- {scope}/{kind}: {cnt} expressions" for scope, kinds in num_q.items() for kind, cnt in kinds.items()) 243 | 244 | # Combine prompt 245 | prompt = f""" 246 | You are a referring expression detection data generator. I will provide you with a list of objects WITHIN an IMAGE, and you will generate referring expressions similar to RefCOCO, RefCOCO+, RefCOCOg, and GrefCOCO. 247 | 248 | We categorized referring expressions into 3 types: attribute-based, spatial-based, and reasoning-based. 249 | **Requirements**: 250 | - Attribute-based: ask about `features`, `category`, or `short_phrase` of exactly one object. (e.g. The white dog) 251 | - Spatial-based: infer absolute or relative positions strictly from the `bbox` values (e.g. left/right, above/below, center). (e.g. The dog left to the people with brown shirt) 252 | - Reasoning-based: combine features, short_phrase, category and spatial bbox relationships between objects. (e.g. The whilte animal left to the person with brown shirt) 253 | - Use `short_phrase` or `features` preferentially to refer to objects; also can use `category` with some features to refer it. 254 | - Return ONE JSON block matching the schema exactly, with **exactly** the requested counts per bucket. No extra keys. 255 | 256 | For each referring expressions, we have 3 types of returning objects: 257 | 1. **Single object**: Expression refers to exactly one object in the image. (e.g. The white dog) 258 | 2. **Multi-object**: Expression refers to 2 or more objects in the image. (e.g. All the white dog in the images) 259 | 260 | We have categorized referring expressions into three types: attribute-based, spatial-based, and reasoning-based. 261 | 262 | Requirements: 263 | - Attribute-based: Refer to exactly one object by its features, category, or short_phrase (e.g., "the white dog"). 264 | - Spatial-based: Infer absolute or relative positions strictly from bbox values (e.g., "the dog to the left of the person with a brown shirt"). 265 | - Reasoning-based: Combine features, short_phrases, categories, and spatial bbox relationships between objects (e.g., "the white animal to the left of the person wearing a brown shirt"). 266 | - Use short_phrase or features preferentially to refer to objects; you may also use category together with features. 267 | - Return a single JSON block matching the schema exactly, containing exactly the requested counts per bucket; include no extra keys. 268 | 269 | For each referring expression, there are three reference types: 270 | 1. Single-object: The expression refers to exactly one object in the image (e.g., "the white dog"). 271 | 2. Multi-object: The expression refers to two or more objects in the image (e.g., "all the white dogs in the image"). 272 | 273 | 274 | 275 | # ### Example segments and expressions 276 | # Example annotation 277 | {EXAMPLE_SEGMENTS} 278 | 279 | # Example generated expression 280 | {EXAMPLE_PROMPT} 281 | 282 | 283 | ### Here is the objects in the image (Bounding Box is provided in XYWH COCO format, you should compare then in coco's way): 284 | {table} 285 | 286 | ### Counts to generate 287 | {counts_text} 288 | 289 | {SCHEMA_SNIPPET} 290 | """ 291 | return prompt 292 | 293 | 294 | ##################################################################### 295 | # ----------------------- OVDG HELPERS -----------------------------# 296 | ##################################################################### 297 | 298 | def ovdg_from_expressions(image: Dict[str, Any], qs: Dict[str, Any]) -> Dict[str, Any]: 299 | regions, caption_parts, cursor = [], [], 0 300 | # Map segment IDs to their bounding boxes 301 | id2bbox = {seg['id']: seg['bbox'] for seg in image['segments_info']} 302 | 303 | def add_region(text: str, ids: List[int]): 304 | nonlocal cursor 305 | start, end = cursor, cursor + len(text) - 1 306 | # Gather all bboxes for each refZrenced ID 307 | bboxes = [id2bbox.get(i, [0, 0, 0, 0]) for i in ids] 308 | regions.append({ 309 | 'bboxes': bboxes, 310 | 'phrase': text, 311 | 'tokens_positive': [[start, end]] 312 | }) 313 | caption_parts.append(text) 314 | cursor += len(text) 315 | 316 | # Process both single-object and multi-object queries 317 | for scope in ('single_object', 'multi_object'): 318 | for kind in ('attribute', 'spatial', 'reasoning'): 319 | for e in qs[scope][kind]: 320 | # Strip punctuation and add region 321 | add_region(e['q'].rstrip('?. '), e['ids']) 322 | 323 | return { 324 | 'filename': image.get('file_name', ''), 325 | 'height': image.get('height', 1024), 326 | 'width': image.get('width', 1024), 327 | 'grounding': { 328 | 'caption': '. '.join(caption_parts) + '. ', 329 | 'regions': regions 330 | } 331 | } 332 | 333 | 334 | ##################################################################### 335 | # -------------------- MAIN PIPELINE -------------------------------# 336 | ##################################################################### 337 | def infer_expressions(client: openai.Client, prompt: str) -> Dict[str, Any]: 338 | from pydantic import ValidationError 339 | 340 | attempt = 0 341 | parsed: RefExpressions | None = None 342 | 343 | while attempt < MAX_RETRIES: 344 | try: 345 | # resp = client.beta.chat.completions.parse( 346 | # model=MODEL_NAME, 347 | # messages=[{'role':'user','content':prompt}], 348 | # response_format=RefExpressions 349 | # ) 350 | resp = client.beta.chat.completions.parse( 351 | model="Qwen/QwQ-32B-AWQ", 352 | messages=[{'role':'user','content':prompt}], 353 | response_format=RefExpressions, 354 | temperature=0.1 355 | ) 356 | # .choices[0].message.parsed is already a RefExpressions instance 357 | parsed = resp.choices[0].message.parsed 358 | break 359 | except ValidationError as ve: 360 | # JSON structure didn’t match—retry 361 | attempt += 1 362 | print(f"[Attempt {attempt}] validation error: {ve}") 363 | time.sleep(SLEEP_BETWEEN_RETRIES) 364 | except Exception as e: 365 | # LLM or network error—also retry 366 | attempt += 1 367 | print(f"[Attempt {attempt}] error: {e}") 368 | time.sleep(SLEEP_BETWEEN_RETRIES) 369 | 370 | if parsed is None: 371 | # Give back an empty—but correctly typed—fallback 372 | parsed = RefExpressions( 373 | single_object=ObjectExpressions(attribute=[], spatial=[], reasoning=[]), 374 | multi_object= ObjectExpressions(attribute=[], spatial=[], reasoning=[]) 375 | ) 376 | 377 | # Finally return a plain dict for the rest of your pipeline 378 | return parsed.dict() 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | # -------------------- Distributed Runner --------------------# 387 | 388 | def create_client(api_key: str): 389 | return openai.Client(base_url="http://localhost:8080/v1", api_key=api_key) 390 | 391 | def process_annotation(args): 392 | idx, ann, cache_dir, api_key, orig_output_dir = args 393 | cache_file = os.path.join(cache_dir, f"{idx}.json") 394 | output_dir = orig_output_dir 395 | cache_dir_name = os.path.basename(cache_dir) 396 | if cache_dir_name.startswith("cache_job_"): 397 | job_index = cache_dir_name[len("cache_job_") :] 398 | else: 399 | job_index = "0" 400 | remote_cache_path = os.path.join(orig_output_dir, f"cache_job_{job_index}", f"{idx}.json") 401 | # If output_dir is a GCS path, check remote cache existence 402 | if output_dir.startswith("gs://"): 403 | bucket_name, blob_name = parse_gcs_path(remote_cache_path) 404 | client_gcs = storage.Client() 405 | bucket = client_gcs.bucket(bucket_name) 406 | blob = bucket.blob(blob_name) 407 | if blob.exists(): 408 | # Remote cache exists, download and use it 409 | blob.download_to_filename(cache_file) 410 | with open(cache_file, 'r') as f: 411 | raw = json.load(f) 412 | chunks = chunk_result(raw) 413 | return idx, chunks 414 | # Only check local cache for non-GCS cases 415 | if not output_dir.startswith("gs://") and os.path.exists(cache_file) and os.path.getsize(cache_file) > 0: 416 | with open(cache_file, 'r') as f: 417 | raw = json.load(f) 418 | chunks = chunk_result(raw) 419 | return idx, chunks 420 | 421 | client = create_client(api_key) 422 | prompt = build_prompt(ann['segments_info'], DEFAULT_NUM_Q) 423 | 424 | # Use the shared infer_expressions function instead of inline logic 425 | qs = infer_expressions(client, prompt) 426 | 427 | valid_ids = {seg['id'] for seg in ann['segments_info']} 428 | for scope in ('single_object','multi_object'): 429 | for kind in ('attribute','spatial','reasoning'): 430 | for expr in qs[scope][kind]: 431 | expr['ids'] = [rid for rid in expr['ids'] if rid in valid_ids] 432 | 433 | result = ovdg_from_expressions(ann, qs) 434 | chunks = chunk_result(result) 435 | with open(cache_file, 'w') as f: 436 | json.dump(result, f) 437 | # After writing local cache, upload to GCS if output_dir is GCS 438 | if output_dir.startswith("gs://"): 439 | upload_to_gcs(local_path=cache_file, gcs_path=remote_cache_path) 440 | return idx, chunks 441 | 442 | def main(): 443 | parser = argparse.ArgumentParser(description="Distributed ODVG inference") 444 | parser.add_argument('total_jobs', type=int, help='Total number of jobs') 445 | parser.add_argument('job_index', type=int, help='Index of this job (0-based)') 446 | parser.add_argument('input_file', type=str, help='Path to input JSONL annotations') 447 | parser.add_argument('output_dir', type=str, help='Directory for outputs and cache') 448 | parser.add_argument('--api_key', type=str, default="placeholder", help='OpenAI API key') 449 | parser.add_argument('--num_workers', type=int, default=300, help='Number of parallel processes') 450 | args = parser.parse_args() 451 | 452 | # --- GCS input/output support --- 453 | # Download input file if on GCS 454 | is_gcs_input = args.input_file.startswith("gs://") 455 | if is_gcs_input: 456 | local_input = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl").name 457 | download_from_gcs(args.input_file, local_input) 458 | args.input_file = local_input 459 | 460 | # Prepare local output directory (and flag for later upload) 461 | is_gcs_output = args.output_dir.startswith("gs://") 462 | if is_gcs_output: 463 | local_output_dir = tempfile.mkdtemp() 464 | else: 465 | local_output_dir = args.output_dir 466 | 467 | api_key = args.api_key or os.environ.get('OPENAI_API_KEY') 468 | if not api_key: 469 | raise ValueError("OpenAI API key must be provided via --api_key or OPENAI_API_KEY") 470 | 471 | os.makedirs(local_output_dir, exist_ok=True) 472 | cache_dir = os.path.join(local_output_dir, f"cache_job_{args.job_index}") 473 | os.makedirs(cache_dir, exist_ok=True) 474 | 475 | annotations = [] 476 | annotations = json.load(open(args.input_file)).get('annotations', []) 477 | 478 | total = len(annotations) 479 | per_job = total // args.total_jobs 480 | start = args.job_index * per_job 481 | end = total if args.job_index == args.total_jobs-1 else start + per_job 482 | partition = annotations[start:end] 483 | 484 | tasks = [(i, ann, cache_dir, api_key, args.output_dir) for i, ann in enumerate(partition)] 485 | with Pool(processes=args.num_workers) as pool: 486 | results = list(tqdm(pool.imap_unordered(process_annotation, tasks), total=len(tasks))) 487 | 488 | # results is a list of (idx, chunks), where chunks is a list of dicts 489 | results.sort(key=lambda x: x[0]) 490 | output_path = os.path.join(local_output_dir, f"job_{args.job_index}.jsonl") 491 | with open(output_path, 'w') as out_f: 492 | for _, chunks in results: 493 | if chunks is None: 494 | continue 495 | for entry in chunks: 496 | out_f.write(json.dumps(entry) + '\n') 497 | 498 | 499 | # Upload result back to GCS if requested 500 | if is_gcs_output: 501 | gcs_dest = args.output_dir.rstrip("/") + f"/job_{args.job_index}.jsonl" 502 | upload_to_gcs(output_path, gcs_dest) 503 | 504 | if __name__ == '__main__': 505 | main() 506 | -------------------------------------------------------------------------------- /segmentation_generator/segmentataion_synthesizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .generate_image_and_mask import generate_image_and_mask 3 | from segment_layout.existing_layout_generator import ExistingLayoutGenerator 4 | from segment_layout.fine_grained_bbox_layout_generator import FineGrainedBoundingBoxLayoutGenerator 5 | from PIL import Image 6 | 7 | def generate_random_color_and_id(rng): 8 | # Generate random values for R, G, and B between 0 and 255 9 | R = rng.integers(0, 256) 10 | G = rng.integers(0, 256) 11 | B = rng.integers(0, 256) 12 | 13 | # Calculate the color id based on the formula provided 14 | color_id = R + G * 256 + B * (256**2) 15 | 16 | # Return the color and the id 17 | color = (R, G, B) 18 | return color, color_id 19 | 20 | 21 | class BaseSegmentationSynthesizer(): 22 | synthesize_method = None 23 | def __init__(self, data_manager, save_path, random_seed) -> None: 24 | if random_seed is None: 25 | random_seed = 42 # use the default seed 26 | self.rng = np.random.default_rng(random_seed) 27 | self.data_manager = data_manager 28 | self.save_path = save_path 29 | 30 | def random_position(self, width, height): 31 | x = self.rng.integers(0, width) 32 | y = self.rng.integers(0, height) 33 | return (x, y) 34 | 35 | def random_augmentation(self): 36 | # Randomly select a scale factor between 0.5 and 1.5, 37 | # a boolean flag for horizontal flip (50% chance), 38 | # a boolean flag for vertical flip (50% chance), 39 | # and a rotation angle between -180 and 180 degrees. 40 | scale = self.rng.uniform(0.5, 1.5) 41 | flip_horizontal = bool(self.rng.integers(0, 2)) 42 | flip_vertical = bool(self.rng.integers(0, 2)) 43 | rotate = int(self.rng.integers(-180, 180)) 44 | return {"scale": scale, "flip_horizontal": flip_horizontal, "flip_vertical": flip_vertical, "rotate": rotate} 45 | 46 | 47 | 48 | def category_to_id(self, category): 49 | if category.startswith("COCO2017"): 50 | return int(category.split("_")[-1]) 51 | if category in self.data_manager.unified_object_categories_to_idx: 52 | return self.data_manager.unified_object_categories_to_idx[category] 53 | else: 54 | raise ValueError(f"Category {category} not found in the unified object categories.") 55 | 56 | def generate(self, image_metadata, resize_mode, containSmallObjectMask=False): 57 | # start_time = time.time() 58 | if containSmallObjectMask: 59 | image, mask, small_object_mask_np = generate_image_and_mask(self.data_manager, image_metadata, resize_mode=resize_mode, containSmallObjectMask=True) 60 | output = {"image_metadata": image_metadata, "image": image, "mask": mask, "small_object_mask_np": small_object_mask_np} 61 | else: 62 | image, mask = generate_image_and_mask(self.data_manager, image_metadata, resize_mode=resize_mode, containSmallObjectMask=False) 63 | output = {"image_metadata": image_metadata, "image": image, "mask": mask} 64 | 65 | # elapsed_time = time.time() - start_time # compute elapsed time 66 | # print(f"generate_image_and_mask took {elapsed_time:.6f} seconds") 67 | 68 | return output 69 | 70 | 71 | def generate_with_coco_panoptic_format(self, image_metadata, resize_mode="fit", containRGBA=False, containBbox=True, containCategory=True, containSmallObjectMask=False): 72 | output = self.generate(image_metadata, resize_mode, containSmallObjectMask) 73 | image, mask, metadata = output["image"], output["mask"], output["image_metadata"] 74 | 75 | segment_id_to_color = {} 76 | used_colors = set() 77 | mask = np.array(mask) 78 | height, width = mask.shape 79 | 80 | coco_mask = np.zeros((height, width, 3), dtype=np.uint8) 81 | segments_info = [] 82 | 83 | for object_info in metadata["objects"]: 84 | segment_id = object_info["segment_id"] 85 | if segment_id != 0: 86 | while True: 87 | unique_color, unique_id = generate_random_color_and_id(self.rng) 88 | if unique_color not in used_colors: 89 | used_colors.add(unique_color) # Mark this color as used 90 | break 91 | 92 | if segment_id not in segment_id_to_color: 93 | segment_id_to_color[segment_id] = (unique_color, unique_id) 94 | coco_mask[mask == segment_id] = segment_id_to_color[segment_id][0] 95 | 96 | if containBbox: 97 | rows, cols = np.where(mask == segment_id) 98 | if rows.size and cols.size: 99 | x_min, y_min = int(cols.min()), int(rows.min()) 100 | x_max, y_max = int(cols.max()), int(rows.max()) 101 | # COCO uses [x, y, width, height] format. 102 | bbox = [x_min, y_min, x_max - x_min + 1, y_max - y_min + 1] 103 | else: 104 | bbox = [0, 0, 0, 0] 105 | else: 106 | bbox = [0, 0, 0, 0] 107 | 108 | # iscrowd, bbox, area are not used in this case, so we set them to dummy values 109 | segments_info.append({ 110 | "id": unique_id, 111 | "category_id": self.category_to_id(object_info['object_metadata']["category"]) if containCategory else 1, 112 | "iscrowd": 0, 113 | "bbox": bbox, 114 | "area": 1, 115 | }) 116 | 117 | output["coco_mask"] = coco_mask 118 | output["segments_info"] = segments_info 119 | if containRGBA: 120 | alpha_channel = (mask > 0).astype(np.uint8) * 255 121 | image_rgba = np.dstack((image, alpha_channel)) 122 | output["image_rgba"] = image_rgba 123 | if containSmallObjectMask: 124 | alpha_channel_small_obj = (output['small_object_mask_np'] > 0).astype(np.uint8) * 255 125 | image_rgba_small_obj = np.dstack((image, alpha_channel_small_obj)) 126 | output['image_rgba_small_object'] = image_rgba_small_obj 127 | return output 128 | 129 | 130 | def generate_with_unified_format(self, image_metadata, resize_mode="fit", containRGBA=False, containBbox=True, containCategory=True, containSmallObjectMask=False): 131 | output = self.generate(image_metadata, resize_mode, containSmallObjectMask) 132 | image, mask, metadata = output["image"], output["mask"], output["image_metadata"] 133 | 134 | segment_id_to_color = {} 135 | used_colors = set() 136 | mask = np.array(mask) 137 | height, width = mask.shape 138 | 139 | coco_mask = np.zeros((height, width, 3), dtype=np.uint8) 140 | segments_info = [] 141 | 142 | for object_info in metadata["objects"]: 143 | segment_id = object_info["segment_id"] 144 | if segment_id != 0: 145 | while True: 146 | unique_color, unique_id = generate_random_color_and_id(self.rng) 147 | if unique_color not in used_colors: 148 | used_colors.add(unique_color) # Mark this color as used 149 | break 150 | 151 | if segment_id not in segment_id_to_color: 152 | segment_id_to_color[segment_id] = (unique_color, unique_id) 153 | coco_mask[mask == segment_id] = segment_id_to_color[segment_id][0] 154 | 155 | if containBbox: 156 | rows, cols = np.where(mask == segment_id) 157 | if rows.size and cols.size: 158 | x_min, y_min = int(cols.min()), int(rows.min()) 159 | x_max, y_max = int(cols.max()), int(rows.max()) 160 | # COCO uses [x, y, width, height] format. 161 | bbox = [x_min, y_min, x_max - x_min + 1, y_max - y_min + 1] 162 | else: 163 | bbox = [0, 0, 0, 0] 164 | else: 165 | bbox = [0, 0, 0, 0] 166 | 167 | # iscrowd, bbox, area are not used in this case, so we set them to dummy values 168 | segments_info.append({ 169 | "id": unique_id, 170 | "category_id": self.category_to_id(object_info['object_metadata']["category"]) if containCategory else 1, 171 | "iscrowd": 0, 172 | "bbox": bbox, 173 | "category": object_info['object_metadata']["category"], 174 | "sub_category": object_info['object_metadata']["sub_category"], 175 | "description": object_info['object_metadata']["description"], 176 | "features": object_info['object_metadata']["features"], 177 | "short_phrase": object_info['object_metadata']["short_phrase"], 178 | "area": 1, 179 | }) 180 | 181 | output["coco_mask"] = coco_mask 182 | output["segments_info"] = segments_info 183 | if containRGBA: 184 | alpha_channel = (mask > 0).astype(np.uint8) * 255 185 | image_rgba = np.dstack((image, alpha_channel)) 186 | output["image_rgba"] = image_rgba 187 | if containSmallObjectMask: 188 | alpha_channel_small_obj = (output['small_object_mask_np'] > 0).astype(np.uint8) * 255 189 | image_rgba_small_obj = np.dstack((image, alpha_channel_small_obj)) 190 | output['image_rgba_small_object'] = image_rgba_small_obj 191 | return output 192 | 193 | class RandomCenterPointSegmentationSynthesizer(BaseSegmentationSynthesizer): 194 | synthesize_method = "random" 195 | def __init__(self, data_manager, save_path, random_seed=None): 196 | super().__init__(data_manager, save_path, random_seed) 197 | 198 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False): # -> dict: 199 | image_metadata = {} 200 | image_metadata["width"] = width 201 | image_metadata["height"] = height 202 | image_metadata["number_of_objects"] = number_of_objects 203 | image_metadata["synthesize_method"] = self.synthesize_method 204 | 205 | if hasBackground: 206 | background_metadata = self.data_manager.get_random_background_metadata(self.rng) 207 | image_metadata["background"] = {"background_metadata": background_metadata} 208 | 209 | curr_segment_id = 1 210 | objects = [] 211 | for _ in range(number_of_objects): 212 | # Randomly select a segmentation from the data manager 213 | object_metadata = self.data_manager.get_random_object_metadata(self.rng) 214 | object_position = self.random_position(width, height) 215 | if dataAugmentation: 216 | augmentation = self.random_augmentation() 217 | objects.append({ 218 | "object_metadata": object_metadata, 219 | "object_position": object_position, 220 | "segment_id": curr_segment_id, 221 | "augmentation": augmentation 222 | }) 223 | else: 224 | objects.append({ 225 | "object_metadata": object_metadata, 226 | "object_position": object_position, 227 | "segment_id": curr_segment_id 228 | }) 229 | curr_segment_id += 1 230 | image_metadata["objects"] = objects 231 | return image_metadata 232 | 233 | class FineGrainedBoundingBoxSegmentationSynthesizer(BaseSegmentationSynthesizer): 234 | synthesize_method = "fine_grained_bbox" 235 | def __init__(self, data_manager, save_path, random_seed=None): 236 | super().__init__(data_manager, save_path, random_seed) 237 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator() 238 | 239 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False, considerArea=False): # -> dict: 240 | image_metadata = {} 241 | image_metadata["width"] = width 242 | image_metadata["height"] = height 243 | image_metadata["number_of_objects"] = number_of_objects 244 | image_metadata["synthesize_method"] = self.synthesize_method 245 | 246 | if hasBackground: 247 | background_metadata = self.data_manager.get_random_background_metadata(self.rng) 248 | image_metadata["background"] = {"background_metadata": background_metadata} 249 | 250 | curr_segment_id = 1 251 | objects = [] 252 | # sort based on the original size 253 | for _ in range(number_of_objects): 254 | # Randomly select a segmentation from the data manager 255 | object_metadata = self.data_manager.get_random_object_metadata(self.rng) 256 | if considerArea: 257 | # add area 258 | path = object_metadata['image_path'] 259 | with Image.open(path) as img: 260 | img_np = np.array(img) 261 | area = np.sum(img_np[:, :, 3] > 0) 262 | object_metadata['area'] = area 263 | 264 | if dataAugmentation: 265 | augmentation = self.random_augmentation() 266 | objects.append({ 267 | "object_metadata": object_metadata, 268 | "segment_id": curr_segment_id, 269 | "augmentation": augmentation 270 | }) 271 | else: 272 | objects.append({ 273 | "object_metadata": object_metadata, 274 | "segment_id": curr_segment_id 275 | }) 276 | curr_segment_id += 1 277 | 278 | num_bboxes = len(objects) 279 | # based on the COCO website 280 | num_large = int(round(num_bboxes * 0.24)) 281 | num_mid = int(round(num_bboxes * 0.34)) 282 | num_small = num_bboxes - num_large - num_mid 283 | object_positions = self.fine_grained_bbox_layout_generator.generate(num_large=num_large, num_mid=num_mid, 284 | num_small=num_small, width=width, height=height) 285 | if considerArea: 286 | bbox_with_area = [] 287 | for bbox in object_positions: 288 | # Calculate width and height from [x_min, y_min, x_max, y_max] 289 | width = bbox[2] - bbox[0] 290 | height = bbox[3] - bbox[1] 291 | bbox_area = width * height 292 | bbox_with_area.append({"bbox": bbox, "area": bbox_area}) 293 | 294 | # Sort objects based on the computed object area (descending order) 295 | objects_sorted = sorted(objects, key=lambda x: x['object_metadata']['area'], reverse=True) 296 | # Sort bounding boxes based on their computed area (descending order) 297 | bboxes_sorted = sorted(bbox_with_area, key=lambda x: x['area'], reverse=True) 298 | 299 | # Assign the sorted bounding boxes to the sorted objects 300 | for idx, obj in enumerate(objects_sorted): 301 | obj["object_position"] = bboxes_sorted[idx]["bbox"] 302 | image_metadata["objects"] = objects_sorted 303 | # print the match results 304 | for obj in objects_sorted: 305 | print(f"Object: {obj['object_metadata']['image_path']}, Area: {obj['object_metadata']['area']}") 306 | for bbox in bboxes_sorted: 307 | print(f"bbox: {bbox['bbox']}, Area: {bbox['area']}") 308 | print("in") 309 | else: 310 | for idx, obj in enumerate(objects): 311 | obj["object_position"] = object_positions[idx] 312 | image_metadata["objects"] = objects 313 | 314 | return image_metadata 315 | 316 | 317 | class FineGrainedBoundingBoxSingleCategorySegmentationSynthesizer(BaseSegmentationSynthesizer): 318 | synthesize_method = "fine_grained_bbox" 319 | def __init__(self, data_manager, save_path, random_seed=None): 320 | super().__init__(data_manager, save_path, random_seed) 321 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator() 322 | 323 | def sampling_metadata( 324 | self, 325 | width, 326 | height, 327 | number_of_objects, 328 | hasBackground=False, 329 | dataAugmentation=False, 330 | considerArea=False 331 | ): 332 | image_metadata = { 333 | "width": width, 334 | "height": height, 335 | "number_of_objects": number_of_objects, 336 | "synthesize_method": self.synthesize_method 337 | } 338 | 339 | if hasBackground: 340 | bg_meta = self.data_manager.get_random_background_metadata(self.rng) 341 | image_metadata["background"] = {"background_metadata": bg_meta} 342 | 343 | # First object: random to determine category 344 | first_meta = self.data_manager.get_random_object_metadata(self.rng) 345 | category = first_meta.get('category') 346 | if category is None: 347 | raise KeyError("First sampled object metadata has no 'category' field.") 348 | 349 | objects = [] 350 | curr_id = 1 351 | # Process first object 352 | if considerArea: 353 | path = first_meta['image_path'] 354 | with Image.open(path) as img: 355 | arr = np.array(img) 356 | first_meta['area'] = int(np.sum(arr[:, :, 3] > 0)) 357 | entry = {"object_metadata": first_meta, "segment_id": curr_id} 358 | if dataAugmentation: 359 | entry["augmentation"] = self.random_augmentation() 360 | objects.append(entry) 361 | curr_id += 1 362 | 363 | # Sample remaining objects of the same category 364 | for _ in range(number_of_objects - 1): 365 | obj_meta = self.data_manager.get_random_object_metadata_by_category(self.rng, category) 366 | if considerArea: 367 | path = obj_meta['image_path'] 368 | with Image.open(path) as img: 369 | arr = np.array(img) 370 | obj_meta['area'] = int(np.sum(arr[:, :, 3] > 0)) 371 | entry = {"object_metadata": obj_meta, "segment_id": curr_id} 372 | if dataAugmentation: 373 | entry["augmentation"] = self.random_augmentation() 374 | objects.append(entry) 375 | curr_id += 1 376 | 377 | # Compute bbox size splits per COCO proportions 378 | n = len(objects) 379 | n_large = int(round(n * 0.24)) 380 | n_mid = int(round(n * 0.34)) 381 | n_small = n - n_large - n_mid 382 | bboxes = self.fine_grained_bbox_layout_generator.generate( 383 | num_large=n_large, 384 | num_mid=n_mid, 385 | num_small=n_small, 386 | width=width, 387 | height=height 388 | ) 389 | 390 | if considerArea: 391 | boxes_with_area = [ 392 | {"bbox": box, "area": (box[2] - box[0]) * (box[3] - box[1])} 393 | for box in bboxes 394 | ] 395 | objects_sorted = sorted(objects, key=lambda o: o['object_metadata']['area'], reverse=True) 396 | boxes_sorted = sorted(boxes_with_area, key=lambda b: b['area'], reverse=True) 397 | for obj, box in zip(objects_sorted, boxes_sorted): 398 | obj['object_position'] = box['bbox'] 399 | image_metadata['objects'] = objects_sorted 400 | else: 401 | for obj, box in zip(objects, bboxes): 402 | obj['object_position'] = box 403 | image_metadata['objects'] = objects 404 | 405 | return image_metadata 406 | 407 | 408 | 409 | 410 | class ExistingLayoutSegmentationSynthesizer(BaseSegmentationSynthesizer): 411 | synthesize_method = "existing_layout" 412 | def __init__(self, data_manager, save_path, table_path, random_seed = None): 413 | super().__init__(data_manager, save_path, random_seed) 414 | self.existing_layout_generator = ExistingLayoutGenerator(table_path=table_path) 415 | 416 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=None, considerArea=False):# -> dict: 417 | image_metadata = {} 418 | image_metadata["width"] = width 419 | image_metadata["height"] = height 420 | image_metadata["number_of_objects"] = number_of_objects 421 | image_metadata["synthesize_method"] = self.synthesize_method 422 | # add background image background 423 | 424 | if hasBackground: 425 | background_metadata = self.data_manager.get_random_background_metadata(self.rng) 426 | image_metadata["background"] = {"background_metadata": background_metadata} 427 | 428 | curr_segment_id = 1 429 | objects = [] 430 | # sort based on the original size 431 | for _ in range(number_of_objects): 432 | # Randomly select a segmentation from the data manager 433 | object_metadata = self.data_manager.get_random_object_metadata(self.rng) 434 | if dataAugmentation: 435 | augmentation = self.random_augmentation() 436 | objects.append({ 437 | "object_metadata": object_metadata, 438 | "segment_id": curr_segment_id, 439 | "augmentation": augmentation 440 | }) 441 | else: 442 | objects.append({ 443 | "object_metadata": object_metadata, 444 | "segment_id": curr_segment_id 445 | }) 446 | curr_segment_id += 1 447 | 448 | segment_path_list = [obj['object_metadata']['image_path'] for obj in objects] 449 | if considerArea: 450 | # havn't implemented the area matching algorithm 451 | # object_areas = [obj['object_metadata']['area'] for obj in objects] 452 | object_areas = segment_path_list 453 | object_positions = self.existing_layout_generator.predict_wi_area(num_bboxes=len(objects), segment_path_list=segment_path_list, areas=object_areas, width=width, height=height) 454 | else: 455 | object_positions = self.existing_layout_generator.predict_wo_area(num_bboxes=len(objects), width=width, height=height) 456 | 457 | for idx, obj in enumerate(objects): 458 | obj["object_position"] = object_positions[idx] 459 | image_metadata["objects"] = objects 460 | 461 | return image_metadata 462 | 463 | 464 | class FineGrainedBoundingBoxSegmentationSynthesizer(BaseSegmentationSynthesizer): 465 | synthesize_method = "fine_grained_bbox" 466 | def __init__(self, data_manager, save_path, random_seed=None): 467 | super().__init__(data_manager, save_path, random_seed) 468 | self.fine_grained_bbox_layout_generator = FineGrainedBoundingBoxLayoutGenerator() 469 | 470 | def sampling_metadata(self, width, height, number_of_objects, hasBackground=False, dataAugmentation=False, considerArea=False): # -> dict: 471 | image_metadata = {} 472 | image_metadata["width"] = width 473 | image_metadata["height"] = height 474 | image_metadata["number_of_objects"] = number_of_objects 475 | image_metadata["synthesize_method"] = self.synthesize_method 476 | 477 | if hasBackground: 478 | background_metadata = self.data_manager.get_random_background_metadata(self.rng) 479 | image_metadata["background"] = {"background_metadata": background_metadata} 480 | 481 | curr_segment_id = 1 482 | objects = [] 483 | # sort based on the original size 484 | for _ in range(number_of_objects): 485 | # Randomly select a segmentation from the data manager 486 | object_metadata = self.data_manager.get_random_object_metadata(self.rng) 487 | if considerArea: 488 | # add area 489 | path = object_metadata['image_path'] 490 | with Image.open(path) as img: 491 | img_np = np.array(img) 492 | area = np.sum(img_np[:, :, 3] > 0) 493 | object_metadata['area'] = area 494 | 495 | if dataAugmentation: 496 | augmentation = self.random_augmentation() 497 | objects.append({ 498 | "object_metadata": object_metadata, 499 | "segment_id": curr_segment_id, 500 | "augmentation": augmentation 501 | }) 502 | else: 503 | objects.append({ 504 | "object_metadata": object_metadata, 505 | "segment_id": curr_segment_id 506 | }) 507 | curr_segment_id += 1 508 | 509 | num_bboxes = len(objects) 510 | # based on the COCO website 511 | num_large = int(round(num_bboxes * 0.24)) 512 | num_mid = int(round(num_bboxes * 0.34)) 513 | num_small = num_bboxes - num_large - num_mid 514 | object_positions = self.fine_grained_bbox_layout_generator.generate(num_large=num_large, num_mid=num_mid, 515 | num_small=num_small, width=width, height=height) 516 | if considerArea: 517 | bbox_with_area = [] 518 | for bbox in object_positions: 519 | # Calculate width and height from [x_min, y_min, x_max, y_max] 520 | width = bbox[2] - bbox[0] 521 | height = bbox[3] - bbox[1] 522 | bbox_area = width * height 523 | bbox_with_area.append({"bbox": bbox, "area": bbox_area}) 524 | 525 | # Sort objects based on the computed object area (descending order) 526 | objects_sorted = sorted(objects, key=lambda x: x['object_metadata']['area'], reverse=True) 527 | # Sort bounding boxes based on their computed area (descending order) 528 | bboxes_sorted = sorted(bbox_with_area, key=lambda x: x['area'], reverse=True) 529 | 530 | # Assign the sorted bounding boxes to the sorted objects 531 | for idx, obj in enumerate(objects_sorted): 532 | obj["object_position"] = bboxes_sorted[idx]["bbox"] 533 | image_metadata["objects"] = objects_sorted 534 | # print the match results 535 | for obj in objects_sorted: 536 | print(f"Object: {obj['object_metadata']['image_path']}, Area: {obj['object_metadata']['area']}") 537 | for bbox in bboxes_sorted: 538 | print(f"bbox: {bbox['bbox']}, Area: {bbox['area']}") 539 | else: 540 | for idx, obj in enumerate(objects): 541 | obj["object_position"] = object_positions[idx] 542 | image_metadata["objects"] = objects 543 | 544 | return image_metadata -------------------------------------------------------------------------------- /relighting_and_blending/inference.py: -------------------------------------------------------------------------------- 1 | # import debugpy 2 | # try: 3 | # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 4 | # debugpy.listen(("localhost", 9501)) 5 | # print("Waiting for debugger attach") 6 | # debugpy.wait_for_client() 7 | # except Exception as e: 8 | # pass 9 | 10 | #!/usr/bin/env python 11 | """ 12 | Modified version of inference_with_blending.py with GCS integration. 13 | Supports downloading input images (and related JSON/masks) from GCS and uploading outputs to GCS, 14 | similar to the provided flux pipeline code. 15 | """ 16 | 17 | import os 18 | import re 19 | import json 20 | import time 21 | import argparse 22 | import numpy as np 23 | from tqdm import tqdm 24 | import tempfile 25 | 26 | from PIL import Image 27 | from skimage.color import rgb2lab, lab2rgb 28 | import torch 29 | import safetensors.torch as sf 30 | 31 | # Diffusers and related libraries 32 | from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline 33 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler 34 | from diffusers.models.attention_processor import AttnProcessor2_0 35 | from transformers import CLIPTextModel, CLIPTokenizer 36 | from torch.hub import download_url_to_file 37 | 38 | # Custom functions from your previous code (e.g., blending utilities) 39 | from gradio_demo import ( 40 | hooked_unet_forward, 41 | encode_prompt_pair, 42 | pytorch2numpy, 43 | numpy2pytorch, 44 | resize_and_center_crop, 45 | resize_without_crop, 46 | ) 47 | 48 | # --- GCS Utility Functions --- 49 | from google.cloud import storage 50 | 51 | def parse_gcs_path(gcs_path): 52 | """ 53 | Given a GCS path in the form gs://bucket_name/path/to/file, 54 | returns a tuple (bucket_name, blob_path). 55 | """ 56 | if not gcs_path.startswith("gs://"): 57 | raise ValueError("Not a valid GCS path") 58 | path = gcs_path[5:] 59 | parts = path.split("/", 1) 60 | bucket = parts[0] 61 | blob = parts[1] if len(parts) > 1 else "" 62 | return bucket, blob 63 | 64 | def download_from_gcs(gcs_path, local_path): 65 | """ 66 | Downloads a file from a GCS path to a local file. 67 | """ 68 | bucket_name, blob_name = parse_gcs_path(gcs_path) 69 | client = storage.Client() 70 | bucket = client.bucket(bucket_name) 71 | blob = bucket.blob(blob_name) 72 | blob.download_to_filename(local_path) 73 | print(f"Downloaded {gcs_path} to {local_path}") 74 | 75 | def upload_to_gcs(local_path, gcs_path): 76 | """ 77 | Upload a local file to a GCS path. 78 | """ 79 | bucket_name, blob_name = parse_gcs_path(gcs_path) 80 | client = storage.Client() 81 | bucket = client.bucket(bucket_name) 82 | blob = bucket.blob(blob_name) 83 | blob.upload_from_filename(local_path) 84 | print(f"Uploaded {local_path} to {gcs_path}") 85 | 86 | def gcs_blob_exists(gcs_path): 87 | """ 88 | Check if a blob exists at the given GCS path. 89 | """ 90 | bucket_name, blob_name = parse_gcs_path(gcs_path) 91 | client = storage.Client() 92 | bucket = client.bucket(bucket_name) 93 | blob = bucket.blob(blob_name) 94 | return blob.exists() 95 | 96 | def list_gcs_files(gcs_path, suffix=""): 97 | """ 98 | List file names (relative to the prefix) from a GCS path with an optional suffix filter. 99 | """ 100 | bucket_name, prefix = parse_gcs_path(gcs_path) 101 | client = storage.Client() 102 | bucket = client.bucket(bucket_name) 103 | blobs = list(client.list_blobs(bucket, prefix=prefix)) 104 | file_names = [] 105 | for blob in blobs: 106 | if blob.name.endswith(suffix): 107 | # Remove the prefix portion for clarity (if needed) 108 | relative_name = blob.name[len(prefix):].lstrip("/") 109 | file_names.append(relative_name) 110 | return file_names 111 | 112 | def build_color_mask_gcs_path(data_path, file_id, original_folder="train", new_folder="panoptic_train"): 113 | """ 114 | Builds the GCS path for the color mask by replacing the folder name 115 | in the blob portion without altering the bucket name. 116 | """ 117 | if not data_path.startswith("gs://"): 118 | raise ValueError("Not a valid GCS path") 119 | # Remove the "gs://" prefix. 120 | path_without_prefix = data_path[5:] 121 | # Split into bucket and blob path. 122 | bucket, blob_path = path_without_prefix.split("/", 1) 123 | # Replace only the intended folder segment in blob_path. 124 | new_blob_path = blob_path.replace(original_folder, new_folder, 1) 125 | return f"gs://{bucket}/{new_blob_path}/{file_id}.png" 126 | 127 | # --- Diffusion and Relightening Setup --- 128 | 129 | from enum import Enum 130 | 131 | class BGSource(Enum): 132 | NONE = "None" 133 | LEFT = "Left Light" 134 | RIGHT = "Right Light" 135 | TOP = "Top Light" 136 | BOTTOM = "Bottom Light" 137 | 138 | # Set up Stable Diffusion components 139 | sd15_name = 'stablediffusionapi/realistic-vision-v51' 140 | tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer") 141 | text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder") 142 | vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae") 143 | unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet") 144 | 145 | # RMBG and UNet modifications 146 | from briarmbg import BriaRMBG 147 | rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4") 148 | 149 | with torch.no_grad(): 150 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, 151 | unet.conv_in.kernel_size, 152 | unet.conv_in.stride, 153 | unet.conv_in.padding) 154 | new_conv_in.weight.zero_() 155 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 156 | new_conv_in.bias = unet.conv_in.bias 157 | unet.conv_in = new_conv_in 158 | 159 | unet_original_forward = unet.forward 160 | unet.forward = hooked_unet_forward 161 | 162 | # Load model offset and merge with UNet 163 | model_path = './models/iclight_sd15_fc.safetensors' 164 | if not os.path.exists(model_path): 165 | download_url_to_file( 166 | url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', 167 | dst=model_path 168 | ) 169 | sd_offset = sf.load_file(model_path) 170 | sd_origin = unet.state_dict() 171 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 172 | unet.load_state_dict(sd_merged, strict=True) 173 | del sd_offset, sd_origin, sd_merged 174 | 175 | # Set device and send models to device 176 | device = torch.device('cuda') 177 | text_encoder = text_encoder.to(device=device, dtype=torch.float16) 178 | vae = vae.to(device=device, dtype=torch.bfloat16) 179 | unet = unet.to(device=device, dtype=torch.float16) 180 | rmbg = rmbg.to(device=device, dtype=torch.float32) 181 | 182 | # SDP attention processor 183 | unet.set_attn_processor(AttnProcessor2_0()) 184 | vae.set_attn_processor(AttnProcessor2_0()) 185 | 186 | # Define schedulers 187 | ddim_scheduler = DDIMScheduler( 188 | num_train_timesteps=1000, 189 | beta_start=0.00085, 190 | beta_end=0.012, 191 | beta_schedule="scaled_linear", 192 | clip_sample=False, 193 | set_alpha_to_one=False, 194 | steps_offset=1, 195 | ) 196 | euler_a_scheduler = EulerAncestralDiscreteScheduler( 197 | num_train_timesteps=1000, 198 | beta_start=0.00085, 199 | beta_end=0.012, 200 | steps_offset=1 201 | ) 202 | dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler( 203 | num_train_timesteps=1000, 204 | beta_start=0.00085, 205 | beta_end=0.012, 206 | algorithm_type="sde-dpmsolver++", 207 | use_karras_sigmas=True, 208 | steps_offset=1 209 | ) 210 | 211 | # Pipelines for text-to-image and image-to-image 212 | t2i_pipe = StableDiffusionPipeline( 213 | vae=vae, 214 | text_encoder=text_encoder, 215 | tokenizer=tokenizer, 216 | unet=unet, 217 | scheduler=dpmpp_2m_sde_karras_scheduler, 218 | safety_checker=None, 219 | requires_safety_checker=False, 220 | feature_extractor=None, 221 | image_encoder=None 222 | ) 223 | i2i_pipe = StableDiffusionImg2ImgPipeline( 224 | vae=vae, 225 | text_encoder=text_encoder, 226 | tokenizer=tokenizer, 227 | unet=unet, 228 | scheduler=dpmpp_2m_sde_karras_scheduler, 229 | safety_checker=None, 230 | requires_safety_checker=False, 231 | feature_extractor=None, 232 | image_encoder=None 233 | ) 234 | 235 | # --- Utility Functions for Image Parsing and Blending --- 236 | 237 | def parse_rgba(img, sigma=0.0): 238 | """ 239 | Given a RGBA image (as a NumPy array), returns the blended RGB image using the alpha channel. 240 | """ 241 | assert img.shape[2] == 4, "Input image must have 4 channels (RGBA)." 242 | rgb = img[:, :, :3] 243 | alpha = img[:, :, 3].astype(np.float32) / 255.0 244 | # set alpha to be like if it is transparent then 0, otherwise 1 245 | # alpha = np.where(alpha < 0.5, 0.0, 1.0) 246 | result = 127 + (rgb.astype(np.float32) - 127 + sigma) * alpha[:, :, None] 247 | # temporarily 248 | return result.clip(0, 255).astype(np.uint8), alpha 249 | 250 | def blend_images_with_mask_rank_sigmoid(old_image, new_image, color_mask, alpha_min=0.3, alpha_max=0.9, steepness=10): 251 | """ 252 | Blends new_image into old_image according to a color mask using a sigmoid function based on segment area. 253 | """ 254 | if not isinstance(old_image, Image.Image): 255 | old_image = Image.fromarray(old_image) 256 | if not isinstance(new_image, Image.Image): 257 | new_image = Image.fromarray(new_image) 258 | if not isinstance(color_mask, Image.Image): 259 | color_mask = Image.fromarray(color_mask) 260 | 261 | old_image = old_image.convert("RGBA") 262 | new_image = new_image.convert("RGBA") 263 | color_mask = color_mask.convert("RGBA") 264 | 265 | old_np = np.array(old_image).astype(np.float32) 266 | new_np = np.array(new_image).astype(np.float32) 267 | mask_np = np.array(color_mask).astype(np.float32) 268 | mask_rgb = mask_np[..., :3] 269 | total_pixels = mask_rgb.shape[0] * mask_rgb.shape[1] 270 | 271 | unique_colors = np.unique(mask_rgb.reshape(-1, 3), axis=0) 272 | unique_colors = [c for c in unique_colors if not np.allclose(c, [0, 0, 0])] 273 | 274 | color_to_norm_area = {} 275 | for color in unique_colors: 276 | region = ((mask_rgb[..., 0] == color[0]) & 277 | (mask_rgb[..., 1] == color[1]) & 278 | (mask_rgb[..., 2] == color[2])) 279 | area = np.sum(region) 280 | norm_area = area / total_pixels 281 | color_to_norm_area[tuple(color)] = norm_area 282 | 283 | segments_sorted = sorted(color_to_norm_area.items(), key=lambda x: x[1]) 284 | total_segments = len(segments_sorted) 285 | color_to_alpha = {} 286 | for rank, (color, norm_area) in enumerate(segments_sorted): 287 | normalized_rank = rank / (total_segments - 1) if total_segments > 1 else 0 288 | sigmoid_value = 1 / (1 + np.exp(steepness * (normalized_rank - 0.5))) 289 | alpha = alpha_min + (alpha_max - alpha_min) * sigmoid_value 290 | color_to_alpha[color] = alpha 291 | print(f"Segment color: {color}, Norm Area: {norm_area:.6f}, Rank: {rank}, Normalized Rank: {normalized_rank:.3f}, Alpha: {alpha}") 292 | 293 | # Convert images from RGBA numpy arrays to Lab color space for lightness-only blending 294 | rgb_old = old_np[..., :3] / 255.0 295 | rgb_new = new_np[..., :3] / 255.0 296 | lab_old = rgb2lab(rgb_old) 297 | lab_new = rgb2lab(rgb_new) 298 | lab_out = lab_old.copy() 299 | 300 | # Blend only the L channel per segment while preserving original a/b channels 301 | for color, alpha in color_to_alpha.items(): 302 | region = ((mask_rgb[..., 0] == color[0]) & 303 | (mask_rgb[..., 1] == color[1]) & 304 | (mask_rgb[..., 2] == color[2])) 305 | # Update lightness channel 306 | lab_out[..., 0][region] = alpha * lab_old[..., 0][region] + (1 - alpha) * lab_new[..., 0][region] 307 | # Lightly blend chroma channels (a and b) to avoid unrealistic color shifts 308 | lab_out[..., 1][region] = 0.25 * alpha * lab_old[..., 1][region] + (1 - 0.25 * alpha) * lab_new[..., 1][region] 309 | lab_out[..., 2][region] = 0.25 * alpha * lab_old[..., 2][region] + (1 - 0.25 * alpha) * lab_new[..., 2][region] 310 | # lab_out[..., 1][region] = lab_old[..., 1][region] 311 | # lab_out[..., 2][region] = lab_old[..., 2][region] 312 | lab_out[..., 1][region] = lab_new[..., 1][region] 313 | lab_out[..., 2][region] = lab_new[..., 2][region] 314 | # Convert Lab back to RGB 315 | rgb_out = lab2rgb(lab_out) 316 | rgb_out_uint8 = (rgb_out * 255).clip(0, 255).astype(np.uint8) 317 | blended_image = Image.fromarray(rgb_out_uint8) 318 | return blended_image 319 | 320 | # --- Inference Functions --- 321 | 322 | @torch.inference_mode() 323 | def process(input_fg, prompt, image_width, image_height, num_samples, 324 | seed, steps, a_prompt, n_prompt, cfg, 325 | highres_scale, highres_denoise, lowres_denoise, bg_source): 326 | bg_source = BGSource(bg_source) 327 | input_bg = None 328 | if bg_source == BGSource.NONE: 329 | pass 330 | elif bg_source == BGSource.LEFT: 331 | gradient = np.linspace(255, 0, image_width) 332 | image = np.tile(gradient, (image_height, 1)) 333 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 334 | elif bg_source == BGSource.RIGHT: 335 | gradient = np.linspace(0, 255, image_width) 336 | image = np.tile(gradient, (image_height, 1)) 337 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 338 | elif bg_source == BGSource.TOP: 339 | gradient = np.linspace(255, 0, image_height)[:, None] 340 | image = np.tile(gradient, (1, image_width)) 341 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 342 | elif bg_source == BGSource.BOTTOM: 343 | gradient = np.linspace(0, 255, image_height)[:, None] 344 | image = np.tile(gradient, (1, image_width)) 345 | input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) 346 | else: 347 | raise ValueError("Wrong initial latent!") 348 | 349 | rng = torch.Generator(device=device).manual_seed(int(seed)) 350 | fg = resize_and_center_crop(input_fg, image_width, image_height) 351 | concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype) 352 | concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor 353 | conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt) 354 | 355 | if input_bg is None: 356 | latents = t2i_pipe( 357 | prompt_embeds=conds, 358 | negative_prompt_embeds=unconds, 359 | width=image_width, 360 | height=image_height, 361 | num_inference_steps=steps, 362 | num_images_per_prompt=num_samples, 363 | generator=rng, 364 | output_type='latent', 365 | guidance_scale=cfg, 366 | cross_attention_kwargs={'concat_conds': concat_conds}, 367 | ).images.to(vae.dtype) / vae.config.scaling_factor 368 | else: 369 | bg = resize_and_center_crop(input_bg, image_width, image_height) 370 | bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype) 371 | bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor 372 | latents = i2i_pipe( 373 | image=bg_latent, 374 | strength=lowres_denoise, 375 | prompt_embeds=conds, 376 | negative_prompt_embeds=unconds, 377 | width=image_width, 378 | height=image_height, 379 | num_inference_steps=int(round(steps / lowres_denoise)), 380 | num_images_per_prompt=num_samples, 381 | generator=rng, 382 | output_type='latent', 383 | guidance_scale=cfg, 384 | cross_attention_kwargs={'concat_conds': concat_conds}, 385 | ).images.to(vae.dtype) / vae.config.scaling_factor 386 | 387 | pixels = vae.decode(latents).sample 388 | pixels = pytorch2numpy(pixels) 389 | pixels = [resize_without_crop( 390 | image=p, 391 | target_width=int(round(image_width * highres_scale / 64.0) * 64), 392 | target_height=int(round(image_height * highres_scale / 64.0) * 64)) 393 | for p in pixels] 394 | pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) 395 | latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor 396 | latents = latents.to(device=unet.device, dtype=unet.dtype) 397 | image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8 398 | fg = resize_and_center_crop(input_fg, image_width, image_height) 399 | concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype) 400 | concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor 401 | latents = i2i_pipe( 402 | image=latents, 403 | strength=highres_denoise, 404 | prompt_embeds=conds, 405 | negative_prompt_embeds=unconds, 406 | width=image_width, 407 | height=image_height, 408 | num_inference_steps=int(round(steps / highres_denoise)), 409 | num_images_per_prompt=num_samples, 410 | generator=rng, 411 | output_type='latent', 412 | guidance_scale=cfg, 413 | cross_attention_kwargs={'concat_conds': concat_conds}, 414 | ).images.to(vae.dtype) / vae.config.scaling_factor 415 | 416 | pixels = vae.decode(latents).sample 417 | return pytorch2numpy(pixels) 418 | 419 | @torch.inference_mode() 420 | def process_relight(input_fg, prompt, image_width, image_height, num_samples, 421 | seed, steps, a_prompt, n_prompt, cfg, 422 | highres_scale, highres_denoise, lowres_denoise, bg_source): 423 | input_fg, _ = parse_rgba(input_fg) 424 | results = process(input_fg, prompt, image_width, image_height, num_samples, 425 | seed, steps, a_prompt, n_prompt, cfg, 426 | highres_scale, highres_denoise, lowres_denoise, bg_source) 427 | return input_fg, results 428 | 429 | def adjust_dimensions(width, height, max_dim=1024, divisible_by=8): 430 | # For simplicity, we return a fixed dimension (or you can implement your resizing logic) 431 | return 1024, 1024 432 | 433 | # --- Main Inference Workflow with GCS Integration --- 434 | 435 | def main(args): 436 | data_path = args.dataset_path 437 | output_data_path = args.output_data_path 438 | illuminate_prompts_path = args.illuminate_prompts_path 439 | record_path = args.record_path 440 | 441 | # Determine if the dataset, index JSON, and output_data_path reside on GCS. 442 | input_on_gcs = data_path.startswith("gs://") 443 | output_on_gcs = output_data_path.startswith("gs://") 444 | index_on_gcs = args.index_json_path is not None and args.index_json_path.startswith("gs://") 445 | illuminate_on_gcs = illuminate_prompts_path.startswith("gs://") if illuminate_prompts_path else False 446 | 447 | # Create a local temporary directory to use when downloading files from GCS. 448 | temp_dir = tempfile.mkdtemp() 449 | 450 | # If the illumination prompts are on GCS, download them locally. 451 | if illuminate_on_gcs: 452 | local_illuminate_path = os.path.join(temp_dir, "illumination_prompt.json") 453 | download_from_gcs(illuminate_prompts_path, local_illuminate_path) 454 | illuminate_prompts_path = local_illuminate_path 455 | 456 | with open(illuminate_prompts_path, "r") as f: 457 | illuminate_prompts = json.load(f) 458 | 459 | records = {} 460 | split_index = args.split 461 | num_splits = args.num_splits 462 | 463 | # Prepare list of filenames based on index JSON if provided; otherwise list .png files in dataset. 464 | if args.index_json_path: 465 | # If index JSON is on GCS, download it. 466 | if index_on_gcs: 467 | local_index_path = os.path.join(temp_dir, "index.json") 468 | download_from_gcs(args.index_json_path, local_index_path) 469 | index_json_path = local_index_path 470 | else: 471 | index_json_path = args.index_json_path 472 | 473 | with open(index_json_path, 'r') as f: 474 | all_filenames = json.load(f) 475 | if not isinstance(all_filenames, list): 476 | raise ValueError("The index JSON file must contain a list of filenames.") 477 | splits = np.array_split(all_filenames, num_splits) 478 | split_filenames = list(splits[split_index]) 479 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images from index JSON.") 480 | else: 481 | if input_on_gcs: 482 | # List .png files from the GCS path. 483 | bucket_name, prefix = parse_gcs_path(data_path) 484 | files = list_gcs_files(data_path, suffix=".png") 485 | # Optionally, sort numerically if filenames are numbers. 486 | pattern = re.compile(r'^(\d+)\.png$') 487 | file_numbers = [] 488 | for f in files: 489 | m = pattern.match(os.path.basename(f)) 490 | if m: 491 | file_numbers.append((int(m.group(1)), f)) 492 | file_numbers.sort(key=lambda x: x[0]) 493 | sorted_filenames = [f for _, f in file_numbers] 494 | splits = np.array_split(sorted_filenames, num_splits) 495 | split_filenames = list(splits[split_index]) 496 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images from GCS.") 497 | else: 498 | # Local directory listing. 499 | all_files = [f for f in os.listdir(data_path) if f.endswith(".png")] 500 | pattern = re.compile(r'^(\d+)\.png$') 501 | filtered_files = [] 502 | for f in all_files: 503 | m = pattern.match(f) 504 | if m: 505 | numeric_value = int(m.group(1)) 506 | filtered_files.append((numeric_value, f)) 507 | filtered_files.sort(key=lambda x: x[0]) 508 | sorted_filenames = [f for _, f in filtered_files] 509 | splits = np.array_split(sorted_filenames, num_splits) 510 | split_filenames = list(splits[split_index]) 511 | print(f"Processing split {split_index + 1}/{num_splits} with {len(split_filenames)} images (local).") 512 | 513 | # Prepare the output destination. 514 | if not output_on_gcs: 515 | os.makedirs(output_data_path, exist_ok=True) 516 | 517 | # Process each file. 518 | for fg_name in tqdm(split_filenames, desc="Processing images"): 519 | # For input image, if residing on GCS, construct full GCS path and download to temp file. 520 | if input_on_gcs: 521 | full_fg_path = os.path.join(data_path, fg_name) 522 | local_fg_path = os.path.join(temp_dir, os.path.basename(fg_name)) 523 | download_from_gcs(full_fg_path, local_fg_path) 524 | else: 525 | local_fg_path = os.path.join(data_path, fg_name) 526 | 527 | # Open the foreground image. 528 | try: 529 | input_fg = np.array(Image.open(local_fg_path)) 530 | except Exception as e: 531 | print(f"Error opening image {local_fg_path}: {e}") 532 | continue 533 | 534 | # Determine output file name. 535 | file_id = os.path.splitext(os.path.basename(fg_name))[0] 536 | if output_on_gcs: 537 | output_blob = f"{file_id}.jpg" 538 | full_output_path = os.path.join(output_data_path, output_blob).replace("\\", "/") 539 | if gcs_blob_exists(full_output_path): 540 | print(f"Skipping '{fg_name}': Output blob already exists on GCS.") 541 | continue 542 | else: 543 | output_path = os.path.join(output_data_path, f"{file_id}.jpg") 544 | if os.path.exists(output_path): 545 | print(f"Skipping '{fg_name}': Output file '{output_path}' already exists.") 546 | continue 547 | 548 | # Dynamically adjust dimensions. 549 | orig_height, orig_width = input_fg.shape[:2] 550 | image_width, image_height = adjust_dimensions(orig_width, orig_height, max_dim=1024, divisible_by=8) 551 | print(f"Processing '{fg_name}': Adjusted dimensions: {image_width} x {image_height}") 552 | 553 | # Select a random prompt from illumination prompts. 554 | prompt = np.random.choice(illuminate_prompts) 555 | bg_source = np.random.choice([BGSource.NONE, BGSource.NONE, BGSource.NONE, BGSource.NONE, 556 | BGSource.LEFT, BGSource.RIGHT, BGSource.TOP, BGSource.BOTTOM]) 557 | seed = 123456 558 | steps = 25 559 | a_prompt = "not obvious objects in the background, best quality, don't significantly change foreground objects, keep its semantic meaning" 560 | n_prompt = "have obvious objects in the background, lowres, bad anatomy, bad hands, cropped, worst quality, change foreground objects, don't keep its semantic meaning" 561 | cfg = 2.0 562 | highres_scale = 1.0 563 | highres_denoise = 0.5 564 | lowres_denoise = 0.9 565 | num_samples = 1 566 | 567 | # Process relighting. 568 | input_fg, results = process_relight( 569 | input_fg=input_fg, 570 | prompt=prompt, 571 | image_width=image_width, 572 | image_height=image_height, 573 | num_samples=num_samples, 574 | seed=seed, 575 | steps=steps, 576 | a_prompt=a_prompt, 577 | n_prompt=n_prompt, 578 | cfg=cfg, 579 | highres_scale=highres_scale, 580 | highres_denoise=highres_denoise, 581 | lowres_denoise=lowres_denoise, 582 | bg_source=bg_source.value # Pass Enum string value if needed 583 | ) 584 | 585 | # For blending, determine the color mask path. 586 | if input_on_gcs: 587 | # Use the helper function to build the color mask path properly. 588 | color_mask_path = build_color_mask_gcs_path(data_path, file_id) 589 | local_mask_path = os.path.join(temp_dir, f"{file_id}_mask.png") 590 | download_from_gcs(color_mask_path, local_mask_path) 591 | else: 592 | color_mask_path = os.path.join(data_path.replace("train", "panoptic_train"), f"{file_id}.png") 593 | local_mask_path = color_mask_path 594 | 595 | try: 596 | color_mask = Image.open(local_mask_path) 597 | except Exception as e: 598 | print(f"Error opening color mask {local_mask_path}: {e}") 599 | continue 600 | 601 | blended_image = blend_images_with_mask_rank_sigmoid( 602 | old_image=results[0], 603 | new_image=input_fg, 604 | color_mask=color_mask 605 | ) 606 | 607 | # Save the output. 608 | if output_on_gcs: 609 | local_output_file = os.path.join(temp_dir, f"{file_id}.jpg") 610 | blended_image.save(local_output_file) 611 | upload_to_gcs(local_output_file, full_output_path) 612 | os.remove(local_output_file) 613 | else: 614 | blended_image.save(output_path) 615 | print(f"Saved relit image to '{output_path}'") 616 | 617 | # Record details. 618 | records[fg_name] = { 619 | "output_path": full_output_path if output_on_gcs else output_path, 620 | "prompt": prompt, 621 | "bg_source": bg_source.value, 622 | "seed": seed, 623 | "steps": steps, 624 | "cfg": cfg, 625 | "highres_scale": highres_scale, 626 | "highres_denoise": highres_denoise, 627 | "lowres_denoise": lowres_denoise 628 | } 629 | 630 | # (Optional) Save or update your records file. 631 | # if record_path: 632 | # if record_path.startswith("gs://"): 633 | # local_record_path = os.path.join(temp_dir, "record.json") 634 | # else: 635 | # local_record_path = record_path 636 | # 637 | # with open(local_record_path, 'w') as f: 638 | # json.dump(records, f, indent=4) 639 | # if record_path.startswith("gs://"): 640 | # upload_to_gcs(local_record_path, record_path) 641 | # os.remove(local_record_path) 642 | # print("Processing complete.") 643 | 644 | if __name__ == "__main__": 645 | def parse_args(): 646 | parser = argparse.ArgumentParser(description="Relight images using Stable Diffusion pipelines with GCS integration.") 647 | parser.add_argument('--dataset_path', type=str, required=True, 648 | help="Path to the segment dataset. Can be a local path or a GCS path (e.g., gs://bucket/path).") 649 | parser.add_argument('--output_data_path', type=str, required=True, 650 | help="Path to save the output data. Can be a local path or a GCS path.") 651 | parser.add_argument('--num_splits', type=int, default=1, help="Number of splits to create") 652 | parser.add_argument('--split', type=int, default=0, help="Split index to process (0-indexed)") 653 | parser.add_argument('--index_json_path', type=str, default=None, 654 | help="Path to the JSON file containing image filenames; supports GCS paths.") 655 | parser.add_argument('--illuminate_prompts_path', type=str, required=True, 656 | help="Path to the JSON file containing illumination prompts; supports GCS paths.") 657 | parser.add_argument('--record_path', type=str, default=None, 658 | help="Path to the JSON file where records are saved; supports GCS paths.") 659 | return parser.parse_args() 660 | 661 | args = parse_args() 662 | main(args) -------------------------------------------------------------------------------- /datasets/object_datasets.py: -------------------------------------------------------------------------------- 1 | from .base_datasets import BaseObjectDataset 2 | import os 3 | from tqdm import tqdm 4 | import pickle 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | import pandas as pd 7 | import json 8 | import re 9 | from shutil import copyfile 10 | from PIL import Image 11 | import PIL 12 | from .coco_categories import COCO_CATEGORIES 13 | def build_category_index(categories): 14 | return {category['name'].replace("_", "").replace(" ", ""): category['id'] for category in categories} 15 | 16 | COCO_CATEGORY_INDEX = build_category_index(COCO_CATEGORIES) 17 | 18 | 19 | def getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation): 20 | # getting the first value of the filtering_annotations as schema 21 | 22 | 23 | if universial_idx in filtering_annotations: 24 | return filtering_annotations[universial_idx] 25 | else: 26 | return default_annotation 27 | 28 | class ADE20KDataset(BaseObjectDataset): 29 | def __init__(self, dataset_path, filtering_annotations_path): 30 | dataset_name = "ADE20K" 31 | data_type = "panoptic" 32 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 33 | # get filtering annotation 34 | if self.filtering_annotations_path is not None: 35 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 36 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 37 | default_annotation = {key: False for key in default_annotation.keys()} 38 | else: 39 | filtering_annotations = None 40 | 41 | # get metadata_table 42 | index_counter = 0 43 | rows = [] 44 | for folder_name in os.listdir(self.dataset_path): 45 | folder_path = os.path.join(self.dataset_path, folder_name) 46 | if os.path.isdir(folder_path): 47 | for file_name in os.listdir(folder_path): 48 | if ".png" in file_name: 49 | image_path = os.path.join(self.dataset_path, folder_name, file_name) 50 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}" 51 | rows.append({ 52 | 'universal_idx': universial_idx, 53 | 'idx_for_curr_dataset': index_counter, 54 | 'category': folder_name, 55 | 'image_path': image_path, 56 | 'mask_path': image_path, # Assuming mask_path is the same as image_path 57 | 'dataset_name': dataset_name, 58 | 'data_type': data_type, 59 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None 60 | }) 61 | index_counter += 1 62 | self.metadata_table = pd.DataFrame(rows) 63 | self.categories = set(self.metadata_table['category']) 64 | 65 | class VOC2012Dataset(BaseObjectDataset): 66 | def __init__(self, dataset_path, filtering_annotations_path): 67 | dataset_name = "VOC2012" 68 | data_type = "panoptic" 69 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 70 | # get filtering annotation 71 | if self.filtering_annotations_path is not None: 72 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 73 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 74 | default_annotation = {key: False for key in default_annotation.keys()} 75 | else: 76 | filtering_annotations = None 77 | 78 | # get metadata_table 79 | index_counter = 0 80 | rows = [] 81 | for folder_name in os.listdir(self.dataset_path): 82 | folder_path = os.path.join(self.dataset_path, folder_name) 83 | if os.path.isdir(folder_path): 84 | for file_name in os.listdir(folder_path): 85 | if ".png" in file_name: 86 | image_path = os.path.join(self.dataset_path, folder_name, file_name) 87 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}" 88 | rows.append({ 89 | 'universal_idx': universial_idx, 90 | 'idx_for_curr_dataset': index_counter, 91 | 'category': folder_name, 92 | 'image_path': image_path, 93 | 'mask_path': image_path, # Assuming mask_path is the same as image_path 94 | 'dataset_name': dataset_name, 95 | 'data_type': data_type, 96 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None 97 | }) 98 | index_counter += 1 99 | self.metadata_table = pd.DataFrame(rows) 100 | self.categories = set(self.metadata_table['category']) 101 | 102 | 103 | 104 | # First, build an index mapping cleaned category names to the category data. 105 | # Now define the lookup function that uses the pre-built index. 106 | 107 | class COCO2017Dataset(BaseObjectDataset): 108 | def _load_coco_id(self, folder_name: str) -> dict: 109 | tokens = folder_name.strip().split() 110 | if not tokens: 111 | raise ValueError("Empty folder name provided.") 112 | # Remove trailing underscore and digits from the first token. 113 | cleaned_folder_name = re.sub(r'\d+', '', tokens[0]) 114 | cleaned_folder_name = cleaned_folder_name.replace("_", "") 115 | try: 116 | return f"COCO2017_{cleaned_folder_name}_{COCO_CATEGORY_INDEX[cleaned_folder_name]}" 117 | except KeyError: 118 | raise ValueError(f"Could not find COCO category for folder name '{folder_name}, cleaned to '{cleaned_folder_name}'.") 119 | 120 | def __init__(self, dataset_path, filtering_annotations_path=None, available_coco_image_path=None, coco_segment_to_image_path=None, dataset_name="COCO2017"): 121 | data_type = "panoptic" 122 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 123 | # get filtering annotation 124 | if self.filtering_annotations_path is not None: 125 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 126 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 127 | default_annotation = {key: False for key in default_annotation.keys()} 128 | else: 129 | filtering_annotations = None 130 | 131 | # add coco image limit, limit the source image that each segments coming from 132 | if available_coco_image_path is not None: 133 | available_coco_image_list = json.load(open(available_coco_image_path, "r")) 134 | print(f"Only consider {len(available_coco_image_list)} images from COCO2017") 135 | 136 | # get metadata_table 137 | index_counter = 0 138 | rows = [] 139 | for folder_name in os.listdir(self.dataset_path): 140 | folder_path = os.path.join(self.dataset_path, folder_name) 141 | if os.path.isdir(folder_path): 142 | for file_name in os.listdir(folder_path): 143 | if ".png" in file_name: 144 | 145 | # check if segment are not in the available_coco_image_list, then remove it from segments pool 146 | if available_coco_image_path is not None: 147 | correspond_image_name = file_name.split("_")[0] + ".jpg" 148 | if correspond_image_name not in available_coco_image_list: 149 | continue 150 | 151 | image_path = os.path.join(self.dataset_path, folder_name, file_name) 152 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}" 153 | rows.append({ 154 | 'universal_idx': universial_idx, 155 | 'idx_for_curr_dataset': index_counter, 156 | 'category': self._load_coco_id(folder_name), 157 | 'image_path': image_path, 158 | 'mask_path': image_path, # Assuming mask_path is the same as image_path 159 | 'dataset_name': dataset_name, 160 | 'data_type': data_type, 161 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None 162 | }) 163 | index_counter += 1 164 | self.metadata_table = pd.DataFrame(rows) 165 | self.categories = set(self.metadata_table['category']) 166 | 167 | 168 | class COCO2017FullDataset(COCO2017Dataset): 169 | def __init__(self, dataset_path, filtering_annotations_path=None, available_coco_image_path=None, coco_segment_to_image_path=None): 170 | super().__init__(dataset_path, filtering_annotations_path, available_coco_image_path, coco_segment_to_image_path, dataset_name="COCO_Full") 171 | 172 | 173 | 174 | # class SyntheticDataset(BaseObjectDataset): 175 | # def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None, dataset_name="SyntheticDatasetPlaceHolder"): 176 | # data_type = "panoptic" 177 | # super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 178 | # # get filtering annotation 179 | # if self.filtering_annotations_path is not None: 180 | # filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 181 | # default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 182 | # default_annotation = {key: False for key in default_annotation.keys()} 183 | # else: 184 | # filtering_annotations = None 185 | 186 | # if synthetic_annotation_path is not None: 187 | # synthetic_annotation = json.load(open(synthetic_annotation_path, "r")) 188 | 189 | # # get metadata_table 190 | # index_counter = 0 191 | # rows = [] 192 | 193 | # # Wrap the first layer iteration with tqdm for progress display 194 | # for folder_name in tqdm(os.listdir(self.dataset_path), desc="Processing top-level directories"): 195 | # folder_path = os.path.join(self.dataset_path, folder_name) 196 | # if os.path.isdir(folder_path): 197 | # for subfolder_name in os.listdir(folder_path): 198 | # subfolder_path = os.path.join(folder_path, subfolder_name) 199 | # if os.path.isdir(subfolder_path): 200 | # for file_name in os.listdir(subfolder_path): 201 | # if ".png" in file_name: 202 | # # check if segment are not in the available_coco_image_list, then remove it from segments pool 203 | # image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name) 204 | # image_subfolder_id = file_name.replace('.png', '').split("_")[-1] 205 | 206 | # universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}" 207 | # query_idx = f"{folder_name}_{subfolder_name}" 208 | 209 | # if query_idx in synthetic_annotation: 210 | # description = synthetic_annotation[query_idx]["description"] 211 | # short_phrase = synthetic_annotation[query_idx]["short_phrase"] 212 | # features = synthetic_annotation[query_idx]["features"] 213 | # rows.append({ 214 | # 'universal_idx': universial_idx, 215 | # 'query_idx': query_idx, 216 | # 'idx_for_curr_dataset': index_counter, 217 | # 'image_path': image_path, 218 | # 'mask_path': image_path, # Assuming mask_path is the same as image_path 219 | # 'dataset_name': dataset_name, 220 | # 'data_type': data_type, 221 | # 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None, 222 | # 'category': folder_name, 223 | # 'sub_category': subfolder_name, 224 | # 'description': description, 225 | # 'short_phrase': short_phrase, 226 | # 'features': features 227 | # }) 228 | # index_counter += 1 229 | # self.metadata_table = pd.DataFrame(rows) 230 | # self.categories = set(self.metadata_table['category']) 231 | 232 | 233 | class SyntheticDataset(BaseObjectDataset): 234 | def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None, dataset_name="SyntheticDatasetPlaceHolder", cache_path="metadata_cache.pkl"): 235 | data_type = "panoptic" 236 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 237 | 238 | # Load filtering annotations if provided 239 | if self.filtering_annotations_path is not None: 240 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 241 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 242 | default_annotation = {key: False for key in default_annotation.keys()} 243 | else: 244 | filtering_annotations = None 245 | default_annotation = None 246 | 247 | if synthetic_annotation_path is not None: 248 | synthetic_annotation = json.load(open(synthetic_annotation_path, "r")) 249 | else: 250 | synthetic_annotation = {} 251 | 252 | # Check if cache exists: 253 | if os.path.exists(cache_path): 254 | print(f"Loading metadata from cache: {cache_path}") 255 | with open(cache_path, "rb") as fp: 256 | self.metadata_table = pickle.load(fp) 257 | else: 258 | # If cache does not exist, create the metadata table 259 | self.metadata_table = self._build_metadata_table(synthetic_annotation, filtering_annotations, default_annotation, dataset_name) 260 | # Save to cache for future runs 261 | with open(cache_path, "wb") as fp: 262 | pickle.dump(self.metadata_table, fp) 263 | print(f"Metadata cache saved to: {cache_path}") 264 | 265 | # Create a set of categories for later use 266 | self.categories = set(self.metadata_table['category']) 267 | 268 | def _process_folder(self, folder_name, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): 269 | """ 270 | Process one top-level directory and return a list of rows (dicts) for that folder. 271 | """ 272 | index_counter = 0 # each folder can have an independent counter, we'll fix indices after merging if needed 273 | rows = [] 274 | folder_path = os.path.join(self.dataset_path, folder_name) 275 | if os.path.isdir(folder_path): 276 | for subfolder_name in os.listdir(folder_path): 277 | subfolder_path = os.path.join(folder_path, subfolder_name) 278 | if os.path.isdir(subfolder_path): 279 | for file_name in os.listdir(subfolder_path): 280 | if file_name.endswith(".png"): 281 | image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name) 282 | image_subfolder_id = file_name.replace('.png', '').split("_")[-1] 283 | 284 | universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}" 285 | query_idx = f"{folder_name}_{subfolder_name}" 286 | 287 | if query_idx in synthetic_annotation: 288 | description = synthetic_annotation[query_idx]["description"] 289 | short_phrase = synthetic_annotation[query_idx]["short_phrase"] 290 | features = synthetic_annotation[query_idx]["features"] 291 | # Get filtering annotation if exists 292 | filtering_ann = (getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) 293 | if filtering_annotations is not None else None) 294 | rows.append({ 295 | 'universal_idx': universial_idx, 296 | 'query_idx': query_idx, 297 | 'image_path': image_path, 298 | 'mask_path': image_path, # assuming mask_path is the same as image_path 299 | 'dataset_name': dataset_name, 300 | 'data_type': "panoptic", 301 | 'filtering_annotation': filtering_ann, 302 | 'category': folder_name, 303 | 'sub_category': subfolder_name, 304 | 'description': description, 305 | 'short_phrase': short_phrase, 306 | 'features': features 307 | }) 308 | index_counter += 1 309 | return rows 310 | 311 | def _build_metadata_table(self, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): 312 | """ 313 | Build the metadata table from the dataset_path using multithreading. 314 | """ 315 | all_rows = [] 316 | 317 | # Get the list of top-level directories 318 | folders = [folder for folder in os.listdir(self.dataset_path) if os.path.isdir(os.path.join(self.dataset_path, folder))] 319 | 320 | # Use a ThreadPoolExecutor to process folders concurrently 321 | with ThreadPoolExecutor(max_workers=600) as executor: 322 | # Submit jobs for each top-level folder 323 | future_to_folder = { 324 | executor.submit(self._process_folder, folder, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): folder 325 | for folder in folders 326 | } 327 | # Use tqdm to monitor progress 328 | for future in tqdm(as_completed(future_to_folder), total=len(future_to_folder), desc="Processing folders"): 329 | folder = future_to_folder[future] 330 | try: 331 | rows = future.result() 332 | all_rows.extend(rows) 333 | except Exception as exc: 334 | print(f"Folder {folder} generated an exception: {exc}") 335 | 336 | # Optionally, if you want sequential indexing across all rows: 337 | for idx, row in enumerate(all_rows): 338 | row["idx_for_curr_dataset"] = idx 339 | 340 | return pd.DataFrame(all_rows) 341 | 342 | 343 | # class SyntheticDataset(BaseObjectDataset): 344 | # def __init__(self, dataset_path, filtering_annotations_path=None, synthetic_annotation_path=None, 345 | # dataset_name="SyntheticDatasetPlaceHolder", cache_path="metadata_cache.pkl", 346 | # image_cache_dir=None): # image_cache_dir is the directory to save verified images (optional) 347 | # data_type = "panoptic" 348 | # super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 349 | 350 | # # Load filtering annotations if provided 351 | # if self.filtering_annotations_path is not None: 352 | # filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 353 | # default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 354 | # default_annotation = {key: False for key in default_annotation.keys()} 355 | # else: 356 | # filtering_annotations = None 357 | # default_annotation = None 358 | 359 | # if synthetic_annotation_path is not None: 360 | # synthetic_annotation = json.load(open(synthetic_annotation_path, "r")) 361 | # else: 362 | # synthetic_annotation = {} 363 | 364 | # # Image cache directory for verified images (optional) 365 | # self.image_cache_dir = image_cache_dir 366 | 367 | # # Check if metadata cache exists: 368 | # if os.path.exists(cache_path): 369 | # print(f"Loading metadata from cache: {cache_path}") 370 | # with open(cache_path, "rb") as fp: 371 | # self.metadata_table = pickle.load(fp) 372 | # else: 373 | # # If cache does not exist, create the metadata table 374 | # self.metadata_table = self._build_metadata_table(synthetic_annotation, filtering_annotations, default_annotation, dataset_name) 375 | # # Save to cache for future runs 376 | # with open(cache_path, "wb") as fp: 377 | # pickle.dump(self.metadata_table, fp) 378 | # print(f"Metadata cache saved to: {cache_path}") 379 | 380 | # # Create a set of categories for later use 381 | # self.categories = set(self.metadata_table['category']) 382 | 383 | # def _process_folder(self, folder_name, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): 384 | # """ 385 | # Process one top-level directory and return a list of rows (dicts) for that folder. 386 | # Includes checking if each image is valid before adding it to the dataset. 387 | # """ 388 | # rows = [] 389 | # folder_path = os.path.join(self.dataset_path, folder_name) 390 | # if os.path.isdir(folder_path): 391 | # for subfolder_name in os.listdir(folder_path): 392 | # subfolder_path = os.path.join(folder_path, subfolder_name) 393 | # if os.path.isdir(subfolder_path): 394 | # for file_name in os.listdir(subfolder_path): 395 | # # Only process PNG files 396 | # if file_name.endswith(".png"): 397 | # image_path = os.path.join(self.dataset_path, folder_name, subfolder_name, file_name) 398 | 399 | # # Attempt to open and verify the image using Pillow 400 | # try: 401 | # with Image.open(image_path) as img: 402 | # img.verify() # Verify image integrity 403 | # except Exception as e: 404 | # print(f"Skipping corrupted image: {image_path} ({e})") 405 | # continue # Skip this image if it is corrupted 406 | 407 | # # If image is verified and an image cache directory is set, save the verified image there 408 | # if self.image_cache_dir: 409 | # # Reconstruct a similar folder structure in the cache directory 410 | # cache_folder = os.path.join(self.image_cache_dir, folder_name, subfolder_name) 411 | # os.makedirs(cache_folder, exist_ok=True) 412 | # cached_image_path = os.path.join(cache_folder, file_name) 413 | # # Copy the image file to the cache directory 414 | # copyfile(image_path, cached_image_path) 415 | # # Use the cached image path for further processing 416 | # image_path = cached_image_path 417 | 418 | # image_subfolder_id = file_name.replace('.png', '').split("_")[-1] 419 | # universial_idx = f"{dataset_name}_{folder_name}_{subfolder_name}_{image_subfolder_id}" 420 | # query_idx = f"{folder_name}_{subfolder_name}" 421 | 422 | # if query_idx in synthetic_annotation: 423 | # description = synthetic_annotation[query_idx]["description"] 424 | # short_phrase = synthetic_annotation[query_idx]["short_phrase"] 425 | # features = synthetic_annotation[query_idx]["features"] 426 | # # Get filtering annotation if exists 427 | # filtering_ann = (getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) 428 | # if filtering_annotations is not None else None) 429 | # rows.append({ 430 | # 'universal_idx': universial_idx, 431 | # 'query_idx': query_idx, 432 | # 'image_path': image_path, 433 | # 'mask_path': image_path, # Assuming mask_path is the same as image_path 434 | # 'dataset_name': dataset_name, 435 | # 'data_type': "panoptic", 436 | # 'filtering_annotation': filtering_ann, 437 | # 'category': folder_name, 438 | # 'sub_category': subfolder_name, 439 | # 'description': description, 440 | # 'short_phrase': short_phrase, 441 | # 'features': features 442 | # }) 443 | # return rows 444 | 445 | # def _build_metadata_table(self, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): 446 | # """ 447 | # Build the metadata table from the dataset_path using multithreading. 448 | # """ 449 | # all_rows = [] 450 | 451 | # # Get the list of top-level directories 452 | # folders = [folder for folder in os.listdir(self.dataset_path) 453 | # if os.path.isdir(os.path.join(self.dataset_path, folder))] 454 | 455 | # # Define the number of threads you want to use 456 | # num_threads = 300 # Adjust as needed 457 | 458 | # # Use a ThreadPoolExecutor with a defined number of workers 459 | # with ThreadPoolExecutor(max_workers=num_threads) as executor: 460 | # # Submit jobs for each top-level folder 461 | # future_to_folder = { 462 | # executor.submit(self._process_folder, folder, synthetic_annotation, filtering_annotations, default_annotation, dataset_name): folder 463 | # for folder in folders 464 | # } 465 | # # Use tqdm to monitor progress 466 | # for future in tqdm(as_completed(future_to_folder), total=len(future_to_folder), desc="Processing folders"): 467 | # folder = future_to_folder[future] 468 | # try: 469 | # rows = future.result() 470 | # all_rows.extend(rows) 471 | # except Exception as exc: 472 | # print(f"Folder {folder} generated an exception: {exc}") 473 | 474 | # # Optionally, assign a sequential index across all rows 475 | # for idx, row in enumerate(all_rows): 476 | # row["idx_for_curr_dataset"] = idx 477 | 478 | # return pd.DataFrame(all_rows) 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | class SA1BDataset(BaseObjectDataset): 489 | def __init__(self, dataset_path, filtering_annotations_path): 490 | dataset_name = "sa1b_parsed" 491 | data_type = "non-semantic" 492 | super().__init__(dataset_name, data_type, dataset_path, filtering_annotations_path) 493 | # get filtering annotation 494 | if self.filtering_annotations_path is not None: 495 | filtering_annotations = json.load(open(self.filtering_annotations_path, "r")) 496 | default_annotation = filtering_annotations[list(filtering_annotations.keys())[0]] 497 | default_annotation = {key: False for key in default_annotation.keys()} 498 | else: 499 | filtering_annotations = None 500 | 501 | 502 | # get metadata_table 503 | index_counter = 0 504 | rows = [] 505 | for folder_name in os.listdir(self.dataset_path): 506 | folder_path = os.path.join(self.dataset_path, folder_name) 507 | if os.path.isdir(folder_path): 508 | for file_name in os.listdir(folder_path): 509 | if ".png" in file_name or ".jpg" in file_name: 510 | image_path = os.path.join(self.dataset_path, folder_name, file_name) 511 | universial_idx = f"{dataset_name}_{folder_name}_{file_name}" 512 | if universial_idx in filtering_annotations: 513 | rows.append({ 514 | 'universal_idx': universial_idx, 515 | 'idx_for_curr_dataset': index_counter, 516 | 'category': folder_name, 517 | 'image_path': image_path, 518 | 'mask_path': image_path, # Assuming mask_path is the same as image_path 519 | 'dataset_name': dataset_name, 520 | 'data_type': data_type, 521 | 'filtering_annotation': getting_filtering_annotation(universial_idx, filtering_annotations, default_annotation) if filtering_annotations is not None else None 522 | }) 523 | index_counter += 1 524 | self.metadata_table = pd.DataFrame(rows) 525 | self.categories = set(self.metadata_table['category']) 526 | 527 | 528 | class CustomizedDatasetOurFormat(BaseObjectDataset): 529 | def __init__(self, dataset_path, dataset_name, data_type): 530 | super().__init__(dataset_name, data_type, dataset_path) 531 | # get metadata_table 532 | index_counter = 0 533 | rows = [] 534 | for folder_name in os.listdir(self.dataset_path): 535 | folder_path = os.path.join(self.dataset_path, folder_name) 536 | if os.path.isdir(folder_path): 537 | for file_name in os.listdir(folder_path): 538 | if ".png" in file_name: 539 | image_path = os.path.join(self.dataset_path, folder_name, file_name) 540 | rows.append({ 541 | 'idx_for_curr_dataset': index_counter, 542 | 'category': folder_name, 543 | 'image_path': image_path, 544 | 'mask_path': image_path, # Assuming mask_path is the same as image_path 545 | 'dataset_name': dataset_name, 546 | 'data_type': data_type 547 | }) 548 | index_counter += 1 549 | self.metadata_table = pd.DataFrame(rows) 550 | self.categories = set(self.metadata_table['category']) 551 | 552 | --------------------------------------------------------------------------------