├── .gitignore ├── README.md ├── data ├── convert_tfrecord_to_np.py ├── create_shapes.py ├── download_clevrtex.sh ├── download_pascal_and_coco.sh ├── download_rrn.sh ├── download_satnet.sh ├── download_synths.sh └── tfloaders │ ├── clevr_with_masks.py │ ├── multi_dsprites.py │ └── tetrominoes.py ├── eval_obj.py ├── eval_sudoku.py ├── requirements.txt ├── scripts ├── sudoku.md └── synths.md ├── source ├── data │ ├── augs.py │ └── datasets │ │ ├── objs │ │ ├── clevr.py │ │ ├── clevr_tex.py │ │ ├── coco.py │ │ ├── dsprites.py │ │ ├── imagenet.py │ │ ├── load_data.py │ │ ├── npdataset.py │ │ ├── pascal.py │ │ ├── shapes.py │ │ └── tetrominoes.py │ │ └── sudoku │ │ └── sudoku.py ├── evals │ ├── objs │ │ ├── fgari.py │ │ └── mbo.py │ └── sudoku │ │ └── evals.py ├── layers │ ├── common_fns.py │ ├── common_layers.py │ ├── gta.py │ ├── klayer.py │ └── kutils.py ├── models │ ├── objs │ │ ├── knet.py │ │ ├── patchconv.py │ │ ├── pretrained_vits.py │ │ └── vit.py │ └── sudoku │ │ ├── knet.py │ │ └── transformer.py ├── training_utils.py └── utils.py ├── train_obj.py └── train_sudoku.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Artificial Kuramoto Oscillatory Neurons (AKOrN)

2 |

3 | Takeru Miyato 4 | · 5 | Sindy Löwe 6 | · 7 | Andreas Geiger 8 | · 9 | Max Welling 10 |

11 |

[Project page] [Paper]

12 |

ICLR2025 (Oral)

13 |

14 | 15 | 16 | 17 |

18 | 19 | 20 | 21 | This page contains instructions for the initial environment setup and code for the CLEVR-Tex experiments. 22 | - Minimal AKOrN model on Google Colab (The fish example in the paper) [here](https://colab.research.google.com/drive/1n8x2uskNxRIqJvvNaljWDuLAMvxkw0Qn) 23 | - Code for other synthetic datasets (Tetrominoes, dSprits, CLEVR): [here](https://github.com/autonomousvision/akorn/blob/main/scripts/synths.md) 24 | - Sudoku solving: [here](https://github.com/autonomousvision/akorn/blob/main/scripts/sudoku.md) 25 | 26 | ## Setup Conda env 27 | 28 | ``` 29 | yes | conda create -n akorn python=3.12 30 | conda activate akorn 31 | pip3 install -r requirements.txt 32 | ``` 33 | 34 | ## Download the CLEVRTex dataset 35 | ``` 36 | cd data 37 | bash download_clevrtex.sh 38 | cd .. 39 | ``` 40 | 41 | ## Training 42 | ``` 43 | export NUM_GPUS= # If you use a single GPU, run a command without the multi GPU option (remove `--multi-gpu`). 44 | ``` 45 | 46 | ### CLEVRTex 47 | 48 | #### AKOrN 49 | ``` 50 | export L=1 # The number of layers. L=1 or 2. This can be >2, but we only experimented with a single or two-layer model. 51 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_akorn --data_root=./data/clevrtex_full/ --model=akorn --data=clevrtex_full --J=attn --L=${L} 52 | 53 | # Larger model (L=2, ch=512, bs=512) 54 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_large_akorn --data_root=./data/clevrtex_full/ --model=akorn --data=clevrtex_full --J=attn --L=2 --ch=512 --batchsize=512 --epochs=1024 --lr=0.0005 55 | ``` 56 | 57 | #### ItrSA 58 | ``` 59 | export L=1 60 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_itrsa --data_root=./data/clevrtex_full/ --model=vit --data=clevrtex_full --L=${L} --gta=False 61 | ``` 62 | 63 | ## Evaluation 64 | 65 | ### CLEVRTex (-OOD, -CAMO) 66 | 67 | ``` 68 | export DATA_TYPE=full #{full, outd, camo} 69 | export L=1 70 | # AKOrN 71 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --model=akorn --data=clevrtex_${DATA_TYPE} --J=attn --L=${L} --model_path=runs/clvtex_akorn/ema_499.pth --model_imsize=128 72 | # ItrSA 73 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --model=vit --data=clevrtex_${DATA_TYPE} --gta=False --L=${L} --model_path=runs/clvtex_itrsa/ema_499.pth --model_imsize=128 74 | ``` 75 | 76 | ### Eval with Up-tiling (See Appendix section). 77 | ``` 78 | # Might take long time depending on the CPU spec 79 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --saccade_r=4 --model=akorn --data=clevrtex_${DATA_TYPE} --J=attn --L=${L} --model_path=runs/clvtex_akorn/ema_499.pth --model_imsize=128 80 | ``` 81 | 82 | #### Performance table 83 | | Model | CLEVRTex FG-ARI | CLEVRTex MBO | OOD FG-ARI | OOD MBO | CAMO FG-ARI | CAMO MBO | 84 | |------------------------------------|-----------------|--------------|------------|---------|-------------|----------| 85 | | ViT | 46.4±0.6 | 25.1±0.7 | 44.1±0.5 | 27.2±0.5 | 32.5±0.6 | 16.1±1.1 | 86 | | ItrSA (L = 1) | 65.7±0.3 | 44.6±0.9 | 64.6±0.8 | 45.1±0.4 | 49.0±0.7 | 30.2±0.8 | 87 | | ItrSA (L = 2) | 76.3±0.4 | 48.5±0.1 | 74.9±0.8 | 46.4±0.5 | 61.9±1.3 | 37.1±0.5 | 88 | | AKOrN (attn, L = 1) | 75.6±0.2 | 55.0±0.0 | 73.4±0.4 | 56.1±1.1 | 59.9±0.1 | 44.3±0.9 | 89 | | AKOrN (attn, L = 2) | 80.5±1.5 | 54.9±0.6 | 79.2±1.2 | 55.7±0.5 | 67.7±1.5 | 46.2±0.9 | 90 | 91 | ##### With Up-tiling (x4) 92 | | Model | CLEVRTex FG-ARI | CLEVRTex MBO | OOD FG-ARI | OOD MBO | CAMO FG-ARI | CAMO MBO | 93 | |------------------------------------|-----------------|--------------|------------|---------|-------------|----------| 94 | | AKOrN (attn, L = 2) | 87.7±1.0 | 55.3±2.1 | 85.2±0.9 | 55.6±1.5 | 74.5±1.2 | 45.6±3.4 | 95 | | Large AKOrN (attn, L = 2) | 88.5±0.9 | 59.7±0.9 | 87.7±0.5 | 60.8±0.6 | 77.0±0.5 | 53.4±0.7 | 96 | -------------------------------------------------------------------------------- /data/convert_tfrecord_to_np.py: -------------------------------------------------------------------------------- 1 | # provided by @loeweX 2 | import os 3 | from typing import List 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from einops import rearrange 8 | import argparse 9 | 10 | 11 | 12 | from tfloaders import ( 13 | clevr_with_masks, 14 | multi_dsprites, 15 | tetrominoes, 16 | ) 17 | 18 | def resize_input(x): 19 | # Center-crop with boundaries [29, 221] (height) and [64, 256] (width). 20 | x = x[29:221, 64:256] 21 | # Resize cropped image to 128x128 resolution. 22 | return np.array( 23 | Image.fromarray(x).resize((128, 128), resample=Image.Resampling.NEAREST) 24 | ) 25 | 26 | def resize_input_tetrominoes(x): 27 | # Center-crop 28 | x = x[1:33, 1:33] 29 | return x 30 | 31 | 32 | def get_hparams(dataset_name): 33 | 34 | pdir = './' 35 | 36 | if dataset_name == "multi_dsprites": 37 | variant = "colored_on_grayscale" # binarized, colored_on_colored 38 | input_path = f"{pdir}/multi_dsprites/multi_dsprites_{variant}.tfrecords" 39 | output_path = f"{pdir}/multi_dsprites/" 40 | 41 | dataset = multi_dsprites.dataset(input_path, variant) 42 | train_size = 60000 43 | dataset_name = variant 44 | 45 | elif dataset_name == "tetrominoes": 46 | input_path = f"{pdir}/tetrominoes/tetrominoes_train.tfrecords" 47 | output_path = f"{pdir}/tetrominoes/" 48 | 49 | dataset = tetrominoes.dataset(input_path) 50 | train_size = 60000 51 | 52 | elif dataset_name in ["clevr_with_masks"]: 53 | input_path = f"{pdir}/clevr_with_masks/clevr_with_masks_train.tfrecords" 54 | output_path = f"{pdir}/clevr_with_masks/" 55 | 56 | dataset = clevr_with_masks.dataset(input_path) 57 | train_size = 70000 58 | 59 | val_size = 10000 # 5000 60 | test_size = 320 61 | eval_size = 64 62 | return ( 63 | input_path, 64 | output_path, 65 | dataset, 66 | train_size, 67 | val_size, 68 | test_size, 69 | eval_size, 70 | dataset_name, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | parser = argparse.ArgumentParser(description="Convert TFRecord to NPZ") 77 | parser.add_argument( 78 | "--dataset_name", 79 | type=str, 80 | required=True, 81 | choices=["multi_dsprites", "tetrominoes", "clevr_with_masks"], 82 | help="Name of the dataset to convert", 83 | ) 84 | args = parser.parse_args() 85 | dataset_name = args.dataset_name 86 | 87 | clevr6 = True 88 | 89 | ( 90 | input_path, 91 | output_path, 92 | dataset, 93 | train_size, 94 | val_size, 95 | test_size, 96 | eval_size, 97 | dataset_name, 98 | ) = get_hparams(dataset_name) 99 | 100 | if not os.path.exists(output_path): 101 | os.mkdir(output_path) 102 | 103 | batched_dataset = dataset.batch(1) 104 | iterator = iter(batched_dataset) 105 | 106 | counter = 0 107 | images: List[np.array] = [] 108 | labels: List[np.array] = [] 109 | 110 | 111 | while True: 112 | try: 113 | data = next(iterator) 114 | except StopIteration: 115 | break 116 | 117 | input_image = np.squeeze(data["image"].numpy()) 118 | 119 | pixelwise_label = np.zeros( 120 | (1, input_image.shape[0], input_image.shape[1]), dtype=np.uint8 121 | ) 122 | for idx in range(data["mask"].shape[1]): 123 | pixelwise_label[ 124 | np.where(data["mask"].numpy()[:, idx, :, :, 0] == 255) 125 | ] = idx 126 | 127 | pixelwise_label = np.squeeze(pixelwise_label) 128 | 129 | if dataset_name in ["clevr_with_masks"]: 130 | input_image = resize_input(input_image) 131 | pixelwise_label = resize_input(pixelwise_label) 132 | 133 | if clevr6 and np.max(pixelwise_label) > 6: 134 | # CLEVR6: only use images with maximally 6 objects 135 | continue 136 | 137 | if dataset_name in ["tetrominoes"]: 138 | input_image = resize_input_tetrominoes(input_image) 139 | pixelwise_label = resize_input_tetrominoes(pixelwise_label) 140 | 141 | input_image = rearrange(input_image, "h w c -> c h w") 142 | input_image = (input_image / 255) 143 | 144 | # pixelwise_label = rearrange(pixelwise_label, "w h -> h w") 145 | 146 | images.append(input_image) 147 | labels.append(pixelwise_label) 148 | 149 | counter += 1 150 | 151 | if counter % 1000 == 0: 152 | print(counter) 153 | 154 | if counter % (train_size + val_size + test_size + eval_size) == 0: 155 | break 156 | 157 | print("Save files") 158 | 159 | np.savez_compressed( 160 | os.path.join(output_path, f"{dataset_name}_eval"), 161 | images=np.squeeze(np.array(images[:eval_size])), 162 | labels=np.squeeze(np.array(labels[:eval_size])), 163 | ) 164 | 165 | start_idx = eval_size 166 | np.savez_compressed( 167 | os.path.join(output_path, f"{dataset_name}_test"), 168 | images=np.squeeze(np.array(images[start_idx : start_idx + test_size])), 169 | labels=np.squeeze(np.array(labels[start_idx : start_idx + test_size])), 170 | ) 171 | 172 | start_idx += test_size 173 | np.savez_compressed( 174 | os.path.join(output_path, f"{dataset_name}_val"), 175 | images=np.squeeze(np.array(images[start_idx : start_idx + val_size])), 176 | labels=np.squeeze(np.array(labels[start_idx : start_idx + val_size])), 177 | ) 178 | 179 | start_idx += val_size 180 | np.savez_compressed( 181 | os.path.join(output_path, f"{dataset_name}_train"), 182 | images=np.squeeze(np.array(images[start_idx:])), 183 | labels=np.squeeze(np.array(labels[start_idx:])), 184 | ) 185 | 186 | print(f"Train dataset size: {len(images[start_idx:])}") 187 | 188 | 189 | -------------------------------------------------------------------------------- /data/create_shapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import random 4 | from skimage.draw import polygon, disk 5 | from tqdm import tqdm 6 | 7 | seed = 1234 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | 11 | # Initialize dimensions 12 | H, W = 18, 18 13 | 14 | # Triangle (hollow with thicker lines) 15 | triangle = np.zeros((H, W), dtype=np.float32) 16 | 17 | outer_rr, outer_cc = polygon([4, 11, 11], [8, 1, 15], shape=(H, W)) 18 | inner_rr, inner_cc = polygon([6, 9, 9], [8, 5, 11], shape=(H, W)) 19 | triangle[outer_rr, outer_cc] = 1.0 20 | triangle[inner_rr, inner_cc] = 0.0 21 | 22 | # Square (hollow with thicker lines) 23 | square = np.zeros((H, W), dtype=np.float32) 24 | square[3:13, 3:13] = 1.0 # Outer square 25 | square[5:11, 5:11] = 0.0 # Inner square 26 | 27 | # Circle (hollow, larger) 28 | circle = np.zeros((H, W), dtype=np.float32) 29 | outer_rr, outer_cc = disk((8, 8), 7, shape=(H, W)) 30 | inner_rr, inner_cc = disk((8, 8), 5, shape=(H, W)) 31 | circle[outer_rr, outer_cc] = 1.0 32 | circle[inner_rr, inner_cc] = 0.0 33 | 34 | # Diamond (hollow, larger and thicker lines) 35 | diamond = np.zeros((H, W), dtype=np.float32) 36 | outer_rr, outer_cc = polygon([8, 2, 8, 14], [2, 8, 14, 8], shape=(H, W)) 37 | inner_rr, inner_cc = polygon([8, 4, 8, 12], [4, 8, 12, 8], shape=(H, W)) 38 | diamond[outer_rr, outer_cc] = 1.0 39 | diamond[inner_rr, inner_cc] = 0.0 40 | 41 | # Shape dictionary 42 | shapes = {1: triangle, 2: square, 3: circle, 4: diamond} 43 | 44 | # Parameters 45 | image_size = 40 46 | num_images = 42000 47 | min_shapes = 2 48 | max_shapes = 4 49 | padding = 0 50 | 51 | # Storage for images and labels 52 | images = np.zeros((num_images, image_size, image_size), dtype=np.float32) 53 | class_labels = [] # Class label for each pixel 54 | instance_labels = [] # Instance label for each pixel 55 | 56 | # Function to place a shape in an image with bounds checking 57 | def place_shape(image, class_label, instance_label, shape, shape_id, instance_id, pos_x, pos_y): 58 | for i in range(H): 59 | for j in range(W): 60 | if shape[i, j] == 1.0: 61 | xi, yj = pos_x + i, pos_y + j 62 | if class_label[xi, yj] == 0: # If background, place shape 63 | image[xi, yj] = 1.0 # Shape pixel value 64 | class_label[xi, yj] = shape_id 65 | instance_label[xi, yj] = instance_id 66 | else: # Overlapping area 67 | class_label[xi, yj] = -1 68 | instance_label[xi, yj] = -1 69 | image[xi, yj] = 1.0 # Ensure overlapping area is also bright 70 | 71 | # Generate images and labels 72 | for img_index in tqdm(range(num_images)): 73 | # Generate a random background color between 0 and 0.3 74 | bg_color = random.uniform(0.1, 0.6) 75 | 76 | # Initialize the image with the random background color 77 | image = np.full((image_size, image_size), bg_color, dtype=np.float32) 78 | class_label = np.zeros((image_size, image_size), dtype=np.int8) 79 | instance_label = np.zeros((image_size, image_size), dtype=np.int8) 80 | 81 | # Randomly select number of shapes 82 | num_shapes = random.randint(min_shapes, max_shapes) 83 | 84 | # Randomly place shapes 85 | for instance_id in range(1, num_shapes + 1): 86 | shape_id = random.randint(1, 4) # Randomly select a shape 87 | shape = shapes[shape_id] 88 | 89 | # Random position with padding 90 | pos_x = random.randint(padding, image_size - H - padding) 91 | pos_y = random.randint(padding, image_size - W - padding) 92 | 93 | # Place shape 94 | place_shape(image, class_label, instance_label, shape, shape_id, instance_id, pos_x, pos_y) 95 | 96 | # Store the image and labels 97 | images[img_index] = image 98 | class_labels.append(class_label) 99 | instance_labels.append(instance_label) 100 | 101 | # Convert class_labels and instance_labels lists to numpy arrays for easy saving 102 | class_labels = np.array(class_labels, dtype=np.int8) 103 | instance_labels = np.array(instance_labels, dtype=np.int8) 104 | 105 | # Example of displaying one generated image with labels 106 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4)) 107 | ax1.imshow(images[0], cmap="gray", vmin=0, vmax=1) 108 | ax1.set_title(f"Generated Image (Background color: {images[0][0,0]:.2f})") 109 | ax2.imshow(class_labels[0], cmap="gray") 110 | ax2.set_title("Class Label (0: background, -1: overlap, 1-4: shapes)") 111 | ax3.imshow(instance_labels[0], cmap="nipy_spectral") 112 | ax3.set_title("Instance Label (-1: overlap)") 113 | plt.show() 114 | 115 | train = np.arange(0, 40000) 116 | val = np.arange(40000, 41000) 117 | test = np.arange(41000, 42000) 118 | 119 | images = images[:, None] 120 | # Saving images and labels (optional) 121 | import os 122 | os.makedirs("./Shapes", exist_ok=True) 123 | dataset = {"images":images[train], "labels": instance_labels[train], "pixelwise_class_labels": class_labels[train]} 124 | np.savez_compressed("./Shapes/train.npz", **dataset) 125 | dataset = {"images":images[val], "labels": instance_labels[val], "pixelwise_class_labels": class_labels[val]} 126 | np.savez_compressed("./Shapes/val.npz", **dataset) 127 | dataset = {"images":images[test], "labels": instance_labels[test], "pixelwise_class_labels": class_labels[test]} 128 | np.savez_compressed("./Shapes/test.npz", **dataset) 129 | 130 | # np.save("instance_labels.npy", instance_labels) 131 | -------------------------------------------------------------------------------- /data/download_clevrtex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ##### Download CLEVRTEX dataset. 4 | urls=( 5 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part1.tar.gz" 6 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part2.tar.gz" 7 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part3.tar.gz" 8 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part4.tar.gz" 9 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part5.tar.gz" 10 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_outd.tar.gz" 11 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_camo.tar.gz" 12 | ) 13 | 14 | output_dir="./" 15 | 16 | mkdir -p $output_dir 17 | 18 | for url in "${urls[@]}"; do 19 | 20 | filename=$(basename "$url") 21 | 22 | echo "Downloading $filename..." 23 | wget -q --show-progress "$url" -P "$output_dir" 24 | 25 | echo "Extracting $filename..." 26 | tar -xzf "$output_dir/$filename" -C "$output_dir" 27 | 28 | rm "$output_dir/$filename" 29 | done 30 | 31 | -------------------------------------------------------------------------------- /data/download_pascal_and_coco.sh: -------------------------------------------------------------------------------- 1 | 2 | ##### Download Pascal VOC 2012 dataset and add trainaug split. 3 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 4 | wget https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip 5 | 6 | tar -xf VOCtrainval_11-May-2012.tar 7 | unzip SegmentationClassAug.zip -d VOCdevkit/VOC2012 8 | 9 | mv trainaug.txt VOCdevkit/VOC2012/ImageSets/Segmentation 10 | mv VOCdevkit/VOC2012/SegmentationClassAug/* VOCdevkit/VOC2012/SegmentationClass/ 11 | 12 | rm -r VOCdevkit/VOC2012/__MACOSX 13 | rm SegmentationClassAug.zip 14 | rm VOCtrainval_11-May-2012.tar 15 | 16 | ###### Download COCO 2017 dataset. 17 | wget http://images.cocodataset.org/zips/train2017.zip 18 | wget http://images.cocodataset.org/zips/val2017.zip 19 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 20 | 21 | unzip train2017.zip 22 | unzip val2017.zip 23 | unzip annotations_trainval2017.zip 24 | 25 | mkdir -p COCO/annotations 26 | mkdir -p COCO/images 27 | mv annotations COCO/ 28 | mv train2017 COCO/images 29 | mv val2017 COCO/images 30 | 31 | rm -r annotations_trainval2017.zip 32 | rm train2017.zip 33 | rm val2017.zip 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /data/download_rrn.sh: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/yilundu/ired_code_release/blob/main/data/download-rrn.sh 2 | # Original RRN Sodoku data 3 | wget https://www.dropbox.com/s/rp3hbjs91xiqdgc/sudoku-hard.zip?dl=1 4 | mv sudoku-hard.zip?dl=1 sudoku-hard.zip 5 | unzip sudoku-hard.zip 6 | mv sudoku-hard sudoku-rrn 7 | rm sudoku-hard.zip 8 | rm -rf __MACOSX -------------------------------------------------------------------------------- /data/download_satnet.sh: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/yilundu/ired_code_release/blob/main/data/download-satnet.sh 2 | # Original SAT-Net repo 3 | wget -cq powei.tw/sudoku.zip && unzip -qq sudoku.zip && rm sudoku.zip 4 | wget -cq powei.tw/parity.zip && unzip -qq parity.zip && rm parity.zip -------------------------------------------------------------------------------- /data/download_synths.sh: -------------------------------------------------------------------------------- 1 | # Need gsutil installed. run `conda install conda-forge::gsutil` to install it or manually download datasets from https://console.cloud.google.com/storage/browser/multi-object-datasets;tab=objects?pli=1&inv=1&invt=AbjJBg&prefix=&forceOnObjectsSortingFiltering=false 2 | for dataset in tetrominoes multi_dsprites clevr_with_masks; do 3 | gsutil cp -r gs://multi-object-datasets/$dataset ./ 4 | python convert_tfrecord_to_np.py --dataset_name=$dataset 5 | done 6 | -------------------------------------------------------------------------------- /data/tfloaders/clevr_with_masks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """CLEVR (with masks) dataset reader.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 21 | IMAGE_SIZE = [240, 320] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 11 26 | BYTE_FEATURES = ['mask', 'image', 'color', 'material', 'shape', 'size'] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string), 33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | 'z': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | 'pixel_coords': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | 'rotation': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | 'size': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 39 | 'material': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 40 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 41 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string), 42 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 43 | } 44 | 45 | 46 | def _decode(example_proto): 47 | # Parse the input `tf.Example` proto using the feature description dict above. 48 | single_example = tf.parse_single_example(example_proto, features) 49 | for k in BYTE_FEATURES: 50 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 51 | axis=-1) 52 | return single_example 53 | 54 | 55 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 56 | """Read, decompress, and parse the TFRecords file. 57 | 58 | Args: 59 | tfrecords_path: str. Path to the dataset file. 60 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 61 | for `tf.data.TFRecordDataset.__init__`. 62 | map_parallel_calls: int. Number of elements decoded asynchronously in 63 | parallel. See documentation for `tf.data.Dataset.map`. 64 | 65 | Returns: 66 | An unbatched `tf.data.TFRecordDataset`. 67 | """ 68 | raw_dataset = tf.data.TFRecordDataset( 69 | tfrecords_path, compression_type=COMPRESSION_TYPE, 70 | buffer_size=read_buffer_size) 71 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 72 | -------------------------------------------------------------------------------- /data/tfloaders/multi_dsprites.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Multi-dSprites dataset reader.""" 16 | 17 | import functools 18 | 19 | import tensorflow.compat.v1 as tf 20 | 21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP") 22 | IMAGE_SIZE = [64, 64] 23 | # The maximum number of foreground and background entities in each variant 24 | # of the provided datasets. The values correspond to the number of 25 | # segmentation masks returned per scene. 26 | MAX_NUM_ENTITIES = {"binarized": 4, "colored_on_grayscale": 6, "colored_on_colored": 5} 27 | BYTE_FEATURES = ["mask", "image"] 28 | 29 | 30 | def feature_descriptions(max_num_entities, is_grayscale=False): 31 | """Create a dictionary desc ribing the dataset features. 32 | 33 | Args: 34 | max_num_entities: int. The maximum number of foreground and background 35 | entities in each image. This corresponds to the number of segmentation 36 | masks and generative factors returned per scene. 37 | is_grayscale: bool. Whether images are grayscale. Otherwise they're assumed 38 | to be RGB. 39 | 40 | Returns: 41 | A dictionary which maps feature names to `tf.Example`-compatible shape and 42 | data type descriptors. 43 | """ 44 | num_channels = 1 if is_grayscale else 3 45 | return { 46 | "image": tf.FixedLenFeature(IMAGE_SIZE + [num_channels], tf.string), 47 | "mask": tf.FixedLenFeature(IMAGE_SIZE + [max_num_entities, 1], tf.string), 48 | "x": tf.FixedLenFeature([max_num_entities], tf.float32), 49 | "y": tf.FixedLenFeature([max_num_entities], tf.float32), 50 | "shape": tf.FixedLenFeature([max_num_entities], tf.float32), 51 | "color": tf.FixedLenFeature([max_num_entities, num_channels], tf.float32), 52 | "visibility": tf.FixedLenFeature([max_num_entities], tf.float32), 53 | "orientation": tf.FixedLenFeature([max_num_entities], tf.float32), 54 | "scale": tf.FixedLenFeature([max_num_entities], tf.float32), 55 | } 56 | 57 | 58 | def _decode(example_proto, features): 59 | # Parse the input `tf.Example` proto using a feature description dictionary. 60 | single_example = tf.parse_single_example(example_proto, features) 61 | for k in BYTE_FEATURES: 62 | single_example[k] = tf.squeeze( 63 | tf.decode_raw(single_example[k], tf.uint8), axis=-1 64 | ) 65 | # To return masks in the canonical [entities, height, width, channels] format, 66 | # we need to transpose the tensor axes. 67 | single_example["mask"] = tf.transpose(single_example["mask"], [2, 0, 1, 3]) 68 | return single_example 69 | 70 | 71 | def dataset( 72 | tfrecords_path, dataset_variant, read_buffer_size=None, map_parallel_calls=None 73 | ): 74 | """Read, decompress, and parse the TFRecords file. 75 | 76 | Args: 77 | tfrecords_path: str. Path to the dataset file. 78 | dataset_variant: str. One of ['binarized', 'colored_on_grayscale', 79 | 'colored_on_colored']. This is used to identify the maximum number of 80 | entities in each scene. If an incorrect identifier is passed in, the 81 | TFRecords file will not be read correctly. 82 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 83 | for `tf.data.TFRecordDataset.__init__`. 84 | map_parallel_calls: int. Number of elements decoded asynchronously in 85 | parallel. See documentation for `tf.data.Dataset.map`. 86 | 87 | Returns: 88 | An unbatched `tf.data.TFRecordDataset`. 89 | """ 90 | if dataset_variant not in MAX_NUM_ENTITIES: 91 | raise ValueError( 92 | "Invalid `dataset_variant` provided. The supported values" 93 | " are: {}".format(list(MAX_NUM_ENTITIES.keys())) 94 | ) 95 | max_num_entities = MAX_NUM_ENTITIES[dataset_variant] 96 | is_grayscale = dataset_variant == "binarized" 97 | raw_dataset = tf.data.TFRecordDataset( 98 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size 99 | ) 100 | features = feature_descriptions(max_num_entities, is_grayscale) 101 | partial_decode_fn = functools.partial(_decode, features=features) 102 | return raw_dataset.map(partial_decode_fn, num_parallel_calls=map_parallel_calls) 103 | 104 | 105 | -------------------------------------------------------------------------------- /data/tfloaders/tetrominoes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tetrominoes dataset reader.""" 16 | 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 21 | IMAGE_SIZE = [35, 35] 22 | # The maximum number of foreground and background entities in the provided 23 | # dataset. This corresponds to the number of segmentation masks returned per 24 | # scene. 25 | MAX_NUM_ENTITIES = 4 26 | BYTE_FEATURES = ['mask', 'image'] 27 | 28 | # Create a dictionary mapping feature names to `tf.Example`-compatible 29 | # shape and data type descriptors. 30 | features = { 31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string), 33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 35 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 36 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32), 37 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32), 38 | } 39 | 40 | 41 | def _decode(example_proto): 42 | # Parse the input `tf.Example` proto using the feature description dict above. 43 | single_example = tf.parse_single_example(example_proto, features) 44 | for k in BYTE_FEATURES: 45 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 46 | axis=-1) 47 | return single_example 48 | 49 | 50 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None): 51 | """Read, decompress, and parse the TFRecords file. 52 | 53 | Args: 54 | tfrecords_path: str. Path to the dataset file. 55 | read_buffer_size: int. Number of bytes in the read buffer. See documentation 56 | for `tf.data.TFRecordDataset.__init__`. 57 | map_parallel_calls: int. Number of elements decoded asynchronously in 58 | parallel. See documentation for `tf.data.Dataset.map`. 59 | 60 | Returns: 61 | An unbatched `tf.data.TFRecordDataset`. 62 | """ 63 | raw_dataset = tf.data.TFRecordDataset( 64 | tfrecords_path, compression_type=COMPRESSION_TYPE, 65 | buffer_size=read_buffer_size) 66 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls) 67 | 68 | -------------------------------------------------------------------------------- /eval_obj.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | # sys.path.append(rootdir) 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim 7 | import tqdm 8 | from ema_pytorch import EMA 9 | import matplotlib.pyplot as plt 10 | 11 | from source.models.objs.knet import AKOrN 12 | from source.models.objs.vit import ViT 13 | from source.utils import get_worker_init_fn 14 | from torch.nn import functional as F 15 | from source.layers.common_layers import RGBNormalize 16 | import numpy as np 17 | 18 | import timm 19 | from timm.models import VisionTransformer 20 | from source.utils import gen_saccade_imgs, apply_pca_torch, str2bool 21 | import argparse 22 | 23 | from source.evals.objs.mbo import calc_mean_best_overlap 24 | from source.evals.objs.fgari import calc_fgari_score 25 | 26 | from typing import Callable 27 | 28 | class Wrapper(nn.Module): 29 | def __init__(self, model): 30 | super().__init__() 31 | self.module = model 32 | 33 | 34 | from collections import OrderedDict 35 | from typing import Dict, Callable 36 | import torch 37 | 38 | noise = 0.0 39 | 40 | 41 | def remove_all_forward_hooks(model: torch.nn.Module) -> None: 42 | for name, child in model._modules.items(): 43 | if child is not None: 44 | if hasattr(child, "_forward_hooks"): 45 | child._forward_hooks: Dict[int, Callable] = OrderedDict() 46 | remove_all_forward_hooks(child) 47 | 48 | 49 | def model_preds(model, org_images): 50 | activation = {} 51 | imsize_h, imsize_w = org_images.shape[-2], org_images.shape[-1] 52 | 53 | def get_activation(name): 54 | def hook(model, input, output): 55 | activation[name] = output.detach() 56 | 57 | return hook 58 | 59 | if isinstance(model, AKOrN): 60 | model.out[0].register_forward_hook(get_activation("z")) 61 | elif isinstance(model, ViT): 62 | model.out[0].register_forward_hook(get_activation("z")) 63 | 64 | else: 65 | raise Exception() 66 | 67 | model.eval() 68 | imgs = org_images.cuda() 69 | 70 | with torch.no_grad(): 71 | if ( 72 | isinstance(model, AKOrN) 73 | or isinstance(model, ViT) 74 | ): 75 | output, _xs = model(imgs, return_xs=True) 76 | else: 77 | output = model(imgs) 78 | _xs = None 79 | v = activation["z"] 80 | 81 | if isinstance(model, AKOrN) or isinstance(model, ViT): 82 | v = F.normalize(v, dim=1) 83 | #elif isinstance(model, ViTWrapper): 84 | # v = F.normalize(v, dim=2) 85 | # v = v.permute(0, 2, 1)[..., 1:] 86 | # h, w = int(np.sqrt(x.shape[-1])), int(np.sqrt(x.shape[-1])) # estimated inpsize 87 | # v = v.unflatten(-1, (h, w)) 88 | remove_all_forward_hooks(model) 89 | return v 90 | 91 | def clustering(x, h, w, method="spectral", n_clusters=3): 92 | from sklearn.cluster import KMeans 93 | 94 | if method == "agglomerative": 95 | import fastcluster 96 | from scipy.cluster.hierarchy import fcluster 97 | from scipy.cluster.hierarchy import linkage 98 | 99 | x = x.view(x.shape[0], -1).transpose(-2, -1).to("cpu").detach() 100 | Z = fastcluster.average(x) 101 | label = fcluster(Z, t=n_clusters, criterion="maxclust") 102 | return label.reshape(h, w) 103 | elif method == "kmeans": 104 | kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit( 105 | x.view(x.shape[0], -1).transpose(-2, -1).to("cpu").detach() 106 | ) 107 | label = kmeans.labels_ 108 | return label.reshape(h, w) 109 | 110 | else: 111 | raise ValueError("Clustering method not found") 112 | 113 | 114 | 115 | 116 | 117 | from source.layers.common_fns import positionalencoding2d 118 | 119 | def eval( 120 | model, 121 | images, 122 | gt, 123 | method="agglomerative", 124 | n_clusters=7, 125 | saccade_r=1, 126 | pca=False, 127 | pca_dim=128, 128 | ): 129 | preds = [] 130 | N = images.shape[0] 131 | _imgs, _ = gen_saccade_imgs(images, model.psize, model.psize // saccade_r) 132 | outputs = [] 133 | for img in _imgs: 134 | v = model_preds(model, img) 135 | outputs.append(v.detach().cpu()) 136 | 137 | nh, nw = int(np.sqrt(len(_imgs))), int(np.sqrt(len(_imgs))) 138 | ho, wo = outputs[0].shape[-2], outputs[0].shape[-1] 139 | nimg = torch.zeros(N, outputs[0].shape[1], ho, nh, wo, nw) 140 | for h in range(nh): 141 | for w in range(nw): 142 | nimg[:, :, :, h, :, w] = outputs[h * (nh) + w] 143 | nimg = nimg.view(N, -1, ho * nh, wo * nw) 144 | 145 | from source.utils import apply_pca_torch 146 | 147 | with torch.no_grad(): 148 | if pca: 149 | pcaimg_ = apply_pca_torch(nimg, n_components=pca_dim) 150 | x = pcaimg_ 151 | else: 152 | x = nimg 153 | 154 | for idx in range(N): 155 | _x = x[idx] 156 | pred = clustering(_x, *_x.shape[1:], method, n_clusters) 157 | pred = torch.nn.Upsample( 158 | scale_factor=(images.shape[-2]/pred.shape[-2], images.shape[-1]/pred.shape[-1]), 159 | mode='nearest')(torch.Tensor(pred[None, None]).float())[0, 0] 160 | preds.append(pred) 161 | 162 | preds = torch.stack(preds, 0).long() 163 | 164 | scores = {} 165 | evaluate_sem = False 166 | if isinstance(gt, list): 167 | gt_sem = gt[1] 168 | gt = gt[0] 169 | evaluate_sem = True 170 | 171 | _gt = ((gt > 0).float() * gt).long() # set ignore bg (-1) to 0 172 | # compute fgari 173 | scores["fgari"] = np.array(calc_fgari_score(_gt, preds)) 174 | 175 | # compute mean best overlap 176 | score, _scores = calc_mean_best_overlap(gt.numpy(), preds.numpy()) 177 | scores["mbo"] = score 178 | scores["mbo_scores"] = _scores 179 | 180 | if evaluate_sem: 181 | score, _scores = calc_mean_best_overlap(gt_sem.numpy(), preds.numpy()) 182 | scores["mbo_c"] = score 183 | scores["mbo_c_scores"] = _scores 184 | 185 | return scores, preds 186 | 187 | 188 | def get_loader(data, data_root, imsize, batchsize): 189 | 190 | from source.data.datasets.objs.load_data import load_data 191 | 192 | dataset, imsize, collate_fn = load_data(data, data_root, imsize, is_eval=True) 193 | 194 | if data == "clevrtex_full" or data == "clevrtex_outd" or data == "clevrtex_camo": 195 | loader = torch.utils.data.DataLoader( 196 | dataset, 197 | batch_size=batchsize, 198 | num_workers=0, 199 | shuffle=True, 200 | collate_fn=collate_fn, 201 | ) 202 | 203 | elif data == "coco": 204 | loader = torch.utils.data.DataLoader( 205 | dataset, 206 | batch_size=batchsize, 207 | num_workers=0, 208 | shuffle=True, 209 | collate_fn=collate_fn, 210 | ) 211 | else: 212 | loader = torch.utils.data.DataLoader( 213 | dataset, 214 | batch_size=batchsize, 215 | num_workers=0, 216 | shuffle=True, 217 | ) 218 | return loader, imsize 219 | 220 | 221 | def eval_dataset( 222 | model, 223 | data, 224 | data_root=None, 225 | imsize=None, 226 | batchsize=100, 227 | method="agglomerative", 228 | instance=True, 229 | saccade_r=1, 230 | pca=False, 231 | ): 232 | 233 | scores = [] 234 | preds = [] 235 | masks = [] 236 | 237 | loader, imsize = get_loader(data, data_root, imsize, batchsize) 238 | for ret in tqdm.tqdm(loader): 239 | pca_dim = 128 240 | if data == "clevr": 241 | images = ret[0] 242 | if instance: 243 | labels = ret[1]["pixelwise_instance_labels"] 244 | else: 245 | labels = ret[1]["pixelwise_class_labels"] 246 | n_clusters = 11 247 | 248 | elif data == "clevrtex_camo" or data == "clevrtex_full" or data == "clevrtex_outd": 249 | images = ret[1] 250 | labels = ret[2][:, 0] 251 | n_clusters = 11 252 | 253 | elif data == "pascal": 254 | images = ret[0] 255 | labels_instance = ret[1]["pixelwise_instance_labels"] 256 | labels_sem = ret[1]["pixelwise_class_labels"] 257 | labels = [labels_instance, labels_sem] 258 | n_clusters = 4 259 | 260 | elif data == "coco": 261 | images = ret["img"] 262 | labels_instance = ret["masks"].long() 263 | labels_sem = ret["sem_masks"].long() 264 | ovlp = ret["inst_overlap_masks"].long() 265 | labels_instance[ovlp == 1] = -1 266 | labels_sem[ovlp == 1] = -1 267 | labels = [labels_instance, labels_sem] 268 | n_clusters = 7 269 | score, pred = eval( 270 | model, 271 | images, 272 | labels, 273 | method, 274 | n_clusters, 275 | saccade_r=saccade_r, 276 | pca=pca, 277 | pca_dim=pca_dim, 278 | ) 279 | scores.append(score) 280 | preds.append(pred) 281 | masks.append(labels) 282 | return scores, preds 283 | 284 | 285 | def print_stats(scores): 286 | fgaris = [] 287 | mbos = [] 288 | mbocs = [] 289 | for _s in scores: 290 | fgaris.append(_s["fgari"]) 291 | mbos.append(_s["mbo_scores"]) 292 | if "mbo_c" in _s: 293 | mbocs.append(_s["mbo_c_scores"]) 294 | print(np.concatenate(fgaris, 0).mean(), np.concatenate(fgaris, 0).std()) 295 | _mbos = np.concatenate(mbos) 296 | _mbos = _mbos[_mbos != -1] 297 | print(np.mean(_mbos), np.std(_mbos)) 298 | if len(mbocs) > 0: 299 | _mbocs = np.concatenate(mbocs) 300 | _mbocs = _mbocs[_mbocs != -1] 301 | print(np.mean(_mbocs), np.std(_mbocs)) 302 | 303 | 304 | if __name__ == "__main__": 305 | 306 | parser = argparse.ArgumentParser() 307 | 308 | # Eval options 309 | parser.add_argument("--model_path", type=str, help="path to the model") 310 | parser.add_argument("--saccade_r", type=int, default=1, help="Uptiling factor") 311 | parser.add_argument("--pca", type=str2bool, default=True) 312 | 313 | # Data loading 314 | parser.add_argument("--limit_cores_used", type=str2bool, default=False) 315 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core") 316 | parser.add_argument("--cpu_core_end", type=int, default=32, help="end core") 317 | parser.add_argument("--data", type=str, default="clevrtex_full") 318 | parser.add_argument( 319 | "--data_root", 320 | type=str, 321 | default=None, 322 | help="optional. you can specify the dir path if the default path of each dataset is not appropritate one. Currently only applied to ImageNet", 323 | ) 324 | parser.add_argument("--batchsize", type=int, default=250) 325 | parser.add_argument("--num_workers", type=int, default=8) 326 | parser.add_argument( 327 | "--data_imsize", 328 | type=int, 329 | default=None, 330 | help="Image size. If None, use the default size of each dataset", 331 | ) 332 | 333 | # General model options 334 | parser.add_argument("--model", type=str, default="knet", help="model") 335 | parser.add_argument("--L", type=int, default=2, help="num of layers") 336 | parser.add_argument("--ch", type=int, default=256, help="num of channels") 337 | parser.add_argument( 338 | "--model_imsize", 339 | type=int, 340 | default=None, 341 | help=""" 342 | Model's imsize that was set when it was initialized. 343 | This is used when evaluating or when finetuning a pretrained model. 344 | """, 345 | ) 346 | parser.add_argument("--autorescale", type=str2bool, default=False) 347 | parser.add_argument("--psize", type=int, default=8, help="patch size") 348 | parser.add_argument("--ksize", type=int, default=1, help="kernel size") 349 | parser.add_argument("--T", type=int, default=8, help="num of recurrence") 350 | parser.add_argument( 351 | "--maxpool", type=str2bool, default=True, help="max pooling or avg pooling" 352 | ) 353 | parser.add_argument( 354 | "--heads", type=int, default=8, help="num of heads in self-attention" 355 | ) 356 | parser.add_argument( 357 | "--gta", 358 | type=str2bool, 359 | default=True, 360 | help=""" 361 | use Geometric Transform Attention (https://github.com/autonomousvision/gta) as positional encoding. 362 | If False, use standard absolute positional encoding 363 | """, 364 | ) 365 | 366 | # AKOrN options 367 | parser.add_argument("--N", type=int, default=4, help="num of rotating dimensions") 368 | parser.add_argument("--J", type=str, default="conv", help="connectivity") 369 | parser.add_argument("--use_omega", type=str2bool, default=False) 370 | parser.add_argument("--global_omg", type=str2bool, default=False) 371 | parser.add_argument( 372 | "--c_norm", 373 | type=str, 374 | default="gn", 375 | help="normalization. gn, sandb(scale and bias), or none", 376 | ) 377 | 378 | parser.add_argument( 379 | "--use_ro_x", 380 | type=str2bool, 381 | default=False, 382 | help="apply linear transform to oscillators between consecutive layers", 383 | ) 384 | 385 | # ablation of some components in the AKOrN's block 386 | parser.add_argument( 387 | "--no_ro", type=str2bool, default=False, help="ablation: no use readout module" 388 | ) 389 | parser.add_argument( 390 | "--project", 391 | type=str2bool, 392 | default=True, 393 | help="use projection or not in the Kuramoto layer", 394 | ) 395 | 396 | args = parser.parse_args() 397 | 398 | torch.backends.cudnn.benchmark = True 399 | torch.backends.cuda.enable_flash_sdp(enabled=True) 400 | 401 | if args.limit_cores_used: 402 | def worker_init_fn(worker_id): 403 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end)) 404 | 405 | if args.model == "akorn": 406 | net = AKOrN( 407 | args.N, 408 | ch=args.ch, 409 | L=args.L, 410 | T=args.T, 411 | J=args.J, # "conv" or "attn", 412 | use_omega=args.use_omega, 413 | global_omg=args.global_omg, 414 | c_norm=args.c_norm, 415 | psize=args.psize, 416 | imsize=args.model_imsize, 417 | autorescale=args.autorescale, 418 | maxpool=args.maxpool, 419 | project=args.project, 420 | heads=args.heads, 421 | use_ro_x=args.use_ro_x, 422 | no_ro=args.no_ro, 423 | gta=args.gta, 424 | ).to("cuda") 425 | elif args.model == "vit": 426 | net = ViT( 427 | psize=args.psize, 428 | imsize=args.model_imsize, 429 | autorescale=args.autorescale, 430 | ch=args.ch, 431 | blocks=args.L, 432 | heads=args.heads, 433 | mlp_dim=2 * args.ch, 434 | T=args.T, 435 | maxpool=args.maxpool, 436 | gta=args.gta, 437 | ).cuda() 438 | 439 | model = EMA(net) 440 | model.load_state_dict(torch.load(args.model_path, weights_only=True)["model_state_dict"]) 441 | model = model.ema_model 442 | 443 | with torch.no_grad(): 444 | scores, preds = eval_dataset( 445 | model, 446 | data=args.data, 447 | data_root=args.data_root, 448 | imsize=args.data_imsize, 449 | batchsize=args.batchsize, 450 | instance=True, 451 | method="agglomerative", 452 | saccade_r=args.saccade_r, 453 | pca=args.pca, 454 | ) 455 | print_stats(scores) 456 | -------------------------------------------------------------------------------- /eval_sudoku.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import torch 3 | import torch.nn 4 | import torch.optim 5 | import tqdm 6 | import torchvision 7 | from torchvision import transforms 8 | import numpy as np 9 | from torch.optim.swa_utils import AveragedModel 10 | import matplotlib.pyplot as plt 11 | 12 | from source.data.datasets.sudoku.sudoku import SudokuDataset, HardSudokuDataset 13 | from source.models.sudoku.knet import SudokuAKOrN 14 | from source.models.sudoku.transformer import SudokuTransformer 15 | from source.evals.sudoku.evals import compute_board_accuracy 16 | from source.utils import str2bool 17 | from ema_pytorch import EMA 18 | import argparse 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument("--model_path", type=str, help="path to the model") 25 | 26 | # Data loading 27 | parser.add_argument("--data", type=str, default="id", help="data") 28 | parser.add_argument("--limit_cores_used", type=str2bool, default=False) 29 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core") 30 | parser.add_argument("--cpu_core_end", type=int, default=16, help="end core") 31 | parser.add_argument( 32 | "--data_root", 33 | type=str, 34 | default=None, 35 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset", 36 | ) 37 | parser.add_argument("--batchsize", type=int, default=100) 38 | parser.add_argument("--num_workers", type=int, default=4) 39 | 40 | # General model options 41 | parser.add_argument("--model", type=str, default="akorn", help="model") 42 | parser.add_argument("--L", type=int, default=1, help="num of layers") 43 | parser.add_argument("--T", type=int, default=16, help="Timesteps") 44 | parser.add_argument("--ch", type=int, default=512, help="num of channels") 45 | parser.add_argument("--heads", type=int, default=8) 46 | 47 | # AKOrN options 48 | parser.add_argument("--N", type=int, default=4) 49 | parser.add_argument( 50 | "--K", 51 | type=int, 52 | default=1, 53 | help="num of random oscillator samples for each input", 54 | ) 55 | parser.add_argument("--minimum_chunk", type=int, default=None) 56 | parser.add_argument("--evote_type", type=str, default="last", help="last or sum") 57 | parser.add_argument("--gamma", type=float, default=1.0, help="step size") 58 | parser.add_argument("--J", type=str, default="attn", help="connectivity") 59 | parser.add_argument("--use_omega", type=str2bool, default=True) 60 | parser.add_argument("--global_omg", type=str2bool, default=True) 61 | parser.add_argument("--learn_omg", type=str2bool, default=False) 62 | parser.add_argument("--init_omg", type=float, default=0.1) 63 | parser.add_argument("--nl", type=str2bool, default=True) 64 | 65 | parser.add_argument("--speed_test", action="store_true") 66 | 67 | args = parser.parse_args() 68 | 69 | torch.backends.cudnn.benchmark = True 70 | torch.backends.cuda.enable_flash_sdp(enabled=True) 71 | 72 | if args.limit_cores_used: 73 | 74 | def worker_init_fn(worker_id): 75 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end)) 76 | 77 | else: 78 | worker_init_fn = None 79 | 80 | if args.data == "id": 81 | loader = torch.utils.data.DataLoader( 82 | SudokuDataset( 83 | args.data_root if args.data_root is not None else "./data/sudoku", 84 | train=False, 85 | ), 86 | batch_size=args.batchsize, 87 | shuffle=False, 88 | num_workers=args.num_workers, 89 | worker_init_fn=worker_init_fn, 90 | ) 91 | elif args.data == "ood": 92 | loader = torch.utils.data.DataLoader( 93 | HardSudokuDataset( 94 | args.data_root if args.data_root is not None else "./data/sudoku-rrn", 95 | split="test", 96 | ), 97 | batch_size=args.batchsize, 98 | shuffle=False, 99 | num_workers=args.num_workers, 100 | worker_init_fn=worker_init_fn, 101 | ) 102 | else: 103 | raise NotImplementedError 104 | 105 | if args.model == "akorn": 106 | print( 107 | f"n: {args.N}, ch: {args.ch}, L: {args.L}, T: {args.T}, type of J: {args.J}" 108 | ) 109 | net = SudokuAKOrN( 110 | n=args.N, 111 | ch=args.ch, 112 | L=args.L, 113 | T=args.T, 114 | gamma=args.gamma, 115 | J=args.J, 116 | use_omega=args.use_omega, 117 | global_omg=args.global_omg, 118 | init_omg=args.init_omg, 119 | learn_omg=args.learn_omg, 120 | nl=args.nl, 121 | heads=args.heads, 122 | ) 123 | elif args.model == "itrsa": 124 | net = SudokuTransformer( 125 | ch=args.ch, 126 | blocks=args.L, 127 | heads=args.heads, 128 | mlp_dim=args.ch * 2, 129 | T=args.T, 130 | gta=False, 131 | ) 132 | else: 133 | raise NotImplementedError 134 | 135 | model = EMA(net).cuda() 136 | model.load_state_dict( 137 | torch.load(args.model_path, weights_only=True)["model_state_dict"] 138 | ) 139 | model = model.ema_model 140 | model.eval() 141 | 142 | K = args.K 143 | 144 | corrects_vote = 0 145 | corrects_avg = 0 146 | totals = 0 147 | 148 | minimum_chunk = args.minimum_chunk if args.minimum_chunk is not None else K 149 | 150 | for i, (X, Y, is_input) in tqdm.tqdm(enumerate(loader)): 151 | B = X.shape[0] 152 | if args.model == 'akorn' and K > 1: # Energy-based voting 153 | for j in range(B): 154 | preds = [] 155 | es_list = [] 156 | for k in range(K//minimum_chunk): 157 | 158 | _X = X[j : j + 1].repeat(minimum_chunk, 1, 1, 1) 159 | _Y = Y[j : j + 1].repeat(minimum_chunk, 1, 1, 1) 160 | _is_input = is_input[j : j + 1].repeat(minimum_chunk, 1, 1, 1) 161 | _X, _Y, _is_input = ( 162 | _X.to(torch.int32).cuda(), 163 | _Y.cuda(), 164 | _is_input.cuda(), 165 | ) 166 | 167 | with torch.no_grad(): 168 | pred, es = model(_X, _is_input, return_es=True) 169 | preds.append(pred.detach()) 170 | if args.evote_type =='sum': 171 | # the sum of energy values over timesteps as board correctness indicator 172 | es = torch.stack(es[-1], 0).sum(0).detach() 173 | elif args.evote_type == 'last': 174 | es = es[-1][-1].detach() 175 | es_list.append(es) 176 | 177 | pred = torch.cat(preds, 0) 178 | es = torch.cat(es_list, 0) 179 | 180 | idxes = torch.argsort(es) # minimum energy first 181 | pred_vote = pred[idxes[:1]].mean(0, keepdim=True) 182 | pred_avg = pred.mean(0, keepdim=True) 183 | 184 | _, _, board_correct_vote = compute_board_accuracy( 185 | pred_vote, _Y[:1], _is_input[:1] 186 | ) 187 | 188 | corrects_vote += board_correct_vote.sum().item() 189 | totals += board_correct_vote.numel() 190 | 191 | else: 192 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda() 193 | with torch.no_grad(): 194 | pred = model(X, is_input) 195 | num_blanks, num_corrects, board_correct = compute_board_accuracy(pred, Y, is_input) 196 | corrects_vote += board_correct.sum().item() 197 | totals += board_correct.numel() 198 | 199 | # Compute mean and standard deviation across networks 200 | accuracy_vote = corrects_vote / totals 201 | 202 | print(f"Vote accuracy: {accuracy_vote:.4f}") 203 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | einops 5 | jupyter 6 | matplotlib 7 | tensorboard 8 | tqdm 9 | argparse 10 | git+https://github.com/fra31/auto-attack 11 | tensorflow-cpu 12 | ema_pytorch 13 | accelerate 14 | scipy 15 | scikit-learn 16 | scikit-image 17 | timm 18 | opencv-python 19 | pycocotools 20 | fastcluster 21 | -------------------------------------------------------------------------------- /scripts/sudoku.md: -------------------------------------------------------------------------------- 1 | ## Donwload the sudoku datasets 2 | ``` 3 | cd data 4 | bash download_satnet.sh 5 | bash download_rrn.sh 6 | cd .. 7 | ``` 8 | 9 | ## Training 10 | ### AKOrN 11 | ``` 12 | python train_sudoku.py --exp_name=sudoku_akorn --eval_freq=10 --epochs=100 --model=akorn --lr=0.001 --T=16 --use_omega=True --global_omg=True --init_omg=0.5 --learn_omg=True 13 | ``` 14 | 15 | ### ItrSA and Transformer 16 | ``` 17 | # ItrSA 18 | python train_sudoku.py --exp_name=sudoku_itrsa --eval_freq=10 --epochs=100 --model=itrsa --lr=0.0005 --T=16 19 | # Transformer 20 | python train_sudoku.py --exp_name=sudoku_itrsa --eval_freq=10 --epochs=100 --model=itrsa --lr=0.0005 --T=1 --L=8 21 | ``` 22 | 23 | ## Evaluation 24 | 25 | ### Inference with the test-time extension of the Kuramoto updates($T_{\rm eval}=128$) . 26 | ``` 27 | export data=ood # id or ood 28 | python eval_sudoku.py --data=${data} --model=akorn --model_path=runs/sudoku_akorn/ema_99.pth --T=128 29 | ``` 30 | 31 | ### Test-time extension of the K-updates + Energy-based voting ($T_{\rm eval}=128, (num\ random\ samples)=100$ ) . 32 | ``` 33 | python eval_sudoku.py --data=${data} --model=akorn --model_path=runs/sudoku_akorn/ema_99.pth --T=128 --K=100 --evote_type=sum 34 | ``` 35 | 36 | 37 | ### Performance table 38 | | Model | ID | OOD | 39 | |-------------------------|--------------------|-----------------| 40 | | Transformer | 98.6±0.3 | 5.2±0.2 | 41 | | ItrSA ($T_{\rm eval}=32$) | 95.7±8.5 | 34.4±5.4 | 42 | | AKOrN ($T_{\rm eval}=128$) | 100.0±0.0 | 51.7±3.3 | 43 | | AKOrN ($T_{\rm eval}=128, K=100$) | 100.0±0.0 | 81.6±1.5 | 44 | | AKOrN ($T_{\rm eval}=128, K=4096$) | 100.0±0.0 | 89.5±2.5 | 45 | 46 | ### Visualization of oscillator dynamics over timesteps 47 | ![sudoku](https://github.com/user-attachments/assets/97f9e6ed-0667-40c9-93a8-c45b5886b43b) 48 | 49 | -------------------------------------------------------------------------------- /scripts/synths.md: -------------------------------------------------------------------------------- 1 | ## Donwload the synthetic datasets (Tetrominoes, dSprits, CLEVR) 2 | ``` 3 | # Need gsutil installed. run `conda install conda-forge::gsutil` to install it or manually download datasets from https://console.cloud.google.com/storage/browser/multi-object-datasets;tab=objects?pli=1&inv=1&invt=AbjJBg&prefix=&forceOnObjectsSortingFiltering=false 4 | cd data 5 | bash download_synths.sh 6 | cd .. 7 | ``` 8 | 9 | ## Training 10 | ``` 11 | export NUM_GPUS= # If you use a single GPU, run a command without the multi GPU option arguments (`--multi-gpu --num_processes=$NUM_GPUS`). 12 | ``` 13 | 14 | #### AKOrN (Attentive models) 15 | ``` 16 | #Tetrominoes 17 | export dataset=tetrominoes; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=128 --psize=4 --epochs=50 --c_norm=none 18 | #dSprites 19 | export dataset=dsprites; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=128 --psize=4 --epochs=50 --c_norm=none 20 | #CLEVR 21 | export dataset=clevr accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=256 --psize=8 --epochs=300 --c_norm=none 22 | ``` 23 | 24 | #### AKOrN (Convolutional models) 25 | ``` 26 | #Tetrominoes 27 | export dataset=tetrominoes; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=5 --ch=128 --psize=4 --epochs=50 --c_norm=none 28 | #dSprites 29 | export dataset=dsprites; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=7 --ch=128 --psize=4 --epochs=50 --c_norm=none 30 | #CLEVR 31 | export dataset=clevr accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=7 --ch=256 --psize=8 --epochs=300 --c_norm=none 32 | ``` 33 | 34 | #### ItrSA 35 | ``` 36 | export dataset=tetrominoes; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=128 --psize=4 --epochs=50 37 | export dataset=dsprites; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=128 --psize=4 --epochs=50 38 | export dataset=clevr; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=256 --psize=8 --epochs=300 39 | ``` 40 | 41 | ## Evaluation 42 | 43 | ``` 44 | export DATA=tetrominoes #{tetrominoes, dsprits, clevr}. Please adjust the model parameters (–model, –J, –ch, –psize) based on the dataset and model you want to evaluate. 45 | export IMSIZE=32 # {32:tetrominoes, 64:dsprites, 128:clevr}. 46 | python eval_obj.py --model=akorn --data=${DATA} --J=attn --L=$L$ --model_path=runs/${dataset}_akorn_attn --model_imsize=32 --J=attn --L=1 --T=8 --ch=128 --psize=4 --c_norm=none 47 | ``` 48 | 49 | #### Performance table 50 | | Model | Tetrominoes FG-ARI | Tetrominoes MBO | dSprites FG-ARI | dSprites MBO | CLEVR FG-ARI | CLEVR MBO | 51 | |-------------------------|--------------------|-----------------|-----------------|--------------|--------------|-----------| 52 | | ItrConv | 59.0±2.9 | 51.6±2.2 | 29.1±6.2 | 38.5±5.2 | 49.3±5.1 | 29.7±3.0 | 53 | | AKOrN (conv) | 76.4±0.8 | 51.9±1.5 | 63.8±7.7 | 50.7±4.7 | 59.0±4.3 | 44.4±2.0 | 54 | | ItrSA | 85.8±0.8 | 54.9±3.4 | 68.1±1.4 | 63.0±1.2 | 82.5±1.7 | 39.4±1.9 | 55 | | AKOrN (attn) | 88.6±1.7 | 56.4±0.9 | 78.3±1.3 | 63.0±1.8 | 91.0±0.5 | 45.5±1.4 | 56 | 57 | ##### (+up-tiling (×4)) 58 | | Model | Tetrominoes FG-ARI | Tetrominoes MBO | dSprites FG-ARI | dSprites MBO | CLEVR FG-ARI | CLEVR MBO | 59 | |-------------------------|--------------------|-----------------|-----------------|--------------|--------------|-----------| 60 | | AKOrN^attn | 93.1±0.3 | 56.3±0.0 | 87.1±1.0 | 60.2±1.9 | 94.6±0.7 | 44.7±0.7 | 61 | -------------------------------------------------------------------------------- /source/data/augs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | 5 | def gauss_noise_tensor(sigma=0.1): 6 | def fn(img): 7 | out = img + sigma * torch.randn_like(img) 8 | out = torch.clamp(out, 0, 1) # pixel space is [0, 1] 9 | return out 10 | 11 | return fn 12 | 13 | 14 | def augmentation_strong(noise=0.0, imsize=32): 15 | transform_aug = transforms.Compose( 16 | [ 17 | transforms.RandomHorizontalFlip(), 18 | transforms.RandomResizedCrop(imsize, scale=(0.2, 1.0)), 19 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), 20 | transforms.AugMix(), 21 | transforms.ToTensor(), 22 | gauss_noise_tensor(noise) if noise > 0 else lambda x: x, 23 | ] 24 | ) 25 | return transform_aug 26 | 27 | 28 | def simclr_augmentation(imsize, hflip=False): 29 | return transforms.Compose( 30 | [ 31 | transforms.RandomResizedCrop(imsize), 32 | transforms.RandomHorizontalFlip(0.5) if hflip else lambda x: x, 33 | get_color_distortion(s=0.5), 34 | transforms.ToTensor(), 35 | ] 36 | ) 37 | 38 | 39 | def random_Linf_noise(trnsfms: transforms.Compose = None, epsilon=64 / 255): 40 | if trnsfms is None: 41 | trnsfms = transforms.Compose([transforms.ToTensor()]) 42 | 43 | randeps = torch.rand(1).item() * epsilon 44 | 45 | def fn(x): 46 | x = x + randeps * torch.randn_like(x).sign() 47 | return torch.clamp(x, 0, 1) 48 | 49 | trnsfms.transforms.append(fn) 50 | return trnsfms 51 | 52 | 53 | def get_color_distortion(s=0.5): 54 | # s is the strength of color distortion 55 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 56 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 57 | rnd_gray = transforms.RandomGrayscale(p=0.2) 58 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 59 | return color_distort 60 | -------------------------------------------------------------------------------- /source/data/datasets/objs/clevr.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from torchvision import transforms 4 | 5 | from source.data.augs import simclr_augmentation 6 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset 7 | 8 | 9 | def get_clevr_pair(root, split="train", imsize=128, hflip=False): 10 | path = Path(root, f"clevr_{split}.npz") 11 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip)) 12 | 13 | 14 | def get_clevr(root, split="train", imsize=128): 15 | path = Path(root, f"clevr_{split}.npz") 16 | transform = transforms.Compose( 17 | [ 18 | transforms.Resize(imsize), 19 | transforms.ToTensor(), 20 | ] 21 | ) 22 | return NumpyDataset(path, transform=transform) 23 | -------------------------------------------------------------------------------- /source/data/datasets/objs/clevr_tex.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/karazijal/clevrtex-generation/blob/main/clevrtex_eval.py 2 | import json 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision.transforms.functional as Ft 10 | from PIL import Image 11 | from scipy.optimize import linear_sum_assignment 12 | from sklearn.metrics import adjusted_rand_score 13 | 14 | 15 | class DatasetReadError(ValueError): 16 | pass 17 | 18 | 19 | def collate_fn(batch): 20 | return ( 21 | *torch.utils.data._utils.collate.default_collate( 22 | [(b[0], b[1], b[2]) for b in batch] 23 | ), 24 | [b[3] for b in batch], 25 | ) 26 | 27 | 28 | class CLEVRTEX: 29 | ccrop_frac = 0.8 30 | splits = {"test": (0.0, 0.1), "val": (0.1, 0.2), "train": (0.2, 1.0)} 31 | shape = (3, 240, 320) 32 | variants = {"full", "pbg", "vbg", "grassbg", "camo", "outd"} 33 | 34 | def _index_with_bias_and_limit(self, idx): 35 | if idx >= 0: 36 | idx += self.bias 37 | if idx >= self.limit: 38 | raise IndexError() 39 | else: 40 | idx = self.limit + idx 41 | if idx < self.bias: 42 | raise IndexError() 43 | return idx 44 | 45 | def _reindex(self): 46 | print(f"Indexing {self.basepath}") 47 | 48 | img_index = {} 49 | msk_index = {} 50 | met_index = {} 51 | 52 | prefix = f"CLEVRTEX_{self.dataset_variant}_" 53 | 54 | img_suffix = ".png" 55 | msk_suffix = "_flat.png" 56 | met_suffix = ".json" 57 | 58 | _max = 0 59 | for img_path in self.basepath.glob(f"**/{prefix}??????{img_suffix}"): 60 | indstr = img_path.name.replace(prefix, "").replace(img_suffix, "") 61 | msk_path = img_path.parent / f"{prefix}{indstr}{msk_suffix}" 62 | met_path = img_path.parent / f"{prefix}{indstr}{met_suffix}" 63 | indstr_stripped = indstr.lstrip("0") 64 | if indstr_stripped: 65 | ind = int(indstr) 66 | else: 67 | ind = 0 68 | if ind > _max: 69 | _max = ind 70 | 71 | if not msk_path.exists(): 72 | raise DatasetReadError(f"Missing {msk_suffix.name}") 73 | 74 | if ind in img_index: 75 | raise DatasetReadError(f"Duplica {ind}") 76 | 77 | img_index[ind] = img_path 78 | msk_index[ind] = msk_path 79 | if self.return_metadata: 80 | if not met_path.exists(): 81 | raise DatasetReadError(f"Missing {met_path.name}") 82 | met_index[ind] = met_path 83 | else: 84 | met_index[ind] = None 85 | 86 | if len(img_index) == 0: 87 | raise DatasetReadError(f"No values found") 88 | missing = [i for i in range(0, _max) if i not in img_index] 89 | if missing: 90 | raise DatasetReadError(f"Missing images numbers {missing}") 91 | 92 | return img_index, msk_index, met_index 93 | 94 | def _variant_subfolder(self): 95 | return f"clevrtex_{self.dataset_variant.lower()}" 96 | 97 | def __init__( 98 | self, 99 | path: Path, 100 | dataset_variant="full", 101 | split="train", 102 | crop=True, 103 | resize=(128, 128), 104 | return_metadata=True, 105 | transform=None, 106 | ): 107 | self.transform = transform 108 | self.return_metadata = return_metadata 109 | self.crop = crop 110 | self.resize = resize 111 | if dataset_variant not in self.variants: 112 | raise DatasetReadError( 113 | f"Unknown variant {dataset_variant}; [{', '.join(self.variants)}] available " 114 | ) 115 | 116 | if split not in self.splits: 117 | raise DatasetReadError( 118 | f"Unknown split {split}; [{', '.join(self.splits)}] available " 119 | ) 120 | if dataset_variant == "outd": 121 | # No dataset splits in 122 | split = None 123 | 124 | self.dataset_variant = dataset_variant 125 | self.split = split 126 | 127 | self.basepath = Path(path) 128 | if not self.basepath.exists(): 129 | raise DatasetReadError() 130 | sub_fold = self._variant_subfolder() 131 | if self.basepath.name != sub_fold: 132 | self.basepath = self.basepath / sub_fold 133 | # try: 134 | # with (self.basepath / 'manifest_ind.json').open('r') as inf: 135 | # self.index = json.load(inf) 136 | # except (json.JSONDecodeError, IOError, FileNotFoundError): 137 | self.index, self.mask_index, self.metadata_index = self._reindex() 138 | 139 | print(f"Sourced {dataset_variant} ({split}) from {self.basepath}") 140 | 141 | bias, limit = self.splits.get(split, (0.0, 1.0)) 142 | if isinstance(bias, float): 143 | bias = int(bias * len(self.index)) 144 | if isinstance(limit, float): 145 | limit = int(limit * len(self.index)) 146 | self.limit = limit 147 | self.bias = bias 148 | 149 | def _format_metadata(self, meta): 150 | """ 151 | Drop unimportanat, unsued or incorrect data from metadata. 152 | Data may become incorrect due to transformations, 153 | such as cropping and resizing would make pixel coordinates incorrect. 154 | Furthermore, only VBG dataset has color assigned to objects, we delete the value for others. 155 | """ 156 | objs = [] 157 | for obj in meta["objects"]: 158 | o = { 159 | "material": obj["material"], 160 | "shape": obj["shape"], 161 | "size": obj["size"], 162 | "rotation": obj["rotation"], 163 | } 164 | if self.dataset_variant == "vbg": 165 | o["color"] = obj["color"] 166 | objs.append(o) 167 | return {"ground_material": meta["ground_material"], "objects": objs} 168 | 169 | def __len__(self): 170 | return self.limit - self.bias 171 | 172 | def __getitem__(self, ind): 173 | ind = self._index_with_bias_and_limit(ind) 174 | 175 | img = Image.open(self.index[ind]).convert("RGB") 176 | msk = Image.open(self.mask_index[ind]) 177 | 178 | if self.crop: 179 | crop_size = int(0.8 * float(min(img.width, img.height))) 180 | img = img.crop( 181 | ( 182 | (img.width - crop_size) // 2, 183 | (img.height - crop_size) // 2, 184 | (img.width + crop_size) // 2, 185 | (img.height + crop_size) // 2, 186 | ) 187 | ) 188 | msk = msk.crop( 189 | ( 190 | (msk.width - crop_size) // 2, 191 | (msk.height - crop_size) // 2, 192 | (msk.width + crop_size) // 2, 193 | (msk.height + crop_size) // 2, 194 | ) 195 | ) 196 | if self.resize: 197 | img = img.resize(self.resize, resample=Image.BILINEAR) 198 | msk = msk.resize(self.resize, resample=Image.NEAREST) 199 | 200 | if self.transform is not None: 201 | img = self.transform(np.array(img)) 202 | # img = Ft.to_tensor(np.array(img)[..., :3]) 203 | msk = torch.from_numpy(np.array(msk))[None] 204 | 205 | ret = (ind, img, msk) 206 | 207 | if self.return_metadata: 208 | with self.metadata_index[ind].open("r") as inf: 209 | meta = json.load(inf) 210 | ret = (ind, img, msk, self._format_metadata(meta)) 211 | 212 | return ret 213 | 214 | 215 | def collate_fn(batch): 216 | return ( 217 | *torch.utils.data._utils.collate.default_collate( 218 | [(b[0], b[1], b[2]) for b in batch] 219 | ), 220 | [b[3] for b in batch], 221 | ) 222 | 223 | 224 | class RunningMean: 225 | def __init__(self): 226 | self.v = 0.0 227 | self.n = 0 228 | 229 | def update(self, v, n=1): 230 | self.v += v * n 231 | self.n += n 232 | 233 | def value(self): 234 | if self.n: 235 | return self.v / (self.n) 236 | else: 237 | return float("nan") 238 | 239 | def __str__(self): 240 | return str(self.value()) 241 | 242 | 243 | class CLEVRTEX_Evaluator: 244 | def __init__(self, masks_have_background=True): 245 | self.masks_have_background = masks_have_background 246 | self.stats = defaultdict(RunningMean) 247 | self.tags = defaultdict(lambda: defaultdict(lambda: defaultdict(RunningMean))) 248 | 249 | def ari(self, pred_mask, true_mask, skip_0=False): 250 | B = pred_mask.shape[0] 251 | pm = pred_mask.argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy() 252 | tm = true_mask.argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy() 253 | aris = [] 254 | for bi in range(B): 255 | t = tm[bi] 256 | p = pm[bi] 257 | if skip_0: 258 | p = p[t > 0] 259 | t = t[t > 0] 260 | ari_score = adjusted_rand_score(t, p) 261 | if ari_score != ari_score: 262 | print(f"NaN at bi") 263 | aris.append(ari_score) 264 | aris = torch.tensor(np.array(aris), device=pred_mask.device) 265 | return aris 266 | 267 | def msc(self, pred_mask, true_mask): 268 | B = pred_mask.shape[0] 269 | bpm = pred_mask.argmax(axis=1).squeeze() 270 | btm = true_mask.argmax(axis=1).squeeze() 271 | covering = torch.zeros(B, device=pred_mask.device, dtype=torch.float) 272 | for bi in range(B): 273 | score = 0.0 274 | norms = 0.0 275 | for ti in range(btm[bi].max()): 276 | tm = btm[bi] == ti 277 | if not torch.any(tm): 278 | continue 279 | iou_max = 0.0 280 | for pi in range(bpm[bi].max()): 281 | pm = bpm[bi] == pi 282 | if not torch.any(pm): 283 | continue 284 | iou = (tm & pm).to(torch.float).sum() / (tm | pm).to( 285 | torch.float 286 | ).sum() 287 | if iou > iou_max: 288 | iou_max = iou 289 | r = tm.to(torch.float).sum() 290 | score += r * iou_max 291 | norms += r 292 | covering[bi] = score / norms 293 | return covering 294 | 295 | def reindex(self, tensor, reindex_tensor, dim=1): 296 | """ 297 | Reindexes tensor along using reindex_tensor. 298 | Effectivelly permutes for each dimensions 1), "Predicted masks sum out to more than 1." 392 | if not self.masks_have_background: 393 | # Some models predict only foreground masks. 394 | # For convenienve we calculate background masks. 395 | pred_masks = torch.cat([1.0 - total_pred_masks, pred_masks], dim=1) 396 | 397 | # Decide the masks Should we effectivelly threshold them? 398 | K = pred_masks.shape[1] 399 | pred_masks = pred_masks.argmax(dim=1) 400 | pred_masks = ( 401 | pred_masks.unsqueeze(1) 402 | == torch.arange(K, device=pred_masks.device).view(1, -1, 1, 1, 1) 403 | ).to(torch.float) 404 | # Coerce true_Masks into known form 405 | if len(true_masks.shape) == 4: 406 | if true_masks.shape[1] == 1: 407 | # Need to expand into masks 408 | true_masks = ( 409 | true_masks.unsqueeze(1) 410 | == torch.arange( 411 | max(true_masks.max() + 1, pred_masks.shape[1]), 412 | device=true_masks.device, 413 | ).view(1, -1, 1, 1, 1) 414 | ).to(pred_image.dtype) 415 | else: 416 | true_masks = true_masks.unsqueeze(2) 417 | true_masks = true_masks.view(pred_image.shape[0], -1, 1, *pred_image.shape[-2:]) 418 | 419 | K = max(true_masks.shape[1], pred_masks.shape[1]) 420 | if true_masks.shape[1] < K: 421 | true_masks = torch.cat( 422 | [ 423 | true_masks, 424 | true_masks.new_zeros( 425 | true_masks.shape[0], 426 | K - true_masks.shape[1], 427 | 1, 428 | *true_masks.shape[-2:], 429 | ), 430 | ], 431 | dim=1, 432 | ) 433 | if pred_masks.shape[1] < K: 434 | pred_masks = torch.cat( 435 | [ 436 | pred_masks, 437 | pred_masks.new_zeros( 438 | pred_masks.shape[0], 439 | K - pred_masks.shape[1], 440 | 1, 441 | *pred_masks.shape[-2:], 442 | ), 443 | ], 444 | dim=1, 445 | ) 446 | 447 | mse = F.mse_loss(pred_image, true_image, reduction="none").sum((1, 2, 3)) 448 | self.add_statistic("MSE", mse) 449 | 450 | # If argmax above, these masks are either 0 or 1 451 | pred_count = ( 452 | (pred_masks >= 0.5).any(-1).any(-1).any(-1).to(torch.float).sum(-1) 453 | ) # shape: (B,) 454 | true_count = ( 455 | (true_masks >= 0.5).any(-1).any(-1).any(-1).to(torch.float).sum(-1) 456 | ) # shape: (B,) 457 | accuracy = (true_count == pred_count).to(torch.float) 458 | self.add_statistic("acc", accuracy) 459 | 460 | pred_reindex, ious, _ = self.ious_alignment(pred_masks, true_masks) 461 | pred_masks = self.reindex(pred_masks, pred_reindex, dim=1) 462 | 463 | truem = true_masks.any(-1).any(-1).any(-1) 464 | predm = pred_masks.any(-1).any(-1).any(-1) 465 | 466 | vism = truem | predm 467 | num_pairs = vism.to(torch.float).sum(-1) 468 | 469 | # mIoU 470 | mIoU = ious.sum(-1) / num_pairs 471 | mIoU_fg = ious[:, 1:].sum(-1) / ( 472 | num_pairs - 1 473 | ) # do not consider the background 474 | mIoU_gt = ious.sum(-1) / truem.to(torch.float).sum(-1) 475 | 476 | self.add_statistic("mIoU", mIoU) 477 | self.add_statistic("mIoU_fg", mIoU_fg) 478 | self.add_statistic("mIoU_gt", mIoU_gt) 479 | 480 | msc = self.msc(pred_masks, true_masks) 481 | self.add_statistic("mSC", msc) 482 | 483 | # DICE 484 | dices = ( 485 | 2 486 | * (pred_masks * true_masks).sum((-3, -2, -1)) 487 | / (pred_masks.sum((-3, -2, -1)) + true_masks.sum((-3, -2, -1))) 488 | ) 489 | dices = torch.nan_to_num( 490 | dices, nan=0.0, posinf=0.0 491 | ) # if there were any empties, they now have 0. DICE 492 | 493 | dice = dices.sum(-1) / num_pairs 494 | dice_fg = dices[:, 1:].sum(-1) / (num_pairs - 1) 495 | self.add_statistic("DICE", dice) 496 | self.add_statistic("DICE_FG", dice_fg) 497 | 498 | # ARI 499 | ari = self.ari(pred_masks, true_masks) 500 | ari_fg = self.ari(pred_masks, true_masks, skip_0=True) 501 | if torch.any(torch.isnan(ari_fg)): 502 | print("NaN ari_fg") 503 | if torch.any(torch.isinf(ari_fg)): 504 | print("Inf ari_fg") 505 | self.add_statistic("ARI", ari) 506 | self.add_statistic("ARI_FG", ari_fg) 507 | 508 | # mAP --? 509 | 510 | if true_metadata is not None: 511 | smses = F.mse_loss( 512 | pred_image[:, None] * true_masks, 513 | true_image[:, None] * true_masks, 514 | reduction="none", 515 | ).sum((-1, -2, -3)) 516 | 517 | for bi, meta in enumerate(true_metadata): 518 | # ground 519 | self.add_statistic( 520 | "ground_mse", smses[bi, 0], ground_material=meta["ground_material"] 521 | ) 522 | self.add_statistic( 523 | "ground_iou", ious[bi, 0], ground_material=meta["ground_material"] 524 | ) 525 | 526 | for i, obj in enumerate(meta["objects"]): 527 | tags = {k: v for k, v in obj.items() if k != "rotation"} 528 | if truem[bi, i + 1]: 529 | self.add_statistic("obj_mse", smses[bi, i + 1], **tags) 530 | self.add_statistic("obj_iou", ious[bi, i + 1], **tags) 531 | # Maybe number of components? 532 | return pred_masks, true_masks 533 | 534 | 535 | class CLEVRTEXPair(CLEVRTEX): 536 | """Generate mini-batche pairs on CIFAR10 training set.""" 537 | 538 | def __getitem__(self, ind): 539 | ind = self._index_with_bias_and_limit(ind) 540 | 541 | img = Image.open(self.index[ind]).convert("RGB") 542 | 543 | if self.crop: 544 | crop_size = int(0.8 * float(min(img.width, img.height))) 545 | img = img.crop( 546 | ( 547 | (img.width - crop_size) // 2, 548 | (img.height - crop_size) // 2, 549 | (img.width + crop_size) // 2, 550 | (img.height + crop_size) // 2, 551 | ) 552 | ) 553 | if self.resize: 554 | img = img.resize(self.resize, resample=Image.BILINEAR) 555 | imgs = [self.transform(img), self.transform(img)] 556 | return torch.stack(imgs) # stack a positive pair 557 | 558 | 559 | import torchvision 560 | from torchvision import transforms 561 | from PIL import Image 562 | from source.data.augs import simclr_augmentation 563 | 564 | 565 | def get_clevrtex_pair(root, split="train", imsize=128, hflip=False): 566 | 567 | return CLEVRTEXPair( 568 | root, 569 | "full", 570 | split, 571 | crop=True, 572 | resize=(imsize, imsize), 573 | return_metadata=False, 574 | transform=simclr_augmentation(imsize=imsize, hflip=hflip), 575 | ) 576 | 577 | 578 | def get_clevrtex( 579 | root, split="test", data_type="full", return_meta_data=False, imsize=128 580 | ): 581 | return CLEVRTEX( 582 | root, 583 | data_type, 584 | split, 585 | crop=True, 586 | resize=(imsize, imsize), 587 | return_metadata=return_meta_data, 588 | transform=torchvision.transforms.ToTensor(), 589 | ) 590 | -------------------------------------------------------------------------------- /source/data/datasets/objs/coco.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/Wuziyi616/SlotDiffusion 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from pycocotools.coco import COCO 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | import torchvision.utils as vutils 12 | 13 | 14 | import torch 15 | import torchvision.transforms as transforms 16 | 17 | 18 | def suppress_mask_idx(masks): 19 | """Make the mask index 0, 1, 2, ...""" 20 | # the original mask could have not continuous index, 0, 3, 4, 6, 9, 13, ... 21 | # we make them 0, 1, 2, 3, 4, 5, ... 22 | if isinstance(masks, np.ndarray): 23 | pkg = np 24 | elif isinstance(masks, torch.Tensor): 25 | pkg = torch 26 | else: 27 | raise NotImplementedError 28 | obj_idx = pkg.unique(masks) 29 | idx_mapping = pkg.arange(obj_idx.max() + 1) 30 | idx_mapping[obj_idx] = pkg.arange(len(obj_idx)) 31 | masks = idx_mapping[masks] 32 | return masks 33 | 34 | 35 | class RandomHorizontalFlip: 36 | """Flip the image and bbox horizontally.""" 37 | 38 | def __init__(self, prob=0.5): 39 | self.prob = prob 40 | 41 | def __call__(self, sample): 42 | # [H, W, 3], [H, W(, 2)], [N, 5] 43 | image, masks, annos, scale, size = ( 44 | sample["image"], 45 | sample["masks"], 46 | sample["annos"], 47 | sample["scale"], 48 | sample["size"], 49 | ) 50 | 51 | if np.random.uniform(0, 1) < self.prob: 52 | image = np.ascontiguousarray(image[:, ::-1, :]) 53 | masks = np.ascontiguousarray(masks[:, ::-1]) 54 | _, w, _ = image.shape 55 | # adjust annos 56 | if annos.shape[0] > 0: 57 | x1 = annos[:, 0].copy() 58 | x2 = annos[:, 2].copy() 59 | annos[:, 0] = w - x2 60 | annos[:, 2] = w - x1 61 | 62 | return { 63 | "image": image, 64 | "masks": masks, 65 | "annos": annos, 66 | "scale": scale, 67 | "size": size, 68 | } 69 | 70 | 71 | class ResizeMinShape: 72 | """Resize for later center crop.""" 73 | 74 | def __init__(self, resolution=(224, 224)): 75 | self.resolution = resolution 76 | 77 | def __call__(self, sample): 78 | image, masks, annos, scale, size = ( 79 | sample["image"], 80 | sample["masks"], 81 | sample["annos"], 82 | sample["scale"], 83 | sample["size"], 84 | ) 85 | h, w, _ = image.shape 86 | # resize so that the h' is at lease resolution[0] 87 | # and the w' is at lease resolution[1] 88 | factor = max(self.resolution[0] / h, self.resolution[1] / w) 89 | resize_h, resize_w = int(round(h * factor)), int(round(w * factor)) 90 | image = cv2.resize(image, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) 91 | masks = cv2.resize(masks, (resize_w, resize_h), interpolation=cv2.INTER_NEAREST) 92 | # adjust annos 93 | factor = float(factor) 94 | annos[:, :4] *= factor 95 | scale *= factor 96 | return { 97 | "image": image, 98 | "masks": masks, 99 | "annos": annos, 100 | "scale": scale, 101 | "size": size, 102 | } 103 | 104 | 105 | class CenterCrop: 106 | """Crop the center square of the image.""" 107 | 108 | def __init__(self, resolution=(224, 224)): 109 | self.resolution = resolution 110 | 111 | def __call__(self, sample): 112 | image, masks, annos, scale, size = ( 113 | sample["image"], 114 | sample["masks"], 115 | sample["annos"], 116 | sample["scale"], 117 | sample["size"], 118 | ) 119 | 120 | h, w, _ = image.shape 121 | assert h >= self.resolution[0] and w >= self.resolution[1] 122 | assert h == self.resolution[0] or w == self.resolution[1] 123 | 124 | if h == self.resolution[0]: 125 | crop_ymin = 0 126 | crop_ymax = h 127 | crop_xmin = (w - self.resolution[0]) // 2 128 | crop_xmax = crop_xmin + self.resolution[0] 129 | else: 130 | crop_xmin = 0 131 | crop_xmax = w 132 | crop_ymin = (h - self.resolution[1]) // 2 133 | crop_ymax = crop_ymin + self.resolution[1] 134 | image = image[crop_ymin:crop_ymax, crop_xmin:crop_xmax] 135 | masks = masks[crop_ymin:crop_ymax, crop_xmin:crop_xmax] 136 | # adjust annos 137 | if annos.shape[0] > 0: 138 | annos[:, [0, 2]] = annos[:, [0, 2]] - crop_xmin 139 | annos[:, [1, 3]] = annos[:, [1, 3]] - crop_ymin 140 | # filter out annos that are out of the image 141 | mask1 = (annos[:, 2] > 0) & (annos[:, 3] > 0) 142 | mask2 = (annos[:, 0] < self.resolution[0]) & ( 143 | annos[:, 1] < self.resolution[1] 144 | ) 145 | annos = annos[mask1 & mask2] 146 | annos[:, [0, 2]] = np.clip(annos[:, [0, 2]], 0, self.resolution[0]) 147 | annos[:, [1, 3]] = np.clip(annos[:, [1, 3]], 0, self.resolution[1]) 148 | 149 | return { 150 | "image": image, 151 | "masks": masks, 152 | "annos": annos, 153 | "scale": scale, 154 | "size": size, 155 | } 156 | 157 | 158 | class Normalize: 159 | """Normalize the image with mean and std.""" 160 | 161 | def __init__(self, mean=0.5, std=0.5): 162 | if isinstance(mean, (list, tuple)): 163 | mean = np.array(mean, dtype=np.float32)[None, None] # [1, 1, 3] 164 | if isinstance(std, (list, tuple)): 165 | std = np.array(std, dtype=np.float32)[None, None] # [1, 1, 3] 166 | self.mean = mean 167 | self.std = std 168 | 169 | def normalize_image(self, image): 170 | image = image.astype(np.float32) / 255.0 171 | image = (image - self.mean) / self.std 172 | return image 173 | 174 | def denormalize_image(self, image): 175 | # simple numbers 176 | if isinstance(self.mean, (int, float)) and isinstance(self.std, (int, float)): 177 | image = image * self.std + self.mean 178 | return image.clamp(0, 1) 179 | # need to convert the shapes 180 | mean = image.new_tensor(self.mean.squeeze()) # [3] 181 | std = image.new_tensor(self.std.squeeze()) # [3] 182 | if image.shape[-1] == 3: # C last 183 | mean = mean[None, None] # [1, 1, 3] 184 | std = std[None, None] # [1, 1, 3] 185 | else: # C first 186 | mean = mean[:, None, None] # [3, 1, 1] 187 | std = std[:, None, None] # [3, 1, 1] 188 | if len(image.shape) == 4: # [B, C, H, W] or [B, H, W, C], batch dim 189 | mean = mean[None] 190 | std = std[None] 191 | image = image * self.std + self.mean 192 | return image.clamp(0, 1) 193 | 194 | def __call__(self, sample): 195 | # [H, W, C] 196 | image, masks, annos, scale, size = ( 197 | sample["image"], 198 | sample["masks"], 199 | sample["annos"], 200 | sample["scale"], 201 | sample["size"], 202 | ) 203 | image = self.normalize_image(image) 204 | # make mask index start from 0 and continuous 205 | # `masks` is [H, W(, 2 or 3)] 206 | if len(masks.shape) == 3: 207 | assert masks.shape[-1] in [2, 3] 208 | # we don't suppress the last mask since it is the overlapping mask 209 | # i.e. regions with overlapping instances 210 | for i in range(masks.shape[-1] - 1): 211 | masks[:, :, i] = suppress_mask_idx(masks[:, :, i]) 212 | else: 213 | masks = suppress_mask_idx(masks) 214 | return { 215 | "image": image, 216 | "masks": masks, 217 | "annos": annos, 218 | "scale": scale, 219 | "size": size, 220 | } 221 | 222 | 223 | class COCOCollater: 224 | """Collect images, annotations, etc. into a batch.""" 225 | 226 | def __init__(self): 227 | pass 228 | 229 | def __call__(self, data): 230 | images = [s["image"] for s in data] 231 | masks = [s["masks"] for s in data] 232 | annos = [s["annos"] for s in data] 233 | scales = [s["scale"] for s in data] 234 | sizes = [s["size"] for s in data] 235 | 236 | images = np.stack(images, axis=0) # [B, H, W, C] 237 | images = torch.from_numpy(images).permute(0, 3, 1, 2) # [B, C, H, W] 238 | 239 | masks = np.stack(masks, axis=0) 240 | masks = torch.from_numpy(masks) # [B, H, W(, 2 or 3)] 241 | 242 | max_annos_num = max(anno.shape[0] for anno in annos) 243 | if max_annos_num > 0: 244 | input_annos = np.ones((len(annos), max_annos_num, 5), dtype=np.float32) * ( 245 | -1 246 | ) 247 | for i, anno in enumerate(annos): 248 | if anno.shape[0] > 0: 249 | input_annos[i, : anno.shape[0], :] = anno 250 | else: 251 | input_annos = np.ones((len(annos), 1, 5), dtype=np.float32) * (-1) 252 | input_annos = torch.from_numpy(input_annos).float() 253 | 254 | scales = torch.from_numpy(np.array(scales)).float() 255 | sizes = torch.from_numpy(np.array(sizes)).float() 256 | 257 | data_dict = { 258 | "img": images.contiguous().float(), 259 | "masks": masks.contiguous().long(), 260 | "annos": input_annos, 261 | "scale": scales, 262 | "size": sizes, 263 | } 264 | if len(masks.shape) == 4: 265 | assert masks.shape[-1] in [2, 3] 266 | if masks.shape[-1] == 3: 267 | data_dict["masks"] = masks[:, :, :, 0] 268 | data_dict["sem_masks"] = masks[:, :, :, 1] 269 | data_dict["inst_overlap_masks"] = masks[:, :, :, 2] 270 | else: 271 | data_dict["masks"] = masks[:, :, :, 0] 272 | data_dict["inst_overlap_masks"] = masks[:, :, :, 1] 273 | return data_dict 274 | 275 | 276 | class COCOTransforms(object): 277 | """Data pre-processing steps.""" 278 | 279 | def __init__( 280 | self, 281 | resolution, 282 | val=False, 283 | ): 284 | self.normalize = Normalize(0.0, 1.0) 285 | if val: 286 | self.transforms = transforms.Compose( 287 | [ 288 | ResizeMinShape(resolution), 289 | CenterCrop(resolution), 290 | self.normalize, 291 | ] 292 | ) 293 | else: 294 | from source.data.augs import get_color_distortion 295 | 296 | self.transforms = transforms.Compose( 297 | [ 298 | transforms.RandomResizedCrop(resolution), 299 | ResizeMinShape(resolution), 300 | CenterCrop(resolution), 301 | ] 302 | ) 303 | self.resolution = resolution 304 | 305 | def __call__(self, input): 306 | return self.transforms(input) 307 | 308 | 309 | COCO_CLASSES = [ 310 | "person", 311 | "bicycle", 312 | "car", 313 | "motorcycle", 314 | "airplane", 315 | "bus", 316 | "train", 317 | "truck", 318 | "boat", 319 | "traffic light", 320 | "fire hydrant", 321 | "stop sign", 322 | "parking meter", 323 | "bench", 324 | "bird", 325 | "cat", 326 | "dog", 327 | "horse", 328 | "sheep", 329 | "cow", 330 | "elephant", 331 | "bear", 332 | "zebra", 333 | "giraffe", 334 | "backpack", 335 | "umbrella", 336 | "handbag", 337 | "tie", 338 | "suitcase", 339 | "frisbee", 340 | "skis", 341 | "snowboard", 342 | "sports ball", 343 | "kite", 344 | "baseball bat", 345 | "baseball glove", 346 | "skateboard", 347 | "surfboard", 348 | "tennis racket", 349 | "bottle", 350 | "wine glass", 351 | "cup", 352 | "fork", 353 | "knife", 354 | "spoon", 355 | "bowl", 356 | "banana", 357 | "apple", 358 | "sandwich", 359 | "orange", 360 | "broccoli", 361 | "carrot", 362 | "hot dog", 363 | "pizza", 364 | "donut", 365 | "cake", 366 | "chair", 367 | "couch", 368 | "potted plant", 369 | "bed", 370 | "dining table", 371 | "toilet", 372 | "tv", 373 | "laptop", 374 | "mouse", 375 | "remote", 376 | "keyboard", 377 | "cell phone", 378 | "microwave", 379 | "oven", 380 | "toaster", 381 | "sink", 382 | "refrigerator", 383 | "book", 384 | "clock", 385 | "vase", 386 | "scissors", 387 | "teddy bear", 388 | "hair drier", 389 | "toothbrush", 390 | ] 391 | 392 | COCO_CLASSES_COLOR = [ 393 | (241, 23, 78), 394 | (63, 71, 49), 395 | (67, 79, 143), 396 | (32, 250, 205), 397 | (136, 228, 157), 398 | (135, 125, 104), 399 | (151, 46, 171), 400 | (129, 37, 28), 401 | (3, 248, 159), 402 | (154, 129, 58), 403 | (93, 155, 200), 404 | (201, 98, 152), 405 | (187, 194, 70), 406 | (122, 144, 121), 407 | (168, 31, 32), 408 | (168, 68, 189), 409 | (173, 68, 45), 410 | (200, 81, 154), 411 | (171, 114, 139), 412 | (216, 211, 39), 413 | (187, 119, 238), 414 | (201, 120, 112), 415 | (129, 16, 164), 416 | (211, 3, 208), 417 | (169, 41, 248), 418 | (100, 77, 159), 419 | (140, 104, 243), 420 | (26, 165, 41), 421 | (225, 176, 197), 422 | (35, 212, 67), 423 | (160, 245, 68), 424 | (7, 87, 70), 425 | (52, 107, 85), 426 | (103, 64, 188), 427 | (245, 76, 17), 428 | (248, 154, 59), 429 | (77, 45, 123), 430 | (210, 95, 230), 431 | (172, 188, 171), 432 | (250, 44, 233), 433 | (161, 71, 46), 434 | (144, 14, 134), 435 | (231, 142, 186), 436 | (34, 1, 200), 437 | (144, 42, 108), 438 | (222, 70, 139), 439 | (138, 62, 77), 440 | (178, 99, 61), 441 | (17, 94, 132), 442 | (93, 248, 254), 443 | (244, 116, 204), 444 | (138, 165, 238), 445 | (44, 216, 225), 446 | (224, 164, 12), 447 | (91, 126, 184), 448 | (116, 254, 49), 449 | (70, 250, 105), 450 | (252, 237, 54), 451 | (196, 136, 21), 452 | (234, 13, 149), 453 | (66, 43, 47), 454 | (2, 73, 234), 455 | (118, 181, 5), 456 | (105, 99, 225), 457 | (150, 253, 92), 458 | (59, 2, 121), 459 | (176, 190, 223), 460 | (91, 62, 47), 461 | (198, 124, 140), 462 | (100, 135, 185), 463 | (20, 207, 98), 464 | (216, 38, 133), 465 | (17, 202, 208), 466 | (216, 135, 81), 467 | (212, 203, 33), 468 | (108, 135, 76), 469 | (28, 47, 170), 470 | (142, 128, 121), 471 | (23, 161, 179), 472 | (33, 183, 224), 473 | ] 474 | 475 | 476 | def to_rgb_from_tensor(x): 477 | """Reverse the Normalize operation in torchvision.""" 478 | return (x * 0.5 + 0.5).clamp(0, 1) 479 | 480 | 481 | def _draw_bbox(img, anno, bbox_width=2): 482 | """Draw bbox on images. 483 | 484 | Args: 485 | img: (3, H, W), torch.Tensor 486 | anno: (N, 5) 487 | """ 488 | anno = anno[anno[:, -1] != -1] 489 | img = torch.round((to_rgb_from_tensor(img) * 255.0)).to(dtype=torch.uint8) 490 | bbox = anno[:, :4] 491 | label = anno[:, -1] 492 | draw_label = [COCO_CLASSES[int(lbl)] for lbl in label] 493 | draw_color = [COCO_CLASSES_COLOR[int(lbl)] for lbl in label] 494 | bbox_img = vutils.draw_bounding_boxes( 495 | img, bbox, labels=draw_label, colors=draw_color, width=bbox_width 496 | ) 497 | bbox_img = bbox_img.float() / 255.0 * 2.0 - 1.0 498 | return bbox_img 499 | 500 | 501 | def draw_coco_bbox(imgs, annos, bbox_width=2): 502 | """Draw bbox on batch images. 503 | 504 | Args: 505 | imgs: (B, 3, H, W), torch.Tensor 506 | annos: (B, N, 5) 507 | """ 508 | if len(imgs.shape) == 3: 509 | return draw_coco_bbox(imgs[None], annos[None], bbox_width)[0] 510 | 511 | bbox_imgs = [] 512 | for img, anno in zip(imgs, annos): 513 | bbox_imgs.append(_draw_bbox(img, anno, bbox_width=bbox_width)) 514 | bbox_imgs = torch.stack(bbox_imgs, dim=0) 515 | return bbox_imgs 516 | 517 | 518 | class COCO2017Dataset(Dataset): 519 | """COCO 2017 dataset.""" 520 | 521 | def __init__( 522 | self, 523 | data_root, 524 | split, 525 | coco_transforms=None, 526 | load_anno=True, 527 | ): 528 | set_name = f"{split}2017" 529 | assert set_name in ["train2017", "val2017"], "Wrong set name!" 530 | 531 | self.split = split 532 | self.load_anno = load_anno 533 | self.coco_transforms = coco_transforms 534 | 535 | self.image_dir = os.path.join(data_root, "images", set_name) 536 | self.anno_dir = os.path.join( 537 | data_root, "annotations", f"instances_{set_name}.json" 538 | ) 539 | self.coco = COCO(self.anno_dir) 540 | 541 | self.image_ids = self.coco.getImgIds() 542 | 543 | if split == "train": 544 | # filter image id without annotation 545 | ids = [] 546 | for image_id in self.image_ids: 547 | anno_ids = self.coco.getAnnIds(imgIds=image_id) 548 | annos = self.coco.loadAnns(anno_ids) 549 | if len(annos) == 0: 550 | continue 551 | ids.append(image_id) 552 | self.image_ids = ids 553 | 554 | self.cat_ids = self.coco.getCatIds() 555 | self.cats = sorted(self.coco.loadCats(self.cat_ids), key=lambda x: x["id"]) 556 | self.num_classes = len(self.cats) 557 | 558 | # cat_id is an original cat id,coco_label is set from 0 to 79 559 | self.cat_id_to_cat_name = {cat["id"]: cat["name"] for cat in self.cats} 560 | self.cat_id_to_coco_label = {cat["id"]: i for i, cat in enumerate(self.cats)} 561 | self.coco_label_to_cat_id = {i: cat["id"] for i, cat in enumerate(self.cats)} 562 | self.coco_label_to_cat_name = { 563 | coco_label: self.cat_id_to_cat_name[cat_id] 564 | for coco_label, cat_id in self.coco_label_to_cat_id.items() 565 | } 566 | 567 | print(f"Dataset Size:{len(self.image_ids)}") 568 | print(f"Dataset Class Num:{self.num_classes}") 569 | 570 | # by default only load instance seg_mask, not semantic seg_mask 571 | self.load_sem_mask = False 572 | 573 | def __len__(self): 574 | return len(self.image_ids) 575 | 576 | def __getitem__(self, idx): 577 | image = self.load_image(idx) 578 | H, W = image.shape[:2] 579 | 580 | if self.load_anno: 581 | annos = self.load_annos(idx) # [N, 5] 582 | masks, inst_overlap_masks = self.load_inst_masks(idx) # [H, W]x2 583 | masks = [masks, inst_overlap_masks] 584 | if self.load_sem_mask: 585 | sem_masks = self.load_sem_masks(idx) # [H, W] 586 | masks.insert(1, sem_masks) # [inst, sem, inst_overlap] 587 | masks = np.stack(masks, axis=-1) # [H, W, 2 or 3] 588 | else: 589 | annos = np.zeros((0, 5), dtype=np.float32) 590 | masks = np.zeros((H, W), dtype=np.int32) 591 | 592 | scale = np.array(1.0).astype(np.float32) 593 | size = np.array([image.shape[0], image.shape[1]]).astype(np.float32) 594 | 595 | sample = { 596 | "image": image, 597 | "masks": masks, 598 | # if load_sem_mask, will have a `sem_masks` key after collate_fn 599 | "annos": annos, 600 | "scale": scale, 601 | "size": size, 602 | } 603 | return self.coco_transforms(sample) 604 | 605 | def load_image(self, idx): 606 | """Read image.""" 607 | file_name = self.coco.loadImgs(self.image_ids[idx])[0]["file_name"] 608 | image = cv2.imdecode( 609 | np.fromfile(os.path.join(self.image_dir, file_name), dtype=np.uint8), 610 | cv2.IMREAD_COLOR, 611 | ) 612 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 613 | return image.astype(np.uint8) 614 | 615 | def load_annos(self, idx): 616 | """Load bbox and cls.""" 617 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx]) 618 | annos = self.coco.loadAnns(anno_ids) 619 | 620 | image_info = self.coco.loadImgs(self.image_ids[idx])[0] 621 | image_h, image_w = image_info["height"], image_info["width"] 622 | 623 | targets = np.zeros((0, 5)) 624 | if len(annos) == 0: 625 | return targets.astype(np.float32) 626 | 627 | # filter annos 628 | for anno in annos: 629 | if anno.get("ignore", False): 630 | continue 631 | if anno.get("iscrowd", False): 632 | continue 633 | if anno["category_id"] not in self.cat_ids: 634 | continue 635 | 636 | # bbox format: [x_min, y_min, w, h] 637 | bbox = anno["bbox"] 638 | inter_w = max(0, min(bbox[0] + bbox[2], image_w) - max(bbox[0], 0)) 639 | inter_h = max(0, min(bbox[1] + bbox[3], image_h) - max(bbox[1], 0)) 640 | if inter_w * inter_h == 0: 641 | continue 642 | if bbox[2] * bbox[3] < 1 or bbox[2] < 1 or bbox[3] < 1: 643 | continue 644 | 645 | target = np.zeros((1, 5)) 646 | target[0, :4] = bbox 647 | target[0, 4] = self.cat_id_to_coco_label[anno["category_id"]] 648 | targets = np.append(targets, target, axis=0) 649 | 650 | # [x_min, y_min, w, h] --> [x_min, y_min, x_max, y_max] 651 | targets[:, 2] = targets[:, 0] + targets[:, 2] 652 | targets[:, 3] = targets[:, 1] + targets[:, 3] 653 | 654 | return targets.astype(np.float32) # [N, 5 (x1, y1, x2, y2, cat_id)] 655 | 656 | def load_inst_masks(self, idx): 657 | """Load instance seg_mask and merge them into an argmaxed mask.""" 658 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx]) 659 | annos = self.coco.loadAnns(anno_ids) 660 | 661 | image_info = self.coco.loadImgs(self.image_ids[idx])[0] 662 | image_h, image_w = image_info["height"], image_info["width"] 663 | 664 | masks = np.zeros((image_h, image_w), dtype=np.int32) 665 | inst_overlap_masks = np.zeros_like(masks) # for overlap check 666 | for i, anno in enumerate(annos): 667 | if anno.get("ignore", False): 668 | continue 669 | if anno.get("iscrowd", False): 670 | continue 671 | if anno["category_id"] not in self.cat_ids: 672 | continue 673 | mask = self.coco.annToMask(anno) 674 | masks[mask > 0] = i + 1 # to put background as 0 675 | inst_overlap_masks[mask > 0] += 1 676 | # overlap value > 1 indicates overlap 677 | inst_overlap_masks = (inst_overlap_masks > 1).astype(np.int32) 678 | # [H, W], [H, W], 1 is overlapping pixels 679 | return masks, inst_overlap_masks 680 | 681 | def load_sem_masks(self, idx): 682 | """Load instance seg_mask and merge them into an argmaxed mask.""" 683 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx]) 684 | annos = self.coco.loadAnns(anno_ids) 685 | 686 | image_info = self.coco.loadImgs(self.image_ids[idx])[0] 687 | image_h, image_w = image_info["height"], image_info["width"] 688 | 689 | masks = np.zeros((image_h, image_w), dtype=np.int32) 690 | for i, anno in enumerate(annos): 691 | if anno.get("ignore", False): 692 | continue 693 | if anno.get("iscrowd", False): 694 | continue 695 | if anno["category_id"] not in self.cat_ids: 696 | continue 697 | mask = self.coco.annToMask(anno) 698 | coco_lbl = self.cat_id_to_coco_label[anno["category_id"]] 699 | masks[mask > 0] = coco_lbl + 1 # to put background as 0 700 | # [H, W] 701 | return masks 702 | 703 | 704 | def build_coco_dataset(root_dir, resolution, load_anno=True, val_only=False): 705 | """Build COCO2017 dataset that load images.""" 706 | val_transforms = COCOTransforms( 707 | resolution, 708 | val=True, 709 | ) 710 | args = dict( 711 | data_root=root_dir, 712 | coco_transforms=val_transforms, 713 | split="val", 714 | load_anno=load_anno, 715 | ) 716 | val_dataset = COCO2017Dataset(**args) 717 | if val_only: 718 | return val_dataset, COCOCollater() 719 | args["split"] = "train" 720 | args["load_anno"] = False 721 | args["coco_transforms"] = COCOTransforms( 722 | resolution, 723 | val=False, 724 | ) 725 | train_dataset = COCO2017Dataset(**args) 726 | return train_dataset, val_dataset, COCOCollater() 727 | -------------------------------------------------------------------------------- /source/data/datasets/objs/dsprites.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torch 6 | import torchvision 7 | from torchvision import transforms 8 | from PIL import Image 9 | 10 | from source.data.augs import get_color_distortion 11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset 12 | from source.data.augs import simclr_augmentation 13 | 14 | 15 | def get_dsprites_pair(root, split="train", imsize=64, hflip=False): 16 | path = Path(root, f"colored_on_grayscale_{split}.npz") 17 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip)) 18 | 19 | 20 | def get_dsprites(root, split="train", imsize=64): 21 | path = Path(root, f"colored_on_grayscale_{split}.npz") 22 | return NumpyDataset( 23 | path, 24 | transform=transforms.Compose( 25 | [transforms.Resize(imsize), transforms.ToTensor()] 26 | ), 27 | ) 28 | -------------------------------------------------------------------------------- /source/data/datasets/objs/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from source.data.augs import get_color_distortion 4 | from torchvision.datasets import ImageNet 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | 9 | class ImageNetPair(ImageNet): 10 | """Generate mini-batche pairs on CIFAR10 training set.""" 11 | 12 | def __getitem__(self, idx): 13 | path, target = self.imgs[idx][0], self.targets[idx] 14 | img = self.loader(path) 15 | imgs = [self.transform(img), self.transform(img)] 16 | return torch.stack(imgs) 17 | 18 | 19 | def get_imagenet( 20 | root, 21 | split="train", 22 | transform=transforms.Compose( 23 | [ 24 | transforms.CenterCrop(256), 25 | transforms.ToTensor(), 26 | ] 27 | ), 28 | ): 29 | return ImageNet( 30 | root=root, 31 | split=split, 32 | transform=transform, 33 | ) 34 | 35 | 36 | def get_imagenet_pair(root, split="train", imsize=256, hflip=False): 37 | from source.data.augs import simclr_augmentation 38 | 39 | return ImageNetPair( 40 | root=root, 41 | split=split, 42 | transform=simclr_augmentation(imsize=imsize, hflip=hflip), 43 | ) 44 | -------------------------------------------------------------------------------- /source/data/datasets/objs/load_data.py: -------------------------------------------------------------------------------- 1 | def load_data(data, data_root, data_imsize, is_eval=False): 2 | # Data loading 3 | 4 | collate_fn = None 5 | 6 | if data == "clevrtex_full" or data == "clevrtex_camo" or data == "clevrtex_outd": 7 | from source.data.datasets.objs.clevr_tex import get_clevrtex_pair, get_clevrtex, collate_fn 8 | if data == "clevrtex_camo": 9 | assert is_eval, "Camo dataset is only for evaluation" 10 | data_type = "camo" 11 | default_path = "./data/clevr_tex/clevrtex_camo" 12 | elif data == "clevrtex_outd": 13 | assert is_eval, "OOD dataset is only for evaluation" 14 | data_type = "outd" 15 | default_path = "./data/clevr_tex/clevrtex_outd" 16 | else: 17 | data_type = "full" 18 | default_path = "./data/clevr_tex/clevrtex_full" 19 | 20 | data_root = ( 21 | default_path 22 | if data_root is None 23 | else data_root 24 | ) 25 | imsize = 128 if data_imsize is None else data_imsize 26 | if is_eval: 27 | dataset = get_clevrtex(data_root, split="test", data_type=data_type, imsize=imsize, return_meta_data=True) 28 | else: 29 | dataset = get_clevrtex_pair( 30 | root=data_root, 31 | split="train", 32 | ) 33 | 34 | elif data == "clevr": 35 | from source.data.datasets.objs.clevr import get_clevr_pair, get_clevr 36 | imsize = 128 if data_imsize is None else data_imsize 37 | data_root = "./data/clevr_with_masks/" 38 | if is_eval: 39 | dataset = get_clevr(data_root, split="test", imsize=imsize) 40 | else: 41 | dataset = get_clevr_pair( 42 | root="./data/clevr_with_masks/", 43 | split="train", 44 | ) 45 | 46 | elif data == "imagenet": 47 | from source.data.datasets.objs.imagenet import get_imagenet_pair 48 | data_root = ( 49 | "./data/ImageNet2012/" 50 | if data_root is None 51 | else data_root 52 | ) 53 | dataset = get_imagenet_pair( 54 | root=data_root, 55 | split="train", 56 | hflip=True, 57 | imsize=256 if data_imsize is None else data_imsize, 58 | ) 59 | elif data == "coco": 60 | imsize = 256 if data_imsize is None else data_imsize 61 | from source.data.datasets.objs.coco import get_coco_dataset 62 | data_root = "./data/COCO" 63 | if not is_eval: 64 | raise ValueError("COCO dataset is only for evaluation") 65 | else: 66 | dataset = get_coco_dataset(root=data_root) 67 | 68 | elif data == "pascal": 69 | imsize = 256 if data_imsize is None else data_imsize 70 | from source.data.datasets.objs.pascal import get_pascal_dataset 71 | data_root = "./data/VOCdevkit/VOC2012" 72 | if not is_eval: 73 | raise ValueError("Pascal dataset is only for evaluation") 74 | else: 75 | dataset = get_coco_dataset(root=data_root) 76 | 77 | elif data == "dsprites": 78 | imsize = 64 if data_imsize is None else data_imsize 79 | from source.data.datasets.objs.dsprites import get_dsprites_pair 80 | dataset = get_dsprites_pair( 81 | root="./data/multi_dsprites/", 82 | split="train", 83 | imsize=imsize, 84 | ) 85 | 86 | elif data == "tetrominoes": 87 | from source.data.datasets.objs.tetrominoes import get_tetrominoes_pair 88 | imsize = 32 if data_imsize is None else data_imsize 89 | dataset = get_tetrominoes_pair( 90 | root="./data/tetrominoes/", 91 | split="train", 92 | imsize=imsize, 93 | ) 94 | 95 | elif data == "Shapes": 96 | imsize = 40 if data_imsize is None else data_imsize 97 | from source.data.datasets.objs.shapes import get_shapes_pair 98 | dataset = get_shapes_pair( 99 | root="./data/Shapes/", 100 | split="train", 101 | imsize=imsize, 102 | ) 103 | return dataset, imsize, collate_fn 104 | 105 | -------------------------------------------------------------------------------- /source/data/datasets/objs/npdataset.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import torch 7 | import torchvision 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | from source.data.augs import get_color_distortion 12 | 13 | 14 | class NumpyDataset(Dataset): 15 | """NpzDataset: loads a npz file as dataset.""" 16 | 17 | def __init__(self, filename, transform=torchvision.transforms.ToTensor()): 18 | super().__init__() 19 | 20 | dataset = np.load(filename) 21 | self.images = dataset["images"].astype(np.float32) 22 | if self.images.shape[1] == 1: 23 | self.images = np.repeat(self.images, 3, axis=1) 24 | self.pixelwise_instance_labels = dataset["labels"] 25 | 26 | if "class_labels" in dataset: 27 | self.class_labels = dataset["class_labels"] 28 | else: 29 | self.class_labels = None 30 | 31 | if "pixelwise_class_labels" in dataset: 32 | self.pixelwise_class_labels = dataset["pixelwise_class_labels"] 33 | else: 34 | self.pixelwise_class_labels = None 35 | 36 | self.transform = transform 37 | 38 | def __len__(self): 39 | return self.images.shape[0] 40 | 41 | def __getitem__(self, idx): 42 | img = self.images[idx] # {"input_images": self.images[idx]} 43 | img = np.transpose(img, (1, 2, 0)) 44 | img = (255 * img).astype(np.uint8) 45 | img = Image.fromarray(img) # .convert('RGB') 46 | labels = {"pixelwise_instance_labels": self.pixelwise_instance_labels[idx]} 47 | 48 | if self.class_labels is not None: 49 | labels["class_labels"] = self.class_labels[idx] 50 | labels["pixelwise_class_labels"] = self.pixelwise_class_labels[idx] 51 | return self.transform(img), labels 52 | 53 | 54 | class PairDataset(NumpyDataset): 55 | """Generate mini-batche pairs on CIFAR10 training set.""" 56 | 57 | def __getitem__(self, idx): 58 | img = self.images[idx] 59 | img = np.transpose(img, (1, 2, 0)) 60 | img = (255 * img).astype(np.uint8) 61 | img = Image.fromarray(img) # .convert('RGB') 62 | imgs = [self.transform(img), self.transform(img)] 63 | return torch.stack(imgs) # stack a positive pair 64 | -------------------------------------------------------------------------------- /source/data/datasets/objs/pascal.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/loeweX/RotatingFeatures/blob/main/codebase/data/PascalDataset.py 2 | 3 | import os 4 | from typing import Tuple, Dict 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class PascalDataset(Dataset): 13 | def __init__(self, root_dir, partition: str, transform, transform_label) -> None: 14 | """ 15 | Initialize the Pascal VOC 2012 Dataset. 16 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information about the dataset. 17 | 18 | We make use of the “trainaug” variant of this dataset, an unofficial split consisting of 10,582 images, 19 | which includes 1,464 images from the original segmentation train set and 9,118 images from the 20 | Semantic Boundaries dataset. 21 | 22 | Args: 23 | opt (DictConfig): Configuration options. 24 | partition (str): Dataset partition ("train", "val", or "test"). 25 | """ 26 | super(PascalDataset, self).__init__() 27 | 28 | self.partition = partition 29 | self.to_tensor = transforms.ToTensor() 30 | if self.partition == "train": 31 | self.partition = "trainaug" 32 | self.transform = transform 33 | self.transform_label = transform_label 34 | 35 | # As is common in the literature, we test our model on the validation set of the Pascal VOC dataset. 36 | # For validation, create a train/validation split of the official training set manually and 37 | # adjust the code accordingly. 38 | if self.partition == "test": 39 | self.partition = "val" 40 | 41 | # Load Pascal dataset. 42 | partition_dir = os.path.join( 43 | root_dir, "ImageSets", "Segmentation", f"{self.partition}.txt" 44 | ) 45 | 46 | with open(partition_dir) as f: 47 | file_names = [x.strip() for x in f.readlines()] 48 | 49 | self.images = [ 50 | os.path.join(root_dir, "JPEGImages", f"{x}.jpg") for x in file_names 51 | ] 52 | self.pixelwise_class_labels = [ 53 | os.path.join(root_dir, "SegmentationClass", f"{x}.png") for x in file_names 54 | ] 55 | self.pixelwise_instance_labels = [ 56 | os.path.join(root_dir, "SegmentationObject", f"{x}.png") for x in file_names 57 | ] 58 | 59 | # Normalize input images using mean and standard deviation of ImageNet. 60 | # self.normalize = transforms.Normalize( 61 | # (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 62 | # ) 63 | 64 | self.num_classes = 20 65 | 66 | def __len__(self) -> int: 67 | """ 68 | Get the number of images in the dataset. 69 | 70 | Returns: 71 | int: Number of images. 72 | """ 73 | return len(self.images) 74 | 75 | @staticmethod 76 | def _preprocess_pascal_labels(labels: torch.Tensor) -> torch.Tensor: 77 | """ 78 | Preprocess Pascal VOC labels by converting to integer 255-scale and 79 | marking object boundaries as "ignore" label. 80 | 81 | Args: 82 | labels (torch.Tensor): The input labels. 83 | 84 | Returns: 85 | torch.Tensor: The preprocessed labels. 86 | """ 87 | labels = labels * 255 88 | labels[labels == 255] = -1 # "Ignore" label throughout the codebase is -1. 89 | return labels 90 | 91 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 92 | """ 93 | Get an item from the dataset. 94 | 95 | Args: 96 | idx (int): Index of the item to retrieve. 97 | 98 | Returns: 99 | tuple: A tuple containing the input image and corresponding gt_labels. 100 | """ 101 | input_image = Image.open(self.images[idx]).convert("RGB") 102 | pixelwise_class_labels = Image.open(self.pixelwise_class_labels[idx]) 103 | 104 | try: 105 | pixelwise_instance_labels = Image.open(self.pixelwise_instance_labels[idx]) 106 | except FileNotFoundError as e: 107 | # Instance labels are not available for all images in the trainaug set. 108 | if self.partition != "trainaug": 109 | raise FileNotFoundError( 110 | "Instance labels should only be missing for the trainaug partition." 111 | ) from e 112 | # Create an empty target. 113 | pixelwise_instance_labels = Image.new( 114 | "L", (pixelwise_class_labels.width, pixelwise_class_labels.height), 0 115 | ) 116 | 117 | pixelwise_class_labels = self._preprocess_pascal_labels( 118 | self.to_tensor(pixelwise_class_labels) 119 | ) 120 | pixelwise_instance_labels = self._preprocess_pascal_labels( 121 | self.to_tensor(pixelwise_instance_labels) 122 | ) 123 | 124 | input_image = self.transform(input_image) 125 | pixelwise_instance_labels = self.transform_label(pixelwise_instance_labels)[0] 126 | pixelwise_class_labels = self.transform_label(pixelwise_class_labels)[0] 127 | 128 | labels = { 129 | "pixelwise_class_labels": pixelwise_class_labels, 130 | "pixelwise_instance_labels": pixelwise_instance_labels, 131 | } 132 | return input_image, labels 133 | 134 | 135 | def get_pascal(root, split="train", imsize=256, imsize_label=320): 136 | transform = transforms.Compose( 137 | [ 138 | transforms.Resize( 139 | imsize, interpolation=transforms.InterpolationMode.NEAREST 140 | ), 141 | transforms.CenterCrop(imsize), 142 | transforms.ToTensor(), 143 | ] 144 | ) 145 | 146 | transform_label = transforms.Compose( 147 | [ 148 | transforms.Resize( 149 | imsize_label, interpolation=transforms.InterpolationMode.NEAREST 150 | ), 151 | transforms.CenterCrop(imsize_label), 152 | ] 153 | ) 154 | return PascalDataset( 155 | root, split, transform=transform, transform_label=transform_label 156 | ) 157 | -------------------------------------------------------------------------------- /source/data/datasets/objs/shapes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torch 6 | import torchvision 7 | from torchvision import transforms 8 | from PIL import Image 9 | 10 | from source.data.augs import simclr_augmentation 11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset 12 | 13 | 14 | def get_shapes_pair(root, split="train", imsize=40): 15 | 16 | path = Path(root, f"{split}.npz") 17 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=False)) 18 | 19 | 20 | def get_shapes(root, split="train", imsize=40): 21 | path = Path(root, f"{split}.npz") 22 | return NumpyDataset( 23 | path, 24 | transform=transforms.Compose( 25 | [ 26 | transforms.Resize((imsize, imsize)), 27 | transforms.ToTensor(), 28 | ] 29 | ), 30 | ) 31 | -------------------------------------------------------------------------------- /source/data/datasets/objs/tetrominoes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torch 6 | import torchvision 7 | from torchvision import transforms 8 | from PIL import Image 9 | 10 | from source.data.augs import simclr_augmentation 11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset 12 | 13 | 14 | def get_tetrominoes_pair(root, split="train", imsize=32, hflip=False): 15 | path = Path(root, f"tetrominoes_{split}.npz") 16 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip)) 17 | 18 | 19 | def get_tetrominoes(root, split="train"): 20 | path = Path(root, f"tetrominoes_{split}.npz") 21 | return NumpyDataset(path, transform=transforms.ToTensor()) 22 | -------------------------------------------------------------------------------- /source/data/datasets/sudoku/sudoku.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | 8 | def convert_onehot_to_int(X): 9 | # [B, H, W, 9]->[B, H, W] 10 | is_input = X.sum(dim=-1) 11 | return (is_input * (X.argmax(-1) + 1)).to(torch.int32) 12 | 13 | 14 | # copied from https://github.com/yilundu/ired_code_release/blob/3d74b85fab7fcf5e28aaf15e9ed3bf51c1a1d545/sat_dataset.py#L17 15 | def load_rrn_dataset(data_dir, split): 16 | if not osp.exists(data_dir): 17 | raise ValueError( 18 | f"Data directory {data_dir} does not exist. Run data/download-rrn.sh to download the dataset." 19 | ) 20 | 21 | split_to_filename = {"train": "train.csv", "val": "valid.csv", "test": "test.csv"} 22 | 23 | filename = osp.join(data_dir, split_to_filename[split]) 24 | df = pd.read_csv(filename, header=None) 25 | 26 | def str2onehot(x): 27 | x = np.array(list(map(int, x)), dtype="int64") 28 | y = np.zeros((len(x), 9), dtype="float32") 29 | idx = np.where(x > 0)[0] 30 | y[idx, x[idx] - 1] = 1 31 | return y.reshape((9, 9, 9)) 32 | 33 | features = list() 34 | labels = list() 35 | for i in range(len(df)): 36 | inp = df.iloc[i][0] 37 | out = df.iloc[i][1] 38 | features.append(str2onehot(inp)) 39 | labels.append(str2onehot(out)) 40 | 41 | return torch.tensor(np.array(features)), torch.tensor(np.array(labels)) 42 | 43 | 44 | def load_sat_dataset(path): 45 | with open(os.path.join(path, "features.pt"), "rb") as f: 46 | X = torch.load(f) 47 | with open(os.path.join(path, "labels.pt"), "rb") as f: 48 | Y = torch.load(f) 49 | with open(os.path.join(path, "perm.pt"), "rb") as f: 50 | perm = torch.load(f) 51 | return X, Y, perm 52 | 53 | 54 | class SudokuDataset: 55 | 56 | def __init__(self, path='./data/sudoku/', train=True): 57 | 58 | X, Y, _ = load_sat_dataset(path) 59 | 60 | is_input = X.sum(dim=3, keepdim=True).int() 61 | 62 | indices = torch.arange(0, 9000) if train else torch.arange(9000, 10000) 63 | 64 | self.X = X[indices] 65 | self.Y = Y[indices] 66 | self.is_input = is_input[indices] 67 | 68 | def __len__(self): 69 | return len(self.X) 70 | 71 | def __getitem__(self, idx): 72 | return self.X[idx], self.Y[idx], self.is_input[idx] 73 | 74 | 75 | class HardSudokuDataset: 76 | def __init__(self, path='./data/sudoku-rnn/', split="test"): 77 | 78 | X, Y = load_rrn_dataset(path, split) 79 | 80 | is_input = X.sum(dim=3, keepdim=True).int() 81 | 82 | self.X = X 83 | self.Y = Y 84 | self.is_input = is_input 85 | 86 | def __len__(self): 87 | return len(self.X) 88 | 89 | def __getitem__(self, idx): 90 | return self.X[idx], self.Y[idx], self.is_input[idx] 91 | -------------------------------------------------------------------------------- /source/evals/objs/fgari.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import adjusted_rand_score 3 | 4 | def calc_fgari_score( 5 | gt_labels: np.ndarray, pred_labels: np.ndarray 6 | ) -> float: 7 | """ 8 | Calculate Adjusted Rand Index (ARI) score for object discovery evaluation. 9 | 10 | Args: 11 | opt (DictConfig): Configuration options. 12 | gt_labels (np.ndarray): Ground truth labels, shape ((b, h, w)). 13 | pred_labels (np.ndarray): Predicted labels, shape (b, h, w). 14 | 15 | Returns: 16 | float: ARI score. 17 | """ 18 | aris = [] 19 | for idx in range(gt_labels.shape[0]): 20 | # Remove "ignore" (-1) and background (0) gt_labels. 21 | area_to_eval = np.where(gt_labels[idx] > 0) 22 | 23 | ari = adjusted_rand_score( 24 | gt_labels[idx][area_to_eval], pred_labels[idx][area_to_eval] 25 | ) 26 | aris.append(ari) 27 | return aris -------------------------------------------------------------------------------- /source/evals/objs/mbo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_iou_matrix(gt_labels: np.ndarray, pred_labels: np.ndarray) -> np.ndarray: 5 | """ 6 | Compute the Intersection over Union (IoU) matrix between ground truth and predicted labels. 7 | 8 | Args: 9 | gt_labels (np.ndarray): Ground truth labels, shape (m, h, w). 10 | pred_labels (np.ndarray): Predicted labels, shape (o, h, w). 11 | 12 | Returns: 13 | np.ndarray: IoU matrix, shape (m, o). 14 | """ 15 | intersection = np.logical_and( 16 | gt_labels[:, None, :, :], pred_labels[None, :, :, :] 17 | ).sum(axis=(2, 3)) 18 | union = np.logical_or(gt_labels[:, None, :, :], pred_labels[None, :, :, :]).sum( 19 | axis=(2, 3) 20 | ) 21 | return intersection / (union + 1e-9) 22 | 23 | 24 | def mean_best_overlap_single_sample( 25 | gt_labels: np.ndarray, pred_labels: np.ndarray 26 | ) -> float: 27 | """ 28 | Compute the Mean Best Overlap (MBO) for a single sample between ground truth and predicted labels. 29 | 30 | Args: 31 | gt_labels (np.ndarray): Ground truth labels, shape (h, w). 32 | pred_labels (np.ndarray): Predicted labels, shape (h, w). 33 | 34 | Returns: 35 | float: MBO score for the sample. 36 | """ 37 | from copy import deepcopy 38 | 39 | pred_labels = deepcopy(pred_labels) 40 | 41 | unique_gt_labels = np.unique(gt_labels) 42 | # Remove "ignore" (-1) label. 43 | unique_gt_labels = unique_gt_labels[unique_gt_labels != -1] 44 | 45 | # Mask areas with "ignore" gt_labels in pred_labels. 46 | pred_labels[np.where(gt_labels < 0)] = -1 47 | 48 | # Ignore background (0) gt_labels. 49 | unique_gt_labels = unique_gt_labels[unique_gt_labels != 0] 50 | 51 | if len(unique_gt_labels) == 0: 52 | return -1 # If no gt_labels left, skip this element. 53 | 54 | unique_pred_labels = np.unique(pred_labels) 55 | 56 | # Remove "ignore" (-1) label. 57 | unique_pred_labels = unique_pred_labels[unique_pred_labels != -1] 58 | 59 | gt_masks = np.equal(gt_labels[None, :, :], unique_gt_labels[:, None, None]) 60 | pred_masks = np.equal(pred_labels[None, :, :], unique_pred_labels[:, None, None]) 61 | 62 | iou_matrix = compute_iou_matrix(gt_masks, pred_masks) 63 | best_iou = np.max(iou_matrix, axis=1) 64 | return np.mean(best_iou) 65 | 66 | 67 | def calc_mean_best_overlap(gt_labels: np.ndarray, pred_labels: np.ndarray) -> float: 68 | """ 69 | Calculate the Mean Best Overlap (MBO) for a batch of ground truth and predicted labels. 70 | 71 | Args: 72 | opt (DictConfig): Configuration options. 73 | gt_labels (np.ndarray): Ground truth labels, shape (b, h, w). 74 | pred_labels (np.ndarray): Predicted labels, shape (b, h, w). 75 | 76 | Returns: 77 | float: MBO score for the batch. 78 | """ 79 | mean_best_overlap = np.array( 80 | [ 81 | mean_best_overlap_single_sample(gt_labels[b_idx], pred_labels[b_idx]) 82 | for b_idx in range(gt_labels.shape[0]) 83 | ] 84 | ) 85 | 86 | if np.any(mean_best_overlap != -1): 87 | return np.mean(mean_best_overlap[mean_best_overlap != -1]), mean_best_overlap 88 | else: 89 | return 0.0, mean_best_overlap 90 | -------------------------------------------------------------------------------- /source/evals/sudoku/evals.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def compute_board_accuracy(pred, Y, is_input): 5 | #print(pred.shape) 6 | B = pred.shape[0] 7 | pred = pred.reshape((B, -1, 9)).argmax(-1) 8 | Y = Y.argmax(dim=-1).reshape(B, -1) 9 | mask = 1 - is_input.reshape(B, -1) 10 | 11 | num_blanks = mask.sum(1) 12 | num_correct = (mask * (pred == Y)).sum(1) 13 | board_correct = (num_correct == num_blanks).int() 14 | return num_blanks, num_correct, board_correct 15 | 16 | 17 | -------------------------------------------------------------------------------- /source/layers/common_fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def positionalencoding2d(d_model, height, width): 5 | """ 6 | :param d_model: dimension of the model 7 | :param height: height of the positions 8 | :param width: width of the positions 9 | :return: d_model*height*width position matrix 10 | """ 11 | if d_model % 4 != 0: 12 | raise ValueError("Cannot use sin/cos positional encoding with " 13 | "odd dimension (got dim={:d})".format(d_model)) 14 | pe = torch.zeros(d_model, height, width) 15 | # Each dimension use half of d_model 16 | d_model = int(d_model / 2) 17 | div_term = torch.exp(torch.arange(0., d_model, 2) * 18 | -(math.log(10000.0) / d_model)) 19 | pos_w = torch.arange(0., width).unsqueeze(1) 20 | pos_h = torch.arange(0., height).unsqueeze(1) 21 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 22 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 23 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 24 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 25 | 26 | return pe 27 | -------------------------------------------------------------------------------- /source/layers/common_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | import einops 6 | import numpy as np 7 | 8 | from source.layers.gta import ( 9 | make_2dcoord, 10 | make_SO2mats, 11 | rep_mul_x, 12 | ) 13 | 14 | 15 | class Interpolate(nn.Module): 16 | 17 | def __init__(self, r, mode="bilinear"): 18 | super().__init__() 19 | self.r = r 20 | self.mode = mode 21 | 22 | def forward(self, x): 23 | return F.interpolate( 24 | x, scale_factor=self.r, mode=self.mode, align_corners=False 25 | ) 26 | 27 | 28 | class Reshape(nn.Module): 29 | def __init__(self, *args): 30 | super().__init__() 31 | self.shape = args 32 | 33 | def forward(self, x): 34 | return x.view(self.shape) 35 | 36 | 37 | class ResBlock(nn.Module): 38 | 39 | def __init__(self, fn): 40 | super().__init__() 41 | self.fn = fn 42 | 43 | def forward(self, x): 44 | return x + self.fn(x) 45 | 46 | 47 | class PatchEmbedding(nn.Module): 48 | def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128): 49 | super(PatchEmbedding, self).__init__() 50 | self.patch_size = patch_size 51 | self.embed_dim = embed_dim 52 | self.num_patches = (img_size // patch_size) ** 2 53 | self.proj = nn.Conv2d( 54 | in_channels, embed_dim, kernel_size=patch_size, stride=patch_size 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.proj(x) # (B, embed_dim, H/patch_size, W/patch_size) 59 | x = x.flatten(2) # (B, embed_dim, num_patches) 60 | x = x.transpose(1, 2) # (B, num_patches, embed_dim) 61 | return x 62 | 63 | 64 | class ReadOutConv(nn.Module): 65 | def __init__( 66 | self, 67 | inch, 68 | outch, 69 | out_dim, 70 | kernel_size=1, 71 | stride=1, 72 | padding=0, 73 | ): 74 | super().__init__() 75 | self.outch = outch 76 | self.out_dim = out_dim 77 | self.invconv = nn.Conv2d( 78 | inch, 79 | outch * out_dim, 80 | kernel_size=kernel_size, 81 | stride=stride, 82 | padding=padding, 83 | ) 84 | self.bias = nn.Parameter(torch.zeros(outch)) 85 | 86 | def forward(self, x): 87 | x = self.invconv(x).unflatten(1, (self.outch, -1)) 88 | x = torch.linalg.norm(x, dim=2) + self.bias[None, :, None, None] 89 | return x 90 | 91 | 92 | class BNReLUConv2d(nn.Module): 93 | 94 | def __init__( 95 | self, 96 | inch, 97 | outch, 98 | kernel_size=1, 99 | stride=1, 100 | padding=0, 101 | norm=None, 102 | act=nn.ReLU(), 103 | ): 104 | super().__init__() 105 | if norm == "gn": 106 | norm = lambda ch: nn.GroupNorm(8, ch) 107 | elif norm == "bn": 108 | norm = lambda ch: nn.BatchNorm2d(ch) 109 | elif norm == None: 110 | norm = lambda ch: nn.Identity() 111 | else: 112 | raise NotImplementedError 113 | 114 | conv = nn.Conv2d( 115 | inch, 116 | outch, 117 | kernel_size=kernel_size, 118 | stride=stride, 119 | padding=padding, 120 | ) 121 | 122 | self.fn = nn.Sequential( 123 | norm(inch), 124 | act, 125 | conv, 126 | ) 127 | 128 | def forward(self, x): 129 | return self.fn(x) 130 | 131 | 132 | class FF(nn.Module): 133 | 134 | def __init__( 135 | self, 136 | inch, 137 | outch, 138 | hidch=None, 139 | kernel_size=1, 140 | stride=1, 141 | padding=0, 142 | norm=None, 143 | act=nn.ReLU(), 144 | ): 145 | super().__init__() 146 | if hidch is None: 147 | hidch = 4 * inch 148 | self.fn = nn.Sequential( 149 | BNReLUConv2d( 150 | inch, 151 | hidch, 152 | kernel_size=kernel_size, 153 | stride=stride, 154 | padding=padding, 155 | norm=norm, 156 | act=act, 157 | ), 158 | BNReLUConv2d( 159 | hidch, 160 | outch, 161 | kernel_size=kernel_size, 162 | stride=stride, 163 | padding=padding, 164 | norm=norm, 165 | act=act, 166 | ), 167 | ) 168 | 169 | def forward(self, x): 170 | x = self.fn(x) 171 | return x 172 | 173 | 174 | class LayerNormForImage(nn.Module): 175 | def __init__(self, num_features, eps=1e-5): 176 | super().__init__() 177 | self.eps = eps 178 | self.gamma = nn.Parameter(torch.ones(1, num_features, 1, 1)) 179 | self.beta = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 180 | 181 | def forward(self, x): 182 | # x shape: [B, C, H, W] 183 | mean = x.mean(dim=1, keepdim=True) 184 | var = x.var(dim=1, keepdim=True, unbiased=False) 185 | x_normalized = (x - mean) / torch.sqrt(var + self.eps) 186 | if x.ndim == 2: 187 | x_normalized = self.gamma[..., 0, 0] * x_normalized + self.beta[..., 0, 0] 188 | else: 189 | x_normalized = self.gamma * x_normalized + self.beta 190 | return x_normalized 191 | 192 | 193 | class ScaleAndBias(nn.Module): 194 | def __init__(self, num_channels, token_input=False): 195 | super().__init__() 196 | self.scale = nn.Parameter(torch.ones(num_channels)) 197 | self.bias = nn.Parameter(torch.zeros(num_channels)) 198 | self.token_input = token_input 199 | 200 | def forward(self, x): 201 | # Determine the shape for scale and bias based on input dimensions 202 | if self.token_input: 203 | # token input 204 | shape = [1, 1, -1] 205 | scale = self.scale.view(*shape) 206 | bias = self.bias.view(*shape) 207 | else: 208 | # image input 209 | shape = [1, -1] + [1] * (x.dim() - 2) 210 | scale = self.scale.view(*shape) 211 | bias = self.bias.view(*shape) 212 | return x * scale + bias 213 | 214 | 215 | class RGBNormalize(nn.Module): 216 | def __init__(self, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)): 217 | super().__init__() 218 | 219 | self.mean = torch.tensor(mean).view(1, len(mean), 1, 1) 220 | self.std = torch.tensor(std).view(1, len(std), 1, 1) 221 | 222 | def forward(self, x): 223 | if x.device != self.mean.device: 224 | self.mean = self.mean.to(x.device) 225 | self.std = self.std.to(x.device) 226 | return (x - self.mean) / self.std 227 | 228 | def inverse(self, x): 229 | if x.device != self.mean.device: 230 | self.mean = self.mean.to(x.device) 231 | self.std = self.std.to(x.device) 232 | return (x * self.std) + self.mean 233 | 234 | 235 | class FeatureAttention(nn.Module): 236 | def __init__(self, n, ch): 237 | super().__init__() 238 | self.n = n 239 | self.ch = ch 240 | self.q_linear = nn.Linear(n, n) 241 | self.k_linear = nn.Linear(n, n) 242 | self.v_linear = nn.Linear(n, n) 243 | self.o_linear = nn.Linear(n, n) 244 | 245 | def forward(self, x): 246 | B = x.shape[0] 247 | q, k, v = map(lambda x: x.view(B, -1, self.n), (x, x, x)) 248 | q = self.q_linear(q) 249 | k = self.k_linear(k) 250 | v = self.v_linear(v) 251 | 252 | o = F.scaled_dot_product_attention(q, k, v) 253 | return self.o_linear(o).view(B, -1) 254 | 255 | 256 | class Attention(nn.Module): 257 | def __init__( 258 | self, 259 | ch, 260 | heads=8, 261 | weight="conv", 262 | kernel_size=1, 263 | stride=1, 264 | padding=0, 265 | gta=False, 266 | rope=False, 267 | hw=None, 268 | ): 269 | super().__init__() 270 | 271 | self.heads = heads 272 | self.head_dim = ch // heads 273 | self.weight = weight 274 | self.stride = stride 275 | 276 | if weight == "conv": 277 | self.W_qkv = nn.Conv2d( 278 | ch, 279 | 3 * ch, 280 | kernel_size=kernel_size, 281 | stride=stride, 282 | padding=padding, 283 | ) 284 | self.W_o = nn.Conv2d( 285 | ch, 286 | ch, 287 | kernel_size=kernel_size, 288 | stride=stride, 289 | padding=padding, 290 | ) 291 | elif weight == "fc": 292 | self.W_qkv = nn.Linear(ch, 3 * ch) 293 | self.W_o = nn.Linear(ch, ch) 294 | else: 295 | raise ValueError("weight should be 'conv' or 'fc': {}".format(weight)) 296 | 297 | self.gta = gta 298 | self.rope = rope 299 | assert (int(self.gta) + int(self.rope)) <= 1 # either gta or rope 300 | 301 | self.hw = hw 302 | 303 | if gta or rope: 304 | assert hw is not None 305 | F = self.head_dim // 4 306 | if self.head_dim % 4 != 0: 307 | F = F + 1 308 | 309 | if not isinstance(hw, list): 310 | coord = hw 311 | _mat = make_SO2mats(coord, F).flatten(1, 2) # [h*w, head_dim/2, 2, 2] 312 | else: 313 | coord = make_2dcoord(hw[0], hw[1]) 314 | _mat = ( 315 | make_SO2mats(coord, F).flatten(2, 3).flatten(0, 1) 316 | ) # [h*w, head_dim/2, 2, 2] 317 | 318 | _mat = _mat[..., : self.head_dim // 2, :, :] 319 | # set indentity matrix for additional tokens 320 | 321 | if gta: 322 | self.mat_q = nn.Parameter(_mat) 323 | self.mat_k = nn.Parameter(_mat) 324 | self.mat_v = nn.Parameter(_mat) 325 | self.mat_o = nn.Parameter(_mat.transpose(-2, -1)) 326 | elif rope: 327 | self.mat_q = nn.Parameter(_mat) 328 | self.mat_k = nn.Parameter(_mat) 329 | 330 | def rescale_gta_mat(self, mat, hw): 331 | # _mat = [h*w, head_dim/2, 2, 2] 332 | if hw[0] == self.hw[0] and hw[1] == self.hw[1]: 333 | return mat 334 | else: 335 | f, c, d = mat.shape[1:] 336 | mat = einops.rearrange( 337 | mat, "(h w) f c d -> (f c d) h w", h=self.hw[0], w=self.hw[1] 338 | ) 339 | mat = F.interpolate(mat[None], size=hw, mode="bilinear")[0] 340 | mat = einops.rearrange(mat, "(f c d) h w -> (h w) f c d", f=f, c=c, d=d) 341 | return mat 342 | 343 | def forward(self, x): 344 | 345 | if self.weight == "conv": 346 | h, w = x.shape[2] // self.stride, x.shape[3] // self.stride 347 | else: 348 | h, w = self.hw 349 | 350 | reshape_str = ( 351 | "b (c nh) h w -> b nh (h w) c" 352 | if self.weight == "conv" 353 | else "b k (c nh) -> b nh k c" 354 | ) 355 | dim = 1 if self.weight == "conv" else 2 356 | q, k, v = self.W_qkv(x).chunk(3, dim=dim) 357 | q, k, v = map( 358 | lambda x: einops.rearrange(x, reshape_str, nh=self.heads), 359 | (q, k, v), 360 | ) 361 | if self.gta: 362 | q, k, v = map( 363 | lambda args: rep_mul_x(self.rescale_gta_mat(args[0], (h, w)), args[1]), 364 | ((self.mat_q, q), (self.mat_k, k), (self.mat_v, v)), 365 | ) 366 | elif self.rope: 367 | q, k = map( 368 | lambda args: rep_mul_x(args[0], args[1]), 369 | ((self.mat_q, q), (self.mat_k, k)), 370 | ) 371 | 372 | x = torch.nn.functional.scaled_dot_product_attention( 373 | q, k, v, attn_mask=self.mask if hasattr(self, "mask") else None 374 | ) 375 | 376 | if self.gta: 377 | x = rep_mul_x(self.rescale_gta_mat(self.mat_o, (h, w)), x) 378 | 379 | if self.weight == "conv": 380 | x = einops.rearrange(x, "b nh (h w) c -> b (c nh) h w", h=h, w=w) 381 | else: 382 | x = einops.rearrange(x, "b nh k c -> b k (c nh)") 383 | 384 | x = self.W_o(x) 385 | 386 | return x 387 | -------------------------------------------------------------------------------- /source/layers/gta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from einops import rearrange 5 | 6 | 7 | def make_2dcoord(H, W, normalize=False): 8 | """ 9 | Return(torch.Tensor): 2d coord values of shape [H, W, 2] 10 | """ 11 | x = np.arange(H, dtype=np.float32) # [0, H) 12 | y = np.arange(W, dtype=np.float32) # [0, W) 13 | if normalize: 14 | x = x / H 15 | y = y / W 16 | x_grid, y_grid = np.meshgrid(x, y, indexing="ij") 17 | return torch.Tensor( 18 | np.stack([x_grid.flatten(), y_grid.flatten()], -1).reshape(H, W, 2) 19 | ) 20 | 21 | 22 | def make_SO2mats(coord, nfreqs): 23 | """ 24 | Args: 25 | coord: [..., 2 or 3] 26 | freqs: [n_freqs, 2 or 3] 27 | Return: 28 | mats of shape [..., n_freqs, (2 or 3), 2, 2] 29 | """ 30 | dim = coord.shape[-1] 31 | b = 10000.0 32 | freqs = torch.exp(torch.arange(0.0, 2 * nfreqs, 2) * -(math.log(b) / (2 * nfreqs))) 33 | grid_ths = [ 34 | torch.einsum("...i,j->...ij", coord[..., d : d + 1], freqs).flatten(-2, -1) 35 | for d in range(dim) 36 | ] 37 | 38 | _mats = [ 39 | [ 40 | torch.cos(grid_ths[d]), 41 | -torch.sin(grid_ths[d]), 42 | torch.sin(grid_ths[d]), 43 | torch.cos(grid_ths[d]), 44 | ] 45 | for d in range(dim) 46 | ] 47 | mats = [ 48 | rearrange(torch.stack(_mats[d], -1), "... (h w)->... h w", h=2, w=2) 49 | for d in range(dim) 50 | ] 51 | mat = torch.stack(mats, -3) 52 | return mat 53 | 54 | 55 | # GTA 56 | @torch.jit.script 57 | def rep_mul_x(rep, x): 58 | # rep.shape=[T, F, 2, 2], x.shape=[B, H, T, F*2] 59 | shape = x.shape 60 | d = rep.shape[-1] 61 | return ( 62 | (rep[None, None] * (x.unflatten(-1, (-1, d))[..., None, :])).sum(-1).view(shape) 63 | ) 64 | 65 | 66 | @torch.jit.script 67 | def rep_mul_qkv(rep, q, k, v): 68 | return rep_mul_x(rep, q), rep_mul_x(rep, k), rep_mul_x(rep, v) 69 | 70 | 71 | @torch.jit.script 72 | def rep_mul_qk(rep, q, k): 73 | return rep_mul_x(rep, q), rep_mul_x(rep, k) 74 | 75 | 76 | def embed_block_diagonal(M, n): 77 | """ 78 | Embed a [h*w, d/2, 2, 2] tensor M into a [h*w, d//2n, 4, 4] tensor M' 79 | with block diagonal structure. 80 | 81 | Args: 82 | M (torch.Tensor): Tensor of shape [h*w, d/2, 2, 2] 83 | n (int): Number of blocks to embed into 2nx2n structure 84 | 85 | Returns: 86 | torch.Tensor: Tensor of shape [h*w, d//2n, 4, 4] 87 | """ 88 | h_w, d_half, _, _ = M.shape 89 | 90 | # Initialize an empty tensor for the block diagonal tensor M' 91 | M_prime = torch.zeros((h_w, d_half // n, 4, 4)) 92 | 93 | # Embed M into the block diagonal structure of M_prime 94 | for t in range(h_w): 95 | for d in range(d_half // n): 96 | M_prime[t, d] = torch.block_diag(*[M[t, n * d + i] for i in range(n)]) 97 | print(M_prime.shape) 98 | return M_prime 99 | -------------------------------------------------------------------------------- /source/layers/klayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.parametrizations import weight_norm 4 | 5 | import numpy as np 6 | 7 | from source.layers.common_layers import ( 8 | ScaleAndBias, 9 | Attention, 10 | ) 11 | 12 | from source.layers.kutils import ( 13 | reshape, 14 | reshape_back, 15 | normalize, 16 | ) 17 | 18 | from einops.layers.torch import Rearrange 19 | 20 | 21 | class OmegaLayer(nn.Module): 22 | 23 | def __init__(self, n, ch, init_omg=0.1, global_omg=False, learn_omg=True): 24 | super().__init__() 25 | self.n = n 26 | self.ch = ch 27 | self.global_omg = global_omg 28 | 29 | if not learn_omg: 30 | print("Not learning omega") 31 | 32 | if n % 2 != 0: 33 | # n is odd 34 | raise NotImplementedError 35 | else: 36 | # n is even 37 | if global_omg: 38 | self.omg_param = nn.Parameter( 39 | init_omg * (1 / np.sqrt(2)) * torch.ones(2), requires_grad=learn_omg 40 | ) 41 | else: 42 | self.omg_param = nn.Parameter( 43 | init_omg * (1 / np.sqrt(2)) * torch.ones(ch // 2, 2), 44 | requires_grad=learn_omg, 45 | ) 46 | 47 | def forward(self, x): 48 | _x = reshape(x, 2) 49 | if self.global_omg: 50 | omg = torch.linalg.norm(self.omg_param).repeat(_x.shape[1]) 51 | else: 52 | omg = torch.linalg.norm(self.omg_param, dim=1) 53 | omg = omg[None] 54 | for _ in range(_x.ndim - 3): 55 | omg = omg.unsqueeze(-1) 56 | omg_x = torch.stack([omg * _x[:, :, 1], -omg * _x[:, :, 0]], dim=2) 57 | omg_x = reshape_back(omg_x) 58 | return omg_x 59 | 60 | 61 | class KLayer(nn.Module): # Kuramoto layer 62 | 63 | def __init__( 64 | self, 65 | n, 66 | ch, 67 | J="conv", 68 | c_norm="gn", 69 | use_omega=False, 70 | init_omg=1.0, 71 | ksize=3, 72 | gta=False, 73 | hw=None, 74 | global_omg=False, 75 | heads=8, 76 | learn_omg=True, 77 | apply_proj=True, 78 | ): 79 | # connnectivity is either 'conv' or 'ca' 80 | super().__init__() 81 | assert (ch % n) == 0 82 | self.n = n 83 | self.ch = ch 84 | self.use_omega = use_omega 85 | self.global_omg = global_omg 86 | self.apply_proj = apply_proj 87 | 88 | self.omg = ( 89 | OmegaLayer(n, ch, init_omg, global_omg, learn_omg) 90 | if self.use_omega 91 | else nn.Identity() 92 | ) 93 | 94 | if J == "conv": 95 | self.connectivity = nn.Conv2d(ch, ch, ksize, 1, ksize // 2) 96 | self.x_type = "image" 97 | elif J == "attn": 98 | self.connectivity = Attention( 99 | ch, 100 | heads=heads, 101 | weight="conv", 102 | kernel_size=1, 103 | stride=1, 104 | padding=0, 105 | gta=gta, 106 | hw=hw, 107 | ) 108 | self.x_type = "image" 109 | else: 110 | raise NotImplementedError 111 | 112 | if c_norm == "gn": 113 | self.c_norm = nn.GroupNorm(ch // n, ch, affine=True) 114 | elif c_norm == "sandb": 115 | self.c_norm = ScaleAndBias(ch, token_input=False) 116 | elif c_norm is None or c_norm == "none": 117 | self.c_norm = nn.Identity() 118 | else: 119 | raise NotImplementedError 120 | 121 | def project(self, y, x): 122 | sim = x * y # similarity between update and current state 123 | yxx = torch.sum(sim, 2, keepdim=True) * x 124 | return y - yxx, sim 125 | 126 | def kupdate(self, x: torch.Tensor, c: torch.Tensor = None): 127 | # compute \sum_j[J_ij x_j] 128 | _y = self.connectivity(x) 129 | # add bias c. 130 | y = _y + c 131 | 132 | if hasattr(self, "omg"): 133 | omg_x = self.omg(x) 134 | else: 135 | omg_x = torch.zeros_like(x) 136 | 137 | y = reshape(y, self.n) 138 | x = reshape(x, self.n) 139 | 140 | # project y onto the tangent space 141 | if self.apply_proj: 142 | y_yxx, sim = self.project(y, x) 143 | else: 144 | y_yxx = y 145 | sim = y * x 146 | 147 | dxdt = omg_x + reshape_back(y_yxx) 148 | sim = reshape_back(sim) 149 | 150 | return dxdt, sim 151 | 152 | def forward(self, x: torch.Tensor, c: torch.Tensor, T: int, gamma): 153 | # x.shape = c.shape = [B, C,...] or [B, T, C] 154 | xs, es = [], [] 155 | c = self.c_norm(c) 156 | x = normalize(x, self.n) 157 | es.append(torch.zeros(x.shape[0]).to(x.device)) 158 | # Iterate kuramoto update with condition c 159 | for t in range(T): 160 | dxdt, _sim = self.kupdate(x, c) 161 | x = normalize(x + gamma * dxdt, self.n) 162 | xs.append(x) 163 | es.append((-_sim).reshape(x.shape[0], -1).sum(-1)) 164 | 165 | return xs, es 166 | -------------------------------------------------------------------------------- /source/layers/kutils.py: -------------------------------------------------------------------------------- 1 | from sympy import prod 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import math 6 | import einops 7 | 8 | 9 | def reshape(x: torch.Tensor, n: int): 10 | if x.ndim == 3: # x.shape = ([B, T, C ]) 11 | return x.transpose(1, 2).unflatten(1, (-1, n)) 12 | else: # x.shape = ([B, C, ..., ]) 13 | return x.unflatten(1, (-1, n)) 14 | 15 | 16 | def reshape_back(x): 17 | if x.ndim == 4: # Tokens 18 | return x.flatten(1, 2).transpose(1, 2) 19 | else: 20 | return x.flatten(1, 2) 21 | 22 | 23 | def _l2normalize(x): 24 | return torch.nn.functional.normalize(x, dim=2) 25 | 26 | 27 | def norm(n, x, dim=2, keepdim=True): 28 | return torch.linalg.norm(reshape(x, n), dim=dim, keepdim=keepdim) 29 | 30 | 31 | def normalize(x: torch.Tensor, n): 32 | x = reshape(x, n) 33 | x = _l2normalize(x) 34 | x = reshape_back(x) 35 | return x 36 | 37 | # currently not used 38 | def compute_exponential_map(n, x, dxdt, reshaped_inputs=False): 39 | if not reshaped_inputs: 40 | dxdt = reshape(dxdt, n) 41 | x = reshape(x, n) 42 | norm = torch.linalg.norm(dxdt, dim=2, keepdim=True) 43 | norm = torch.clip(norm, 0, math.pi) 44 | nx = torch.cos(norm) * x + torch.sin(norm) * (dxdt / (norm + 1e-5)) 45 | if not reshaped_inputs: 46 | nx = reshape_back(nx) 47 | return nx 48 | 49 | 50 | class Normalize(nn.Module): 51 | 52 | def __init__(self, n): 53 | super().__init__() 54 | self.n = n 55 | 56 | def forward(self, x): 57 | return normalize(self.n, x) 58 | -------------------------------------------------------------------------------- /source/models/objs/knet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from source.layers.klayer import ( 7 | KLayer, 8 | ) 9 | 10 | 11 | from source.layers.common_layers import ( 12 | RGBNormalize, 13 | ReadOutConv, 14 | Reshape, 15 | ) 16 | 17 | from source.layers.common_fns import ( 18 | positionalencoding2d, 19 | ) 20 | 21 | 22 | class AKOrN(nn.Module): 23 | 24 | def __init__( 25 | self, 26 | n=4, 27 | ch=256, 28 | L=1, 29 | T=8, 30 | psize=4, 31 | gta=True, 32 | J="attn", 33 | ksize=1, 34 | c_norm="gn", 35 | gamma=1.0, 36 | imsize=128, 37 | use_omega=False, 38 | init_omg=1.0, 39 | global_omg=False, 40 | maxpool=True, 41 | project=True, 42 | heads=8, 43 | use_ro_x=False, 44 | learn_omg=True, 45 | no_ro=False, 46 | autorescale=True, 47 | ): 48 | super().__init__() 49 | # assuming input's range is [0, 1] 50 | self.patchfy = nn.Sequential( 51 | RGBNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 52 | nn.Conv2d(3, ch, kernel_size=psize, stride=psize, padding=0), 53 | ) 54 | 55 | if not gta: 56 | self.pos_enc = True 57 | self.pemb_x = nn.Parameter( 58 | positionalencoding2d(ch, imsize // psize, imsize // psize).reshape( 59 | -1, imsize // psize, imsize // psize 60 | ) 61 | ) 62 | self.pemb_c = nn.Parameter( 63 | positionalencoding2d(ch, imsize // psize, imsize // psize).reshape( 64 | -1, imsize // psize, imsize // psize 65 | ) 66 | ) 67 | else: 68 | self.pos_enc = False 69 | 70 | self.n = n 71 | self.ch = ch 72 | self.L = L 73 | if isinstance(T, int): 74 | self.T = [T] * L 75 | else: 76 | self.T = T 77 | if isinstance(J, str): 78 | self.J = [J] * L 79 | else: 80 | self.J = J 81 | self.gamma = torch.nn.Parameter(torch.Tensor([gamma]), requires_grad=False) 82 | self.psize = psize 83 | self.imsize = imsize 84 | 85 | self.layers = nn.ModuleList() 86 | feature_hw = imsize // psize 87 | 88 | feature_hws = [feature_hw] * self.L 89 | chs = [ch] * (self.L + 1) 90 | 91 | for l in range(self.L): 92 | ch = chs[l] 93 | if l == self.L - 1: 94 | ch_next = chs[l + 1] 95 | else: 96 | ch_next = chs[l + 1] 97 | 98 | klayer = KLayer( 99 | n=n, 100 | ch=ch, 101 | J=self.J[l], 102 | gta=gta, 103 | c_norm=c_norm, 104 | use_omega=use_omega, 105 | init_omg=init_omg, 106 | global_omg=global_omg, 107 | heads=heads, 108 | learn_omg=learn_omg, 109 | ksize=ksize, 110 | hw=[feature_hws[l], feature_hws[l]], 111 | apply_proj=project, 112 | ) 113 | readout = ( 114 | ReadOutConv(ch, ch_next, self.n, 1, 1, 0) 115 | if not no_ro 116 | else nn.Identity() 117 | ) 118 | linear_x = ( 119 | nn.Conv2d(ch, ch_next, 1, 1, 0) 120 | if use_ro_x and l < self.L - 1 121 | else nn.Identity() 122 | ) 123 | self.layers.append(nn.ModuleList([klayer, readout, linear_x])) 124 | ch = ch_next 125 | 126 | if maxpool: 127 | pool = nn.AdaptiveMaxPool2d((1, 1)) 128 | else: 129 | pool = nn.AdaptiveAvgPool2d((1, 1)) 130 | 131 | self.out = nn.Sequential( 132 | nn.Identity(), 133 | pool, 134 | Reshape(-1, ch), 135 | nn.Linear(ch, 4 * ch), 136 | nn.ReLU(), 137 | nn.Linear(4 * ch, ch), 138 | ) 139 | 140 | self.fixed_ptb = False 141 | self.autorescale = autorescale 142 | 143 | def feature(self, inp): 144 | if self.autorescale and ( 145 | inp.shape[2] != self.imsize or inp.shape[3] != self.imsize 146 | ): 147 | inp = F.interpolate( 148 | inp, 149 | (self.imsize, self.imsize), 150 | mode="bilinear", 151 | ) 152 | c = self.patchfy(inp) 153 | 154 | if self.fixed_ptb: 155 | g = torch.Generator(device="cpu").manual_seed(1234) 156 | x = torch.randn(*(c.shape), generator=g).to(c.device) 157 | else: 158 | x = torch.randn_like(c) 159 | 160 | if self.pos_enc: 161 | c = c + self.pemb_c[None] 162 | x = x + self.pemb_x[None] 163 | xs = [x] 164 | es = [torch.zeros(x.shape[0], device=x.device)] 165 | for l, (kblock, ro, lin_x) in enumerate(self.layers): 166 | _xs, _es = kblock(x, c, T=self.T[l], gamma=self.gamma) 167 | x = _xs[-1] 168 | c = ro(x) 169 | x = lin_x(x) 170 | xs.append(_xs) 171 | es.append(_es) 172 | 173 | return c, x, xs, es 174 | 175 | def forward(self, input, return_xs=False, return_es=False): 176 | c, x, xs, es = self.feature(input) 177 | c = self.out(c) 178 | 179 | ret = [c] 180 | if return_xs: 181 | ret.append(xs) 182 | if return_es: 183 | ret.append(es) 184 | 185 | if len(ret) == 1: 186 | return ret[0] 187 | return ret 188 | -------------------------------------------------------------------------------- /source/models/objs/patchconv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/akorn/b018ff1ab4e1b3c32398d1e07d2ee9231dc425c1/source/models/objs/patchconv.py -------------------------------------------------------------------------------- /source/models/objs/pretrained_vits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | from timm.models import VisionTransformer 5 | from source.layers.common_layers import RGBNormalize 6 | 7 | 8 | class ViTWrapper(nn.Module): 9 | def __init__(self, model): 10 | super().__init__() 11 | self.model = model 12 | self.norm = RGBNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 13 | 14 | def forward(self, x): 15 | x = self.norm(x) 16 | return self.model(x) 17 | 18 | 19 | def load_dino(): 20 | model = timm.create_model( 21 | "vit_base_patch16_224_dino", pretrained=True, img_size=256 22 | ) 23 | model = ViTWrapper(model).cuda() 24 | model.psize = 16 25 | return model 26 | 27 | 28 | def load_dinov2(imsize=256): 29 | model = timm.create_model( 30 | "vit_large_patch14_dinov2.lvd142m", pretrained=True, img_size=imsize 31 | ) 32 | model = ViTWrapper(model).cuda() 33 | model.psize = 16 34 | return model 35 | 36 | 37 | def load_mocov3(): 38 | from timm.models.vision_transformer import vit_base_patch16_224 39 | 40 | model = vit_base_patch16_224(pretrained=False, dynamic_img_size=True) 41 | checkpoint = torch.hub.load_state_dict_from_url( 42 | "https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar", 43 | map_location="cpu", 44 | ) 45 | # Load the MoCo v3 state dict into the model 46 | state_dict = checkpoint["state_dict"] 47 | new_state_dict = { 48 | k.replace("module.base_encoder.", ""): v for k, v in state_dict.items() 49 | } 50 | model.load_state_dict(new_state_dict, strict=False) 51 | model.eval() 52 | model = ViTWrapper(model).cuda() 53 | model.img_size = 256 54 | model.psize = 16 55 | return model 56 | 57 | 58 | def load_mae(): 59 | from timm.models.vision_transformer import vit_base_patch16_224 60 | 61 | model = vit_base_patch16_224(pretrained=False, dynamic_img_size=True) 62 | checkpoint = torch.hub.load_state_dict_from_url( 63 | "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth", 64 | map_location="cpu", 65 | ) # Load the state_dict into the model 66 | model.load_state_dict( 67 | checkpoint["model"], strict=False 68 | ) # Set the model to evaluation mode model.eval() 69 | model.eval() 70 | model = ViTWrapper(model).cuda() 71 | model.img_size = 256 72 | model.psize = 16 73 | return model 74 | -------------------------------------------------------------------------------- /source/models/objs/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import einops 7 | from einops.layers.torch import Rearrange, Reduce 8 | from source.layers.common_layers import PatchEmbedding, RGBNormalize, LayerNormForImage 9 | from source.layers.common_fns import positionalencoding2d 10 | from source.layers.common_layers import Attention 11 | from source.layers.common_layers import Reshape 12 | 13 | 14 | class TransformerBlock(nn.Module): 15 | def __init__( 16 | self, 17 | embed_dim, 18 | num_heads, 19 | mlp_dim, 20 | dropout=0.0, 21 | hw=None, 22 | gta=False, 23 | ): 24 | super().__init__() 25 | self.layernorm1 = LayerNormForImage(embed_dim) 26 | self.attn = Attention( 27 | embed_dim, 28 | num_heads, 29 | weight="conv", 30 | gta=gta, 31 | hw=hw, 32 | ) 33 | self.layernorm2 = LayerNormForImage(embed_dim) 34 | self.mlp = nn.Sequential( 35 | nn.Conv2d(embed_dim, mlp_dim, 1, 1, 0), 36 | nn.GELU(), 37 | nn.Conv2d(mlp_dim, embed_dim, 1, 1, 0), 38 | nn.Dropout(dropout), 39 | ) 40 | 41 | def forward(self, src, T): 42 | xs = [] 43 | # Repeat attention T times 44 | for _ in range(T): 45 | src2 = self.layernorm1(src) 46 | src2 = self.attn(src2) 47 | src = src + src2 48 | xs.append(src) 49 | 50 | src2 = self.layernorm2(src) 51 | src2 = self.mlp(src2) 52 | src = src + src2 53 | return src, xs 54 | 55 | 56 | class ViT(nn.Module): 57 | # ViT with iterlative self-attention 58 | def __init__( 59 | self, 60 | imsize=128, 61 | psize=8, 62 | ch=128, 63 | blocks=1, 64 | heads=4, 65 | mlp_dim=256, 66 | T=8, 67 | maxpool=False, 68 | gta=False, 69 | autorescale=False, 70 | ): 71 | super().__init__() 72 | self.T = T 73 | self.psize = psize 74 | self.autorescale = autorescale 75 | self.patchfy = nn.Sequential( 76 | RGBNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 77 | nn.Conv2d(3, ch, kernel_size=psize, stride=psize, padding=0), 78 | ) 79 | if not gta: 80 | self.pos_embed = nn.Parameter( 81 | positionalencoding2d(ch, imsize // psize, imsize // psize) 82 | .reshape(-1, imsize // psize, imsize // psize) 83 | ) 84 | 85 | self.transformer_encoder = nn.ModuleList( 86 | [ 87 | TransformerBlock( 88 | ch, 89 | heads, 90 | mlp_dim, 91 | 0.0, 92 | hw=[imsize // psize, imsize // psize], 93 | gta=gta, 94 | ) 95 | for _ in range(blocks) 96 | ] 97 | ) 98 | 99 | self.out = torch.nn.Sequential( 100 | LayerNormForImage(ch), 101 | ( 102 | nn.AdaptiveMaxPool2d((1, 1)) 103 | if not maxpool 104 | else nn.AdaptiveMaxPool2d((1, 1)) 105 | ), 106 | Reshape(-1, ch), 107 | nn.Linear(ch, 4 * ch), 108 | nn.ReLU(), 109 | nn.Linear(4 * ch, ch), 110 | ) 111 | 112 | def feature(self, x): 113 | if self.autorescale and ( 114 | x.shape[2] != self.imsize or x.shape[3] != self.imsize 115 | ): 116 | x = F.interpolate( 117 | x, 118 | (self.imsize, self.imsize), 119 | mode="bilinear", 120 | ) 121 | x = self.patchfy(x) 122 | if hasattr(self, "pos_embed"): 123 | x = x + self.pos_embed[None] 124 | xs = [x] 125 | for block in self.transformer_encoder: 126 | x, _xs = block(x, self.T) 127 | xs.append(_xs) 128 | return x, xs 129 | 130 | def forward(self, x, return_xs=False): 131 | x, xs = self.feature(x) 132 | x = self.out(x) 133 | if return_xs: 134 | return x, xs 135 | else: 136 | return x 137 | -------------------------------------------------------------------------------- /source/models/sudoku/knet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from source.layers.klayer import ( 4 | KLayer, 5 | ) 6 | from source.layers.common_layers import ( 7 | ReadOutConv, 8 | BNReLUConv2d, 9 | FF, 10 | ResBlock, 11 | ) 12 | from source.layers.common_fns import positionalencoding2d 13 | 14 | 15 | from source.data.datasets.sudoku.sudoku import convert_onehot_to_int 16 | 17 | 18 | class SudokuAKOrN(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | n, 23 | ch=64, 24 | L=1, 25 | T=16, 26 | gamma=1.0, 27 | J="attn", 28 | use_omega=True, 29 | global_omg=True, 30 | init_omg=0.1, 31 | learn_omg=False, 32 | nl=True, 33 | heads=8, 34 | ): 35 | super().__init__() 36 | self.n = n 37 | self.L = L 38 | self.ch = ch 39 | self.embedding = nn.Embedding(10, ch) 40 | 41 | hw = [9, 9] 42 | 43 | self.layers = nn.ModuleList() 44 | for l in range(self.L): 45 | self.layers.append( 46 | nn.ModuleList( 47 | [ 48 | KLayer( 49 | n, 50 | ch, 51 | J, 52 | use_omega=use_omega, 53 | c_norm=None, 54 | hw=hw, 55 | global_omg=global_omg, 56 | init_omg=init_omg, 57 | heads=heads, 58 | learn_omg=learn_omg, 59 | gta=True, 60 | ), 61 | nn.Sequential( 62 | ReadOutConv(ch, ch, n, 1, 1, 0), 63 | ResBlock(FF(ch, ch, ch, 1, 1, 0)) if nl else nn.Identity(), 64 | BNReLUConv2d(ch, ch, 1, 1, 0) if nl else nn.Identity(), 65 | ), 66 | ] 67 | ) 68 | ) 69 | 70 | self.out = nn.Sequential(nn.ReLU(), nn.Conv2d(ch, 9, 1, 1, 0)) 71 | 72 | self.T = T 73 | self.gamma = torch.nn.Parameter(torch.Tensor([gamma])) 74 | self.fixed_noise = False 75 | self.x0 = nn.Parameter(torch.randn(1, ch, 9, 9)) 76 | 77 | def feature(self, inp, is_input): 78 | # inp: torch.Tensor of shape [B, 9, 9, 9] the last dim repreents the digit in the one-hot representation. 79 | inp = convert_onehot_to_int(inp) 80 | c = self.embedding(inp).permute(0, 3, 1, 2) 81 | is_input = is_input.permute(0, 3, 1, 2) 82 | xs = [] 83 | es = [] 84 | 85 | # generate random oscillatores 86 | if self.fixed_noise: 87 | n = torch.randn( 88 | *(c.shape), generator=torch.Generator(device="cpu").manual_seed(42) 89 | ).to(c.device) 90 | x = is_input * c + (1 - is_input) * n 91 | else: 92 | n = torch.randn_like(c) 93 | x = is_input * c + (1 - is_input) * n 94 | 95 | for _, (klayer, readout) in enumerate(self.layers): 96 | # Process x and c. 97 | _xs, _es = klayer( 98 | x, 99 | c, 100 | self.T, 101 | self.gamma, 102 | ) 103 | xs.append(_xs) 104 | es.append(_es) 105 | 106 | x = _xs[-1] 107 | c = readout(x) 108 | 109 | return c, xs, es 110 | 111 | def forward(self, c, is_input, return_xs=False, return_es=False): 112 | out, xs, es = self.feature(c, is_input) 113 | out = self.out(out).permute(0, 2, 3, 1) 114 | ret = [out] 115 | if return_xs: 116 | ret.append(xs) 117 | if return_es: 118 | ret.append(es) 119 | if len(ret) == 1: 120 | return ret[0] 121 | else: 122 | return ret 123 | -------------------------------------------------------------------------------- /source/models/sudoku/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from source.layers.common_layers import Attention 7 | from source.layers.common_fns import positionalencoding2d 8 | 9 | from source.data.datasets.sudoku.sudoku import convert_onehot_to_int 10 | 11 | 12 | class TransformerBlock(nn.Module): 13 | def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.0, hw=None, gta=True): 14 | super().__init__() 15 | self.layernorm1 = nn.LayerNorm(embed_dim) 16 | self.attn = Attention(embed_dim, num_heads, weight="fc", gta=gta, hw=hw) 17 | self.layernorm2 = nn.LayerNorm(embed_dim) 18 | self.mlp = nn.Sequential( 19 | nn.Linear(embed_dim, mlp_dim), 20 | nn.GELU(), 21 | nn.Linear(mlp_dim, embed_dim), 22 | nn.Dropout(dropout), 23 | ) 24 | 25 | def forward(self, src, T): 26 | # Repeat attention T times 27 | for _ in range(T): 28 | src2 = self.layernorm1(src) 29 | src2 = self.attn(src2) 30 | src = src + src2 31 | 32 | src2 = self.layernorm2(src) 33 | src2 = self.mlp(src2) 34 | src = src + src2 35 | return src 36 | 37 | 38 | class SudokuTransformer(nn.Module): 39 | def __init__( 40 | self, 41 | ch=64, 42 | blocks=6, 43 | heads=4, 44 | mlp_dim=1024, 45 | T=16, 46 | gta=False, 47 | ): 48 | super().__init__() 49 | self.T = T 50 | 51 | self.embedding = nn.Embedding(10, ch) 52 | if not gta: 53 | self.pos_embed = nn.Parameter( 54 | positionalencoding2d(ch, 9, 9) 55 | .reshape(-1,9,9) 56 | .flatten(1, 2) 57 | .transpose(0, 1) 58 | ) 59 | 60 | self.transformer_encoder = nn.ModuleList( 61 | [ 62 | TransformerBlock( 63 | ch, heads, mlp_dim, 0.0, hw=(9, 9), gta=gta 64 | ) 65 | for _ in range(blocks) 66 | ] 67 | ) 68 | 69 | self.out = torch.nn.Sequential(nn.LayerNorm(ch), nn.Linear(ch, 9)) 70 | 71 | def forward(self, x, is_input): 72 | B = x.size(0) 73 | x = convert_onehot_to_int(x) 74 | x = self.embedding(x) 75 | x = x.view(B, -1, x.shape[-1]) # [B, H*W, C] 76 | is_input = is_input.view(B, -1) 77 | if hasattr(self, "pos_embed"): 78 | x = x + self.pos_embed[None] 79 | for block in self.transformer_encoder: 80 | x = block(x, self.T) 81 | x = self.out(x) 82 | x = x.view(-1, 9, 9, 9) 83 | return x 84 | -------------------------------------------------------------------------------- /source/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | def add_gradient_histograms(writer, model, epoch): 6 | for name, param in model.named_parameters(): 7 | if param.grad is not None: 8 | writer.add_histogram(name + "/grad", param.grad, epoch) 9 | 10 | 11 | def save_model(model, epoch, checkpoint_dir, prefix="checkpoint"): 12 | 13 | if not os.path.exists(checkpoint_dir): 14 | os.makedirs(checkpoint_dir) 15 | 16 | checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}_{epoch}.pth") 17 | 18 | torch.save( 19 | { 20 | "epoch": epoch, 21 | "model_state_dict": model.state_dict(), 22 | }, 23 | checkpoint_path, 24 | ) 25 | 26 | print(f"Model saved: {checkpoint_path}") 27 | 28 | 29 | def save_checkpoint( 30 | model, optimizer, epoch, loss, checkpoint_dir, max_checkpoints=None 31 | ): 32 | 33 | if not os.path.exists(checkpoint_dir): 34 | os.makedirs(checkpoint_dir) 35 | 36 | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth") 37 | 38 | torch.save( 39 | { 40 | "epoch": epoch, 41 | "model_state_dict": model.state_dict(), 42 | "optimizer_state_dict": optimizer.state_dict(), 43 | "loss": loss, 44 | }, 45 | checkpoint_path, 46 | ) 47 | 48 | print(f"Checkpoint saved: {checkpoint_path}") 49 | 50 | manage_checkpoints(checkpoint_dir, max_checkpoints) 51 | 52 | 53 | def manage_checkpoints(checkpoint_dir, max_checkpoints): 54 | if max_checkpoints is None: 55 | return 56 | else: 57 | checkpoints = [ 58 | f 59 | for f in os.listdir(checkpoint_dir) 60 | if f.startswith("checkpoint_") and f.endswith(".pth") 61 | ] 62 | checkpoints.sort(key=lambda f: int(f.split("_")[1].split(".")[0])) 63 | 64 | while len(checkpoints) > max_checkpoints: 65 | old_checkpoint = checkpoints.pop(0) 66 | os.remove(os.path.join(checkpoint_dir, old_checkpoint)) 67 | print(f"Old checkpoint removed: {old_checkpoint}") 68 | 69 | 70 | class LinearWarmupScheduler(_LRScheduler): 71 | def __init__(self, optimizer, warmup_iters, last_iter=-1): 72 | self.warmup_iters = warmup_iters 73 | self.current_iter = 0 if last_iter == -1 else last_iter 74 | self.base_lrs = [group["lr"] for group in optimizer.param_groups] 75 | super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch=last_iter) 76 | 77 | def get_lr(self): 78 | if self.current_iter < self.warmup_iters: 79 | # Linear warmup phase 80 | return [ 81 | base_lr * (self.current_iter + 1) / self.warmup_iters 82 | for base_lr in self.base_lrs 83 | ] 84 | else: 85 | # Maintain the base learning rate after warmup 86 | return [base_lr for base_lr in self.base_lrs] 87 | 88 | def step(self, it=None): 89 | if it is None: 90 | it = self.current_iter + 1 91 | self.current_iter = it 92 | super(LinearWarmupScheduler, self).step(it) 93 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | PLOTCOLORS = { 7 | "blue": "#377eb8", 8 | "orange": "#ff7f00", 9 | "green": "#4daf4a", 10 | "pink": "#f781bf", 11 | "brown": "#a65628", 12 | "purple": "#984ea3", 13 | "gray": "#999999", 14 | "red": "#e41a1c", 15 | "lightgray": "#d3d3d3", 16 | "lightgreen": "#90ee90", 17 | "yellow": "#dede00", 18 | } 19 | 20 | 21 | def ConvSingularValues(kernel, input_shape): 22 | transforms = torch.fft.fft2(kernel.permute(2, 3, 0, 1), input_shape, dim=[0, 1]) 23 | print(transforms.shape) 24 | return torch.linalg.svd(transforms) 25 | 26 | 27 | def ConvSingularValuesNumpy(kernel, input_shape): 28 | kernel = kernel.detach().cpu().numpy() 29 | transforms = np.fft.fft2(kernel.transpose(2, 3, 0, 1), input_shape, axes=[0, 1]) 30 | # transforms = np.fft.fft2(kernel.permute(2,3,0,1), input_shape, dim=[0,1]) 31 | print(transforms.shape) 32 | return np.linalg.svd(transforms) 33 | 34 | 35 | def EigenValues(kernel, input_shape): 36 | # transforms = np.fft.fft2(kernel, input_shape, axes=[0, 1]) 37 | transforms = torch.fft.fft2(kernel.permute(2, 3, 0, 1), input_shape, dim=[0, 1]) 38 | print(transforms.shape) 39 | return torch.linalg.eig(transforms) 40 | 41 | 42 | def load_state_dict_ignore_size_mismatch(model, state_dict): 43 | model_state_dict = model.state_dict() 44 | matched_state_dict = {} 45 | 46 | for key, param in state_dict.items(): 47 | if key in model_state_dict: 48 | if model_state_dict[key].shape == param.shape: 49 | matched_state_dict[key] = param 50 | else: 51 | print( 52 | f"Size mismatch for key '{key}': model {model_state_dict[key].shape}, checkpoint {param.shape}" 53 | ) 54 | else: 55 | print(f"Key '{key}' not found in model state dict.") 56 | 57 | model_state_dict.update(matched_state_dict) 58 | model.load_state_dict(model_state_dict) 59 | 60 | 61 | def compare_optimizer_state_dicts(original, modified): 62 | diff = {} 63 | for key in original.keys(): 64 | if key not in modified: 65 | diff[key] = "Removed" 66 | elif original[key] != modified[key]: 67 | diff[key] = {"Original": original[key], "Modified": modified[key]} 68 | for key in modified.keys(): 69 | if key not in original: 70 | diff[key] = "Added" 71 | return diff 72 | 73 | 74 | def get_worker_init_fn(start, end): 75 | return lambda worker_id: os.sched_setaffinity(0, range(start, end)) 76 | 77 | 78 | def str2bool(x): 79 | if isinstance(x, bool): 80 | return x 81 | x = x.lower() 82 | if x[0] in ["0", "n", "f"]: 83 | return False 84 | elif x[0] in ["1", "y", "t"]: 85 | return True 86 | raise ValueError("Invalid value: {}".format(x)) 87 | 88 | 89 | def apply_pca(x, n_components=3): 90 | # x.shape = [B, C, H, W] 91 | from sklearn.decomposition import PCA 92 | 93 | pca = PCA(n_components) 94 | nx = [] 95 | d = x.shape[1] 96 | for _x in x: 97 | _x = _x.permute(1, 2, 0).reshape(-1, d) 98 | _x = pca.fit_transform(_x) 99 | _x = _x.transpose(1, 0).reshape(n_components, x.shape[2], x.shape[3]) 100 | nx.append(torch.tensor(_x)) 101 | nx = torch.stack(nx, 0) 102 | # normalize to [0, 1] 103 | nx = (nx - nx.min()) / (nx.max() - nx.min()) 104 | return nx 105 | 106 | 107 | def apply_pca_torch(x, n_components=3): 108 | # x: [B, C, H, W] 109 | B, C, H, W = x.shape 110 | N = H * W 111 | 112 | if n_components >= C: 113 | return x 114 | 115 | # Reshape to [B, N, C] 116 | x = x.permute(0, 2, 3, 1).reshape(B, N, C) 117 | 118 | # Center the data per sample 119 | x_mean = x.mean(dim=1, keepdim=True) # [B, 1, C] 120 | x_centered = x - x_mean # [B, N, C] 121 | 122 | # Compute covariance matrix per sample: [B, C, C] 123 | cov = torch.bmm(x_centered.transpose(1, 2), x_centered) / (N - 1) 124 | 125 | # Compute eigenvalues and eigenvectors per sample 126 | eigenvalues, eigenvectors = torch.linalg.eigh( 127 | cov 128 | ) # eigenvalues: [B, C], eigenvectors: [B, C, C] 129 | 130 | # Reverse the order of eigenvalues and eigenvectors to get descending order 131 | eigenvalues = eigenvalues.flip(dims=[1]) 132 | eigenvectors = eigenvectors.flip(dims=[2]) 133 | 134 | # Select the top 'dim' eigenvectors 135 | top_eigenvectors = eigenvectors[:, :, :n_components] # [B, C, dim] 136 | 137 | # Project the centered data onto the top eigenvectors 138 | x_pca = torch.bmm(x_centered, top_eigenvectors) # [B, N, dim] 139 | 140 | # Reshape back to [B, dim, H, W] 141 | x_pca = x_pca.transpose(1, 2).reshape(B, n_components, H, W) 142 | 143 | return x_pca 144 | 145 | 146 | def gen_saccade_imgs(img, psize, r): 147 | H, W = img.shape[-2:] 148 | img = F.interpolate(img, (H + psize - r, W + psize - r), mode="bicubic") 149 | imgs = [] 150 | for h in range(0, psize, r): 151 | for w in range(0, psize, r): 152 | imgs.append(img[:, :, h : h + H, w : w + W]) 153 | return imgs, img[:, :, psize // 2 : H + psize // 2, psize // 2 : W + psize // 2] 154 | -------------------------------------------------------------------------------- /train_obj.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | 5 | import torch 6 | import torch.distributed 7 | import torch.nn as nn 8 | from torch import optim 9 | from source.training_utils import save_checkpoint, save_model, LinearWarmupScheduler, add_gradient_histograms 10 | from source.data.datasets.objs.load_data import load_data 11 | 12 | from source.utils import str2bool 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | import accelerate 16 | from accelerate import Accelerator 17 | 18 | # Visualization 19 | import torch.nn.functional as F 20 | from tqdm import tqdm 21 | 22 | # for distributed training 23 | from torch.distributed.nn.functional import all_gather 24 | 25 | 26 | def create_logger(logging_dir): 27 | """ 28 | Create a logger that writes to a log file and stdout. 29 | """ 30 | logging.basicConfig( 31 | level=logging.INFO, 32 | format="[\033[34m%(asctime)s\033[0m] %(message)s", 33 | datefmt="%Y-%m-%d %H:%M:%S", 34 | handlers=[ 35 | logging.StreamHandler(), 36 | logging.FileHandler(f"{logging_dir}/log.txt"), 37 | ], 38 | ) 39 | logger = logging.getLogger(__name__) 40 | return logger 41 | 42 | 43 | def simclr(zs, temperature=1.0, normalize=True, loss_type="ip"): 44 | # zs: list of tensors. Each tensor has shape (n, d) 45 | if normalize: 46 | zs = [F.normalize(z, p=2, dim=-1) for z in zs] 47 | if zs[0].dim() == 3: 48 | zs = [z.flatten(1, 2) for z in zs] 49 | m = len(zs) 50 | n = zs[0].shape[0] 51 | device = zs[0].device 52 | mask = torch.eye(n * m, device=device) 53 | label0 = torch.fmod(n + torch.arange(0, m * n, device=device), n * m) 54 | z = torch.cat(zs, 0) 55 | if loss_type == "euclid": # euclidean distance 56 | sim = -torch.cdist(z, z) 57 | elif loss_type == "sq": # squared euclidean distance 58 | sim = -(torch.cdist(z, z) ** 2) 59 | elif loss_type == "ip": # inner product 60 | sim = torch.matmul(z, z.transpose(0, 1)) 61 | else: 62 | raise NotImplementedError 63 | logit_zz = sim / temperature 64 | logit_zz += mask * -1e8 65 | loss = nn.CrossEntropyLoss()(logit_zz, label0) 66 | return loss 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | parser = argparse.ArgumentParser() 72 | 73 | # Training options 74 | parser.add_argument("--exp_name", type=str, help="expname") 75 | parser.add_argument("--seed", type=int, default=1234) 76 | parser.add_argument("--beta", type=float, default=0.998, help="ema decay") 77 | parser.add_argument("--epochs", type=int, default=500, help="num of epochs") 78 | parser.add_argument( 79 | "--checkpoint_every", 80 | type=int, 81 | default=50, 82 | help="save checkpoint every specified epochs", 83 | ) 84 | parser.add_argument("--lr", type=float, default=1e-3, help="lr") 85 | parser.add_argument("--warmup_iters", type=int, default=0) 86 | parser.add_argument( 87 | "--finetune", 88 | type=str, 89 | default=None, 90 | help="path to the checkpoint. Training starts from that checkpoint", 91 | ) 92 | 93 | # Data loading 94 | parser.add_argument("--limit_cores_used", type=str2bool, default=False) 95 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core") 96 | parser.add_argument("--cpu_core_end", type=int, default=32, help="end core") 97 | parser.add_argument("--data", type=str, default="clevrtex") 98 | parser.add_argument( 99 | "--data_root", 100 | type=str, 101 | default=None, 102 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset", 103 | ) 104 | parser.add_argument("--batchsize", type=int, default=256) 105 | parser.add_argument("--num_workers", type=int, default=8) 106 | parser.add_argument( 107 | "--data_imsize", 108 | type=int, 109 | default=None, 110 | help="Image size. If None, use the default size of each dataset", 111 | ) 112 | 113 | # Simclr options 114 | parser.add_argument("--normalize", type=str2bool, default=True) 115 | parser.add_argument("--temp", type=float, default=0.1, help="simclr temperature.") 116 | 117 | # General model options 118 | parser.add_argument("--model", type=str, default="akorn", help="model") 119 | parser.add_argument("--L", type=int, default=1, help="num of layers") 120 | parser.add_argument("--ch", type=int, default=256, help="num of channels") 121 | parser.add_argument( 122 | "--model_imsize", 123 | type=int, 124 | default=None, 125 | help= 126 | """ 127 | Model's imsize. This is used when you want finetune a pretrained model 128 | that was trained on images with different resolution than the finetune image dataset. 129 | """ 130 | ) 131 | parser.add_argument("--autorescale", type=str2bool, default=False) 132 | parser.add_argument("--psize", type=int, default=8, help="patch size") 133 | parser.add_argument("--ksize", type=int, default=1, help="kernel size") 134 | parser.add_argument("--T", type=int, default=8, help="num of recurrence") 135 | parser.add_argument( 136 | "--maxpool", type=str2bool, default=True, help="max pooling or avg pooling" 137 | ) 138 | parser.add_argument( 139 | "--heads", type=int, default=8, help="num of heads in self-attention" 140 | ) 141 | parser.add_argument( 142 | "--gta", 143 | type=str2bool, 144 | default=True, 145 | help=""" 146 | use Geometric Transform Attention (https://github.com/autonomousvision/gta) as positional encoding. 147 | Note that, different from the original GTA, the rotating matrices are learnable. 148 | If False, use standard absolute positional encoding used in the original transformer paper. 149 | """, 150 | ) 151 | 152 | # AKOrN options 153 | parser.add_argument("--N", type=int, default=4, help="num of rotating dimensions") 154 | parser.add_argument("--gamma", type=float, default=1.0, help="step size") 155 | parser.add_argument("--J", type=str, default="conv", help="connectivity") 156 | parser.add_argument("--use_omega", type=str2bool, default=False) 157 | parser.add_argument("--global_omg", type=str2bool, default=False) 158 | parser.add_argument( 159 | "--c_norm", 160 | type=str, 161 | default="gn", 162 | help="normalization. gn(GroupNorm), sandb(scale and bias), or none", 163 | ) 164 | parser.add_argument( 165 | "--init_omg", type=float, default=0.01, help="initial omega length" 166 | ) 167 | parser.add_argument("--learn_omg", type=str2bool, default=False) 168 | parser.add_argument( 169 | "--use_ro_x", 170 | type=str2bool, 171 | default=False, 172 | help="apply linear transform to oscillators between consecutive layers", 173 | ) 174 | 175 | # ablation of some components in the AKOrN's block 176 | parser.add_argument( 177 | "--no_ro", type=str2bool, default=False, help="ablation: no use readout module" 178 | ) 179 | parser.add_argument( 180 | "--project", 181 | type=str2bool, 182 | default=True, 183 | help="use projection or not in the Kuramoto layer", 184 | ) 185 | 186 | args = parser.parse_args() 187 | torch.backends.cudnn.benchmark = True 188 | torch.backends.cuda.enable_flash_sdp(enabled=True) 189 | # Setup accelerator 190 | accelerator = Accelerator() 191 | device = accelerator.device 192 | accelerate.utils.set_seed(args.seed + accelerator.process_index) 193 | 194 | import random 195 | import numpy as np 196 | torch.manual_seed(args.seed) 197 | random.seed(args.seed) 198 | np.random.seed(args.seed) 199 | 200 | # Create job directory and logger 201 | jobdir = f"runs/{args.exp_name}/" 202 | if accelerator.is_main_process: 203 | if not os.path.exists(jobdir): 204 | os.makedirs(jobdir) # Make results folder (holds all experiment subfolders) 205 | logger = create_logger(jobdir) 206 | logger.info(f"Experiment directory created at {jobdir}") 207 | else: 208 | logger = create_logger(jobdir) 209 | 210 | if args.limit_cores_used: 211 | def worker_init_fn(worker_id): 212 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end)) 213 | 214 | else: 215 | worker_init_fn = None 216 | 217 | sstrainset, imsize, _ = load_data(args.data, args.data_root, args.data_imsize, False) 218 | 219 | if accelerator.is_main_process: 220 | logger.info(f"Dataset contains {len(sstrainset):,} images") 221 | 222 | ssloader = torch.utils.data.DataLoader( 223 | sstrainset, 224 | batch_size=int(args.batchsize // accelerator.num_processes), 225 | shuffle=True, 226 | num_workers=args.num_workers, 227 | worker_init_fn=worker_init_fn, 228 | ) 229 | 230 | if accelerator.is_main_process: 231 | writer = SummaryWriter(jobdir) 232 | 233 | def train(net, ema, opt, scheduler, loader, epoch): 234 | losses = [] 235 | initial_params = {name: param.clone() for name, param in net.named_parameters()} 236 | running_loss = 0.0 237 | n = 0 238 | 239 | for i, data in tqdm(enumerate(loader, 0)): 240 | net.train() 241 | inputs = data.view(-1, 3, imsize, imsize).to(device) # 2x batchsize 242 | 243 | # forward 244 | outputs = net(inputs) 245 | 246 | # gather outputs because simclr loss requires all outputs across all processes 247 | if accelerator.num_processes > 1: 248 | outputs = torch.cat(all_gather(outputs), 0) 249 | outputs = outputs.unflatten(0, (outputs.shape[0] // 2, 2)) 250 | 251 | loss = simclr( 252 | [outputs[:, 0], outputs[:, 1]], 253 | temperature=args.temp, 254 | normalize=args.normalize, 255 | loss_type="ip", 256 | ) 257 | 258 | opt.zero_grad() 259 | accelerator.backward(loss) 260 | opt.step() 261 | 262 | scheduler.step() 263 | 264 | running_loss += loss.item() * inputs.shape[0] 265 | n += inputs.shape[0] 266 | 267 | ema.update() 268 | 269 | if accelerator.is_main_process: 270 | add_gradient_histograms(writer, net, epoch) 271 | for name, param in net.named_parameters(): 272 | diff = param - initial_params[name] 273 | writer.add_histogram(f"{name}_diff", diff, epoch) 274 | if accelerator.is_main_process: 275 | logger.info( 276 | f"[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss/n:.3f}" 277 | ) 278 | 279 | total_loss = running_loss / n 280 | if accelerator.is_main_process: 281 | writer.add_scalar("training loss", total_loss, epoch) 282 | 283 | return total_loss 284 | 285 | if args.model == "akorn": 286 | from source.models.objs.knet import AKOrN 287 | 288 | net = AKOrN( 289 | args.N, 290 | ch=args.ch, 291 | L=args.L, 292 | T=args.T, 293 | gamma=args.gamma, 294 | J=args.J, # "conv" or "attn", 295 | use_omega=args.use_omega, 296 | global_omg=args.global_omg, 297 | c_norm=args.c_norm, 298 | psize=args.psize, 299 | imsize=imsize if args.model_imsize is None else args.model_imsize, 300 | autorescale=args.autorescale, 301 | init_omg=args.init_omg, 302 | learn_omg=args.learn_omg, 303 | maxpool=args.maxpool, 304 | project=args.project, 305 | heads=args.heads, 306 | use_ro_x=args.use_ro_x, 307 | no_ro=args.no_ro, 308 | gta=args.gta, 309 | ).to("cuda") 310 | 311 | elif args.model == "vit": 312 | from source.models.objs.vit import ViT 313 | # T=1: ViT. T > 1: ItrSA. 314 | net = ViT( 315 | psize=args.psize, 316 | imsize=imsize if args.model_imsize is None else args.model_imsize, 317 | autorescale=args.autorescale, 318 | ch=args.ch, 319 | blocks=args.L, 320 | heads=args.heads, 321 | mlp_dim=2 * args.ch, 322 | T=args.T, 323 | maxpool=args.maxpool, 324 | gta=args.gta, 325 | ).cuda() 326 | else: 327 | raise NotImplementedError 328 | 329 | total_params = sum(p.numel() for p in net.parameters() if p.requires_grad) 330 | print(f"Total number of basemodel parameters: {total_params}") 331 | 332 | if args.finetune: 333 | if accelerator.is_main_process: 334 | logger.info("Loading checkpoint...") 335 | net.load_state_dict(torch.load(args.finetune)["model_state_dict"]) 336 | 337 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.0) 338 | 339 | if args.finetune: 340 | if accelerator.is_main_process: 341 | logger.info("Loading optimizer state...") 342 | optimizer.load_state_dict(torch.load(args.finetune)["optimizer_state_dict"]) 343 | for param_group in optimizer.param_groups: 344 | param_group["lr"] = args.lr 345 | 346 | from ema_pytorch import EMA 347 | 348 | ema = EMA(net, beta=args.beta, update_every=10, update_after_step=200) 349 | 350 | if args.finetune: 351 | if accelerator.is_main_process: 352 | logger.info("Loading checkpoint...") 353 | dir_name, file_name = os.path.split(args.finetune) 354 | file_name = file_name.replace("checkpoint", "ema") 355 | ema_path = os.path.join(dir_name, file_name) 356 | ema.load_state_dict(torch.load(ema_path)["model_state_dict"]) 357 | 358 | if accelerator.is_main_process: 359 | logger.info(f"Training for {args.epochs} epochs...") 360 | 361 | net, optimizer, ssloader = accelerator.prepare(net, optimizer, ssloader) 362 | 363 | scheduler = LinearWarmupScheduler(optimizer, warmup_iters=args.warmup_iters) 364 | 365 | for epoch in range(0, args.epochs): 366 | total_loss = train(net, ema, optimizer, scheduler, ssloader, epoch) 367 | if (epoch + 1) % args.checkpoint_every == 0: 368 | if accelerator.is_main_process: 369 | save_checkpoint( 370 | accelerator.unwrap_model(net), 371 | optimizer, 372 | epoch, 373 | total_loss, 374 | checkpoint_dir=jobdir, 375 | ) 376 | save_model(ema, epoch, checkpoint_dir=jobdir, prefix="ema") 377 | if accelerator.is_main_process: 378 | torch.save( 379 | accelerator.unwrap_model(net).state_dict(), 380 | os.path.join(jobdir, f"model.pth"), 381 | ) 382 | torch.save(ema.state_dict(), os.path.join(jobdir, f"ema_model.pth")) 383 | -------------------------------------------------------------------------------- /train_sudoku.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys, os 3 | import tqdm 4 | import argparse 5 | 6 | from source.models.sudoku.transformer import SudokuTransformer 7 | 8 | from source.training_utils import save_checkpoint, save_model 9 | from source.data.datasets.sudoku.sudoku import SudokuDataset, HardSudokuDataset 10 | from source.models.sudoku.knet import SudokuAKOrN 11 | from source.utils import str2bool 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import math 16 | from ema_pytorch import EMA 17 | 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | 21 | def apply_threshold(model, threshold): 22 | with torch.no_grad(): 23 | for param in model.parameters(): 24 | param.data = torch.where( 25 | param.abs() < threshold, torch.tensor(0.0), param.data 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument("--exp_name", type=str, help="expname") 34 | parser.add_argument("--seed", type=int, default=None, help="seed") 35 | parser.add_argument("--epochs", type=int, default=100, help="num of epochs") 36 | parser.add_argument("--lr", type=float, default=1e-3, help="lr") 37 | parser.add_argument("--beta", type=float, default=0.995, help="ema decay") 38 | parser.add_argument( 39 | "--clip_grad_norm", type=float, default=1.0, help="clip grad norm" 40 | ) 41 | parser.add_argument( 42 | "--checkpoint_every", 43 | type=int, 44 | default=100, 45 | help="save checkpoint every specified epochs", 46 | ) 47 | parser.add_argument("--eval_freq", type=int, default=10, help="freqadv eval") 48 | 49 | # Data loading 50 | parser.add_argument("--limit_cores_used", type=str2bool, default=False) 51 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core") 52 | parser.add_argument("--cpu_core_end", type=int, default=16, help="end core") 53 | parser.add_argument( 54 | "--data_root", 55 | type=str, 56 | default=None, 57 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset", 58 | ) 59 | parser.add_argument("--batchsize", type=int, default=100) 60 | parser.add_argument("--num_workers", type=int, default=4) 61 | 62 | # General model options 63 | parser.add_argument("--model", type=str, default="akorn", help="model") 64 | parser.add_argument("--L", type=int, default=1, help="num of layers") 65 | parser.add_argument("--T", type=int, default=16, help="Timesteps") 66 | parser.add_argument("--ch", type=int, default=512, help="num of channels") 67 | parser.add_argument("--heads", type=int, default=8) 68 | 69 | # AKOrN options 70 | parser.add_argument("--N", type=int, default=4) 71 | parser.add_argument("--gamma", type=float, default=1.0, help="step size") 72 | parser.add_argument("--J", type=str, default="attn", help="connectivity") 73 | parser.add_argument("--use_omega", type=str2bool, default=True) 74 | parser.add_argument("--global_omg", type=str2bool, default=True) 75 | parser.add_argument("--learn_omg", type=str2bool, default=False) 76 | parser.add_argument("--init_omg", type=float, default=0.1) 77 | parser.add_argument("--nl", type=str2bool, default=True) 78 | 79 | parser.add_argument("--speed_test", action="store_true") 80 | 81 | args = parser.parse_args() 82 | 83 | print("Exp name: ", args.exp_name) 84 | 85 | torch.backends.cudnn.benchmark = True 86 | torch.backends.cuda.enable_flash_sdp(enabled=True) 87 | 88 | if args.seed is not None: 89 | import random 90 | import numpy as np 91 | 92 | torch.manual_seed(args.seed) 93 | random.seed(args.seed) 94 | np.random.seed(args.seed) 95 | 96 | def worker_init_fn(worker_id): 97 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end)) 98 | 99 | if args.data_root is not None: 100 | rootdir = args.data_root 101 | else: 102 | rootdir = "./data/sudoku" 103 | 104 | trainloader = torch.utils.data.DataLoader( 105 | SudokuDataset(rootdir, train=True), 106 | batch_size=args.batchsize, 107 | shuffle=True, 108 | num_workers=args.num_workers, 109 | worker_init_fn=worker_init_fn, 110 | ) 111 | testloader = torch.utils.data.DataLoader( 112 | SudokuDataset(rootdir, train=False), 113 | batch_size=100, 114 | shuffle=False, 115 | num_workers=args.num_workers, 116 | worker_init_fn=worker_init_fn, 117 | ) 118 | 119 | jobdir = f"runs/{args.exp_name}/" 120 | writer = SummaryWriter(jobdir) 121 | 122 | # only compute digit-wise accuracy 123 | from source.evals.sudoku.evals import compute_board_accuracy 124 | def compute_acc(net, loader): 125 | net.eval() 126 | correct = 0 127 | total = 0 128 | correct_input = 0 129 | total_input = 0 130 | for X, Y, is_input in loader: 131 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda() 132 | 133 | with torch.no_grad(): 134 | out = net(X, is_input) 135 | 136 | _, _, board_accuracy = compute_board_accuracy(out, Y, is_input) 137 | correct += board_accuracy.sum().item() 138 | total += board_accuracy.shape[0] 139 | 140 | # digit wise input accuracy 141 | out = out.argmax(dim=-1) 142 | Y = Y.argmax(dim=-1) 143 | mask = (1 - is_input).view(out.shape) 144 | correct_input += ((1 - mask) * (out == Y)).sum().item() 145 | total_input += (1 - mask).sum().item() 146 | 147 | acc = correct / total 148 | input_acc = correct_input / total_input 149 | return acc, input_acc, (total, correct), (total_input, correct_input) 150 | 151 | if args.model == "akorn": 152 | print( 153 | f"n: {args.N}, ch: {args.ch}, L: {args.L}, T: {args.T}, type of J: {args.J}" 154 | ) 155 | net = SudokuAKOrN( 156 | n=args.N, 157 | ch=args.ch, 158 | L=args.L, 159 | T=args.T, 160 | gamma=args.gamma, 161 | J=args.J, 162 | use_omega=args.use_omega, 163 | global_omg=args.global_omg, 164 | init_omg=args.init_omg, 165 | learn_omg=args.learn_omg, 166 | nl=args.nl, 167 | heads=args.heads, 168 | ) 169 | elif args.model == "itrsa": 170 | net = SudokuTransformer( 171 | ch=args.ch, 172 | blocks=args.L, 173 | heads=args.heads, 174 | mlp_dim=args.ch * 2, 175 | T=args.T, 176 | gta=False, 177 | ) 178 | else: 179 | raise NotImplementedError 180 | 181 | net.cuda() 182 | 183 | total_params = sum(p.numel() for p in net.parameters() if p.requires_grad) 184 | print(f"Total number of parameters: {total_params}") 185 | 186 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) 187 | 188 | ema = EMA(net, beta=args.beta, update_every=10, update_after_step=100) 189 | 190 | criterion = torch.nn.CrossEntropyLoss(reduction="none") 191 | 192 | # Measure speed 193 | if args.speed_test: 194 | it_sp = 0 195 | time_per_iter = [] 196 | import numpy as np 197 | 198 | for epoch in range(args.epochs): 199 | total_loss = 0 200 | 201 | for X, Y, is_input in tqdm.tqdm(trainloader): 202 | net.train() 203 | ema.train() 204 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda() 205 | 206 | if args.speed_test: 207 | start = torch.cuda.Event(enable_timing=True) 208 | end = torch.cuda.Event(enable_timing=True) 209 | start.record() 210 | 211 | out = net(X, is_input) 212 | 213 | out = out.reshape(-1, 9) 214 | Y = Y.argmax(dim=-1).reshape(-1) 215 | 216 | loss = criterion(out, Y).mean() 217 | 218 | optimizer.zero_grad() 219 | loss.backward() 220 | if args.clip_grad_norm > 0.: 221 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad_norm) 222 | optimizer.step() 223 | 224 | if args.speed_test: 225 | end.record() 226 | torch.cuda.synchronize() 227 | time_elapsed_per_iter = start.elapsed_time(end) 228 | time_per_iter.append(time_elapsed_per_iter) 229 | print(time_elapsed_per_iter) 230 | it_sp = it_sp + 1 231 | if it_sp == 100: 232 | np.save(os.path.join(jobdir, "time.npy"), np.array(time_per_iter)) 233 | exit(0) 234 | 235 | total_loss += loss.item() 236 | ema.update() 237 | 238 | total_loss = total_loss / len(trainloader) 239 | 240 | writer.add_scalar("training loss", total_loss, epoch) 241 | print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {total_loss:.4f}") 242 | 243 | if (epoch + 1) % args.eval_freq == 0: 244 | 245 | acc, input_acc, stats, stats_input = compute_acc(net, testloader) 246 | writer.add_scalar("test/accuracy", acc, epoch) 247 | writer.add_scalar("test/input_accuracy", input_acc, epoch) 248 | print(f"[Test]: Total blanks:{stats[0]}, Accuracy: {acc}") 249 | print( 250 | f"[Test]: Total given squares:{stats_input[0]}, Accuracy on given digits: {input_acc}" 251 | ) 252 | 253 | # EMA evals 254 | acc, input_acc, stats, stats_input = compute_acc(ema.ema_model, testloader) 255 | writer.add_scalar("ema_test/accuracy", acc, epoch) 256 | writer.add_scalar("ema_test/input_accuracy", input_acc, epoch) 257 | print(f"[EMA Test]: Total blanks:{stats[0]}, Accuracy: {acc}") 258 | print( 259 | f"[EMA Test]: Total given squares:{stats_input[0]}, Accuracy on given digits: {input_acc}" 260 | ) 261 | 262 | if (epoch + 1) % args.checkpoint_every == 0: 263 | save_checkpoint(net, optimizer, epoch, total_loss, checkpoint_dir=jobdir) 264 | save_model(ema, epoch, checkpoint_dir=jobdir, prefix="ema") 265 | 266 | torch.save(net.state_dict(), os.path.join(jobdir, f"model.pth")) 267 | torch.save(ema.state_dict(), os.path.join(jobdir, f"ema_model.pth")) 268 | --------------------------------------------------------------------------------