├── .gitignore
├── README.md
├── data
├── convert_tfrecord_to_np.py
├── create_shapes.py
├── download_clevrtex.sh
├── download_pascal_and_coco.sh
├── download_rrn.sh
├── download_satnet.sh
├── download_synths.sh
└── tfloaders
│ ├── clevr_with_masks.py
│ ├── multi_dsprites.py
│ └── tetrominoes.py
├── eval_obj.py
├── eval_sudoku.py
├── requirements.txt
├── scripts
├── sudoku.md
└── synths.md
├── source
├── data
│ ├── augs.py
│ └── datasets
│ │ ├── objs
│ │ ├── clevr.py
│ │ ├── clevr_tex.py
│ │ ├── coco.py
│ │ ├── dsprites.py
│ │ ├── imagenet.py
│ │ ├── load_data.py
│ │ ├── npdataset.py
│ │ ├── pascal.py
│ │ ├── shapes.py
│ │ └── tetrominoes.py
│ │ └── sudoku
│ │ └── sudoku.py
├── evals
│ ├── objs
│ │ ├── fgari.py
│ │ └── mbo.py
│ └── sudoku
│ │ └── evals.py
├── layers
│ ├── common_fns.py
│ ├── common_layers.py
│ ├── gta.py
│ ├── klayer.py
│ └── kutils.py
├── models
│ ├── objs
│ │ ├── knet.py
│ │ ├── patchconv.py
│ │ ├── pretrained_vits.py
│ │ └── vit.py
│ └── sudoku
│ │ ├── knet.py
│ │ └── transformer.py
├── training_utils.py
└── utils.py
├── train_obj.py
└── train_sudoku.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Artificial Kuramoto Oscillatory Neurons (AKOrN)
2 |
3 | Takeru Miyato
4 | ·
5 | Sindy Löwe
6 | ·
7 | Andreas Geiger
8 | ·
9 | Max Welling
10 |
11 |
12 | ICLR2025 (Oral)
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | This page contains instructions for the initial environment setup and code for the CLEVR-Tex experiments.
22 | - Minimal AKOrN model on Google Colab (The fish example in the paper) [here](https://colab.research.google.com/drive/1n8x2uskNxRIqJvvNaljWDuLAMvxkw0Qn)
23 | - Code for other synthetic datasets (Tetrominoes, dSprits, CLEVR): [here](https://github.com/autonomousvision/akorn/blob/main/scripts/synths.md)
24 | - Sudoku solving: [here](https://github.com/autonomousvision/akorn/blob/main/scripts/sudoku.md)
25 |
26 | ## Setup Conda env
27 |
28 | ```
29 | yes | conda create -n akorn python=3.12
30 | conda activate akorn
31 | pip3 install -r requirements.txt
32 | ```
33 |
34 | ## Download the CLEVRTex dataset
35 | ```
36 | cd data
37 | bash download_clevrtex.sh
38 | cd ..
39 | ```
40 |
41 | ## Training
42 | ```
43 | export NUM_GPUS= # If you use a single GPU, run a command without the multi GPU option (remove `--multi-gpu`).
44 | ```
45 |
46 | ### CLEVRTex
47 |
48 | #### AKOrN
49 | ```
50 | export L=1 # The number of layers. L=1 or 2. This can be >2, but we only experimented with a single or two-layer model.
51 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_akorn --data_root=./data/clevrtex_full/ --model=akorn --data=clevrtex_full --J=attn --L=${L}
52 |
53 | # Larger model (L=2, ch=512, bs=512)
54 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_large_akorn --data_root=./data/clevrtex_full/ --model=akorn --data=clevrtex_full --J=attn --L=2 --ch=512 --batchsize=512 --epochs=1024 --lr=0.0005
55 | ```
56 |
57 | #### ItrSA
58 | ```
59 | export L=1
60 | accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=clvtex_itrsa --data_root=./data/clevrtex_full/ --model=vit --data=clevrtex_full --L=${L} --gta=False
61 | ```
62 |
63 | ## Evaluation
64 |
65 | ### CLEVRTex (-OOD, -CAMO)
66 |
67 | ```
68 | export DATA_TYPE=full #{full, outd, camo}
69 | export L=1
70 | # AKOrN
71 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --model=akorn --data=clevrtex_${DATA_TYPE} --J=attn --L=${L} --model_path=runs/clvtex_akorn/ema_499.pth --model_imsize=128
72 | # ItrSA
73 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --model=vit --data=clevrtex_${DATA_TYPE} --gta=False --L=${L} --model_path=runs/clvtex_itrsa/ema_499.pth --model_imsize=128
74 | ```
75 |
76 | ### Eval with Up-tiling (See Appendix section).
77 | ```
78 | # Might take long time depending on the CPU spec
79 | python eval_obj.py --data_root=./data/clevrtex_${DATA_TYPE}/ --saccade_r=4 --model=akorn --data=clevrtex_${DATA_TYPE} --J=attn --L=${L} --model_path=runs/clvtex_akorn/ema_499.pth --model_imsize=128
80 | ```
81 |
82 | #### Performance table
83 | | Model | CLEVRTex FG-ARI | CLEVRTex MBO | OOD FG-ARI | OOD MBO | CAMO FG-ARI | CAMO MBO |
84 | |------------------------------------|-----------------|--------------|------------|---------|-------------|----------|
85 | | ViT | 46.4±0.6 | 25.1±0.7 | 44.1±0.5 | 27.2±0.5 | 32.5±0.6 | 16.1±1.1 |
86 | | ItrSA (L = 1) | 65.7±0.3 | 44.6±0.9 | 64.6±0.8 | 45.1±0.4 | 49.0±0.7 | 30.2±0.8 |
87 | | ItrSA (L = 2) | 76.3±0.4 | 48.5±0.1 | 74.9±0.8 | 46.4±0.5 | 61.9±1.3 | 37.1±0.5 |
88 | | AKOrN (attn, L = 1) | 75.6±0.2 | 55.0±0.0 | 73.4±0.4 | 56.1±1.1 | 59.9±0.1 | 44.3±0.9 |
89 | | AKOrN (attn, L = 2) | 80.5±1.5 | 54.9±0.6 | 79.2±1.2 | 55.7±0.5 | 67.7±1.5 | 46.2±0.9 |
90 |
91 | ##### With Up-tiling (x4)
92 | | Model | CLEVRTex FG-ARI | CLEVRTex MBO | OOD FG-ARI | OOD MBO | CAMO FG-ARI | CAMO MBO |
93 | |------------------------------------|-----------------|--------------|------------|---------|-------------|----------|
94 | | AKOrN (attn, L = 2) | 87.7±1.0 | 55.3±2.1 | 85.2±0.9 | 55.6±1.5 | 74.5±1.2 | 45.6±3.4 |
95 | | Large AKOrN (attn, L = 2) | 88.5±0.9 | 59.7±0.9 | 87.7±0.5 | 60.8±0.6 | 77.0±0.5 | 53.4±0.7 |
96 |
--------------------------------------------------------------------------------
/data/convert_tfrecord_to_np.py:
--------------------------------------------------------------------------------
1 | # provided by @loeweX
2 | import os
3 | from typing import List
4 |
5 | import numpy as np
6 | from PIL import Image
7 | from einops import rearrange
8 | import argparse
9 |
10 |
11 |
12 | from tfloaders import (
13 | clevr_with_masks,
14 | multi_dsprites,
15 | tetrominoes,
16 | )
17 |
18 | def resize_input(x):
19 | # Center-crop with boundaries [29, 221] (height) and [64, 256] (width).
20 | x = x[29:221, 64:256]
21 | # Resize cropped image to 128x128 resolution.
22 | return np.array(
23 | Image.fromarray(x).resize((128, 128), resample=Image.Resampling.NEAREST)
24 | )
25 |
26 | def resize_input_tetrominoes(x):
27 | # Center-crop
28 | x = x[1:33, 1:33]
29 | return x
30 |
31 |
32 | def get_hparams(dataset_name):
33 |
34 | pdir = './'
35 |
36 | if dataset_name == "multi_dsprites":
37 | variant = "colored_on_grayscale" # binarized, colored_on_colored
38 | input_path = f"{pdir}/multi_dsprites/multi_dsprites_{variant}.tfrecords"
39 | output_path = f"{pdir}/multi_dsprites/"
40 |
41 | dataset = multi_dsprites.dataset(input_path, variant)
42 | train_size = 60000
43 | dataset_name = variant
44 |
45 | elif dataset_name == "tetrominoes":
46 | input_path = f"{pdir}/tetrominoes/tetrominoes_train.tfrecords"
47 | output_path = f"{pdir}/tetrominoes/"
48 |
49 | dataset = tetrominoes.dataset(input_path)
50 | train_size = 60000
51 |
52 | elif dataset_name in ["clevr_with_masks"]:
53 | input_path = f"{pdir}/clevr_with_masks/clevr_with_masks_train.tfrecords"
54 | output_path = f"{pdir}/clevr_with_masks/"
55 |
56 | dataset = clevr_with_masks.dataset(input_path)
57 | train_size = 70000
58 |
59 | val_size = 10000 # 5000
60 | test_size = 320
61 | eval_size = 64
62 | return (
63 | input_path,
64 | output_path,
65 | dataset,
66 | train_size,
67 | val_size,
68 | test_size,
69 | eval_size,
70 | dataset_name,
71 | )
72 |
73 |
74 | if __name__ == "__main__":
75 |
76 | parser = argparse.ArgumentParser(description="Convert TFRecord to NPZ")
77 | parser.add_argument(
78 | "--dataset_name",
79 | type=str,
80 | required=True,
81 | choices=["multi_dsprites", "tetrominoes", "clevr_with_masks"],
82 | help="Name of the dataset to convert",
83 | )
84 | args = parser.parse_args()
85 | dataset_name = args.dataset_name
86 |
87 | clevr6 = True
88 |
89 | (
90 | input_path,
91 | output_path,
92 | dataset,
93 | train_size,
94 | val_size,
95 | test_size,
96 | eval_size,
97 | dataset_name,
98 | ) = get_hparams(dataset_name)
99 |
100 | if not os.path.exists(output_path):
101 | os.mkdir(output_path)
102 |
103 | batched_dataset = dataset.batch(1)
104 | iterator = iter(batched_dataset)
105 |
106 | counter = 0
107 | images: List[np.array] = []
108 | labels: List[np.array] = []
109 |
110 |
111 | while True:
112 | try:
113 | data = next(iterator)
114 | except StopIteration:
115 | break
116 |
117 | input_image = np.squeeze(data["image"].numpy())
118 |
119 | pixelwise_label = np.zeros(
120 | (1, input_image.shape[0], input_image.shape[1]), dtype=np.uint8
121 | )
122 | for idx in range(data["mask"].shape[1]):
123 | pixelwise_label[
124 | np.where(data["mask"].numpy()[:, idx, :, :, 0] == 255)
125 | ] = idx
126 |
127 | pixelwise_label = np.squeeze(pixelwise_label)
128 |
129 | if dataset_name in ["clevr_with_masks"]:
130 | input_image = resize_input(input_image)
131 | pixelwise_label = resize_input(pixelwise_label)
132 |
133 | if clevr6 and np.max(pixelwise_label) > 6:
134 | # CLEVR6: only use images with maximally 6 objects
135 | continue
136 |
137 | if dataset_name in ["tetrominoes"]:
138 | input_image = resize_input_tetrominoes(input_image)
139 | pixelwise_label = resize_input_tetrominoes(pixelwise_label)
140 |
141 | input_image = rearrange(input_image, "h w c -> c h w")
142 | input_image = (input_image / 255)
143 |
144 | # pixelwise_label = rearrange(pixelwise_label, "w h -> h w")
145 |
146 | images.append(input_image)
147 | labels.append(pixelwise_label)
148 |
149 | counter += 1
150 |
151 | if counter % 1000 == 0:
152 | print(counter)
153 |
154 | if counter % (train_size + val_size + test_size + eval_size) == 0:
155 | break
156 |
157 | print("Save files")
158 |
159 | np.savez_compressed(
160 | os.path.join(output_path, f"{dataset_name}_eval"),
161 | images=np.squeeze(np.array(images[:eval_size])),
162 | labels=np.squeeze(np.array(labels[:eval_size])),
163 | )
164 |
165 | start_idx = eval_size
166 | np.savez_compressed(
167 | os.path.join(output_path, f"{dataset_name}_test"),
168 | images=np.squeeze(np.array(images[start_idx : start_idx + test_size])),
169 | labels=np.squeeze(np.array(labels[start_idx : start_idx + test_size])),
170 | )
171 |
172 | start_idx += test_size
173 | np.savez_compressed(
174 | os.path.join(output_path, f"{dataset_name}_val"),
175 | images=np.squeeze(np.array(images[start_idx : start_idx + val_size])),
176 | labels=np.squeeze(np.array(labels[start_idx : start_idx + val_size])),
177 | )
178 |
179 | start_idx += val_size
180 | np.savez_compressed(
181 | os.path.join(output_path, f"{dataset_name}_train"),
182 | images=np.squeeze(np.array(images[start_idx:])),
183 | labels=np.squeeze(np.array(labels[start_idx:])),
184 | )
185 |
186 | print(f"Train dataset size: {len(images[start_idx:])}")
187 |
188 |
189 |
--------------------------------------------------------------------------------
/data/create_shapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import random
4 | from skimage.draw import polygon, disk
5 | from tqdm import tqdm
6 |
7 | seed = 1234
8 | np.random.seed(seed)
9 | random.seed(seed)
10 |
11 | # Initialize dimensions
12 | H, W = 18, 18
13 |
14 | # Triangle (hollow with thicker lines)
15 | triangle = np.zeros((H, W), dtype=np.float32)
16 |
17 | outer_rr, outer_cc = polygon([4, 11, 11], [8, 1, 15], shape=(H, W))
18 | inner_rr, inner_cc = polygon([6, 9, 9], [8, 5, 11], shape=(H, W))
19 | triangle[outer_rr, outer_cc] = 1.0
20 | triangle[inner_rr, inner_cc] = 0.0
21 |
22 | # Square (hollow with thicker lines)
23 | square = np.zeros((H, W), dtype=np.float32)
24 | square[3:13, 3:13] = 1.0 # Outer square
25 | square[5:11, 5:11] = 0.0 # Inner square
26 |
27 | # Circle (hollow, larger)
28 | circle = np.zeros((H, W), dtype=np.float32)
29 | outer_rr, outer_cc = disk((8, 8), 7, shape=(H, W))
30 | inner_rr, inner_cc = disk((8, 8), 5, shape=(H, W))
31 | circle[outer_rr, outer_cc] = 1.0
32 | circle[inner_rr, inner_cc] = 0.0
33 |
34 | # Diamond (hollow, larger and thicker lines)
35 | diamond = np.zeros((H, W), dtype=np.float32)
36 | outer_rr, outer_cc = polygon([8, 2, 8, 14], [2, 8, 14, 8], shape=(H, W))
37 | inner_rr, inner_cc = polygon([8, 4, 8, 12], [4, 8, 12, 8], shape=(H, W))
38 | diamond[outer_rr, outer_cc] = 1.0
39 | diamond[inner_rr, inner_cc] = 0.0
40 |
41 | # Shape dictionary
42 | shapes = {1: triangle, 2: square, 3: circle, 4: diamond}
43 |
44 | # Parameters
45 | image_size = 40
46 | num_images = 42000
47 | min_shapes = 2
48 | max_shapes = 4
49 | padding = 0
50 |
51 | # Storage for images and labels
52 | images = np.zeros((num_images, image_size, image_size), dtype=np.float32)
53 | class_labels = [] # Class label for each pixel
54 | instance_labels = [] # Instance label for each pixel
55 |
56 | # Function to place a shape in an image with bounds checking
57 | def place_shape(image, class_label, instance_label, shape, shape_id, instance_id, pos_x, pos_y):
58 | for i in range(H):
59 | for j in range(W):
60 | if shape[i, j] == 1.0:
61 | xi, yj = pos_x + i, pos_y + j
62 | if class_label[xi, yj] == 0: # If background, place shape
63 | image[xi, yj] = 1.0 # Shape pixel value
64 | class_label[xi, yj] = shape_id
65 | instance_label[xi, yj] = instance_id
66 | else: # Overlapping area
67 | class_label[xi, yj] = -1
68 | instance_label[xi, yj] = -1
69 | image[xi, yj] = 1.0 # Ensure overlapping area is also bright
70 |
71 | # Generate images and labels
72 | for img_index in tqdm(range(num_images)):
73 | # Generate a random background color between 0 and 0.3
74 | bg_color = random.uniform(0.1, 0.6)
75 |
76 | # Initialize the image with the random background color
77 | image = np.full((image_size, image_size), bg_color, dtype=np.float32)
78 | class_label = np.zeros((image_size, image_size), dtype=np.int8)
79 | instance_label = np.zeros((image_size, image_size), dtype=np.int8)
80 |
81 | # Randomly select number of shapes
82 | num_shapes = random.randint(min_shapes, max_shapes)
83 |
84 | # Randomly place shapes
85 | for instance_id in range(1, num_shapes + 1):
86 | shape_id = random.randint(1, 4) # Randomly select a shape
87 | shape = shapes[shape_id]
88 |
89 | # Random position with padding
90 | pos_x = random.randint(padding, image_size - H - padding)
91 | pos_y = random.randint(padding, image_size - W - padding)
92 |
93 | # Place shape
94 | place_shape(image, class_label, instance_label, shape, shape_id, instance_id, pos_x, pos_y)
95 |
96 | # Store the image and labels
97 | images[img_index] = image
98 | class_labels.append(class_label)
99 | instance_labels.append(instance_label)
100 |
101 | # Convert class_labels and instance_labels lists to numpy arrays for easy saving
102 | class_labels = np.array(class_labels, dtype=np.int8)
103 | instance_labels = np.array(instance_labels, dtype=np.int8)
104 |
105 | # Example of displaying one generated image with labels
106 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
107 | ax1.imshow(images[0], cmap="gray", vmin=0, vmax=1)
108 | ax1.set_title(f"Generated Image (Background color: {images[0][0,0]:.2f})")
109 | ax2.imshow(class_labels[0], cmap="gray")
110 | ax2.set_title("Class Label (0: background, -1: overlap, 1-4: shapes)")
111 | ax3.imshow(instance_labels[0], cmap="nipy_spectral")
112 | ax3.set_title("Instance Label (-1: overlap)")
113 | plt.show()
114 |
115 | train = np.arange(0, 40000)
116 | val = np.arange(40000, 41000)
117 | test = np.arange(41000, 42000)
118 |
119 | images = images[:, None]
120 | # Saving images and labels (optional)
121 | import os
122 | os.makedirs("./Shapes", exist_ok=True)
123 | dataset = {"images":images[train], "labels": instance_labels[train], "pixelwise_class_labels": class_labels[train]}
124 | np.savez_compressed("./Shapes/train.npz", **dataset)
125 | dataset = {"images":images[val], "labels": instance_labels[val], "pixelwise_class_labels": class_labels[val]}
126 | np.savez_compressed("./Shapes/val.npz", **dataset)
127 | dataset = {"images":images[test], "labels": instance_labels[test], "pixelwise_class_labels": class_labels[test]}
128 | np.savez_compressed("./Shapes/test.npz", **dataset)
129 |
130 | # np.save("instance_labels.npy", instance_labels)
131 |
--------------------------------------------------------------------------------
/data/download_clevrtex.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##### Download CLEVRTEX dataset.
4 | urls=(
5 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part1.tar.gz"
6 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part2.tar.gz"
7 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part3.tar.gz"
8 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part4.tar.gz"
9 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_full_part5.tar.gz"
10 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_outd.tar.gz"
11 | "https://thor.robots.ox.ac.uk/datasets/clevrtex/clevrtex_camo.tar.gz"
12 | )
13 |
14 | output_dir="./"
15 |
16 | mkdir -p $output_dir
17 |
18 | for url in "${urls[@]}"; do
19 |
20 | filename=$(basename "$url")
21 |
22 | echo "Downloading $filename..."
23 | wget -q --show-progress "$url" -P "$output_dir"
24 |
25 | echo "Extracting $filename..."
26 | tar -xzf "$output_dir/$filename" -C "$output_dir"
27 |
28 | rm "$output_dir/$filename"
29 | done
30 |
31 |
--------------------------------------------------------------------------------
/data/download_pascal_and_coco.sh:
--------------------------------------------------------------------------------
1 |
2 | ##### Download Pascal VOC 2012 dataset and add trainaug split.
3 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
4 | wget https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip
5 |
6 | tar -xf VOCtrainval_11-May-2012.tar
7 | unzip SegmentationClassAug.zip -d VOCdevkit/VOC2012
8 |
9 | mv trainaug.txt VOCdevkit/VOC2012/ImageSets/Segmentation
10 | mv VOCdevkit/VOC2012/SegmentationClassAug/* VOCdevkit/VOC2012/SegmentationClass/
11 |
12 | rm -r VOCdevkit/VOC2012/__MACOSX
13 | rm SegmentationClassAug.zip
14 | rm VOCtrainval_11-May-2012.tar
15 |
16 | ###### Download COCO 2017 dataset.
17 | wget http://images.cocodataset.org/zips/train2017.zip
18 | wget http://images.cocodataset.org/zips/val2017.zip
19 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
20 |
21 | unzip train2017.zip
22 | unzip val2017.zip
23 | unzip annotations_trainval2017.zip
24 |
25 | mkdir -p COCO/annotations
26 | mkdir -p COCO/images
27 | mv annotations COCO/
28 | mv train2017 COCO/images
29 | mv val2017 COCO/images
30 |
31 | rm -r annotations_trainval2017.zip
32 | rm train2017.zip
33 | rm val2017.zip
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/data/download_rrn.sh:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/yilundu/ired_code_release/blob/main/data/download-rrn.sh
2 | # Original RRN Sodoku data
3 | wget https://www.dropbox.com/s/rp3hbjs91xiqdgc/sudoku-hard.zip?dl=1
4 | mv sudoku-hard.zip?dl=1 sudoku-hard.zip
5 | unzip sudoku-hard.zip
6 | mv sudoku-hard sudoku-rrn
7 | rm sudoku-hard.zip
8 | rm -rf __MACOSX
--------------------------------------------------------------------------------
/data/download_satnet.sh:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/yilundu/ired_code_release/blob/main/data/download-satnet.sh
2 | # Original SAT-Net repo
3 | wget -cq powei.tw/sudoku.zip && unzip -qq sudoku.zip && rm sudoku.zip
4 | wget -cq powei.tw/parity.zip && unzip -qq parity.zip && rm parity.zip
--------------------------------------------------------------------------------
/data/download_synths.sh:
--------------------------------------------------------------------------------
1 | # Need gsutil installed. run `conda install conda-forge::gsutil` to install it or manually download datasets from https://console.cloud.google.com/storage/browser/multi-object-datasets;tab=objects?pli=1&inv=1&invt=AbjJBg&prefix=&forceOnObjectsSortingFiltering=false
2 | for dataset in tetrominoes multi_dsprites clevr_with_masks; do
3 | gsutil cp -r gs://multi-object-datasets/$dataset ./
4 | python convert_tfrecord_to_np.py --dataset_name=$dataset
5 | done
6 |
--------------------------------------------------------------------------------
/data/tfloaders/clevr_with_masks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """CLEVR (with masks) dataset reader."""
16 |
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
21 | IMAGE_SIZE = [240, 320]
22 | # The maximum number of foreground and background entities in the provided
23 | # dataset. This corresponds to the number of segmentation masks returned per
24 | # scene.
25 | MAX_NUM_ENTITIES = 11
26 | BYTE_FEATURES = ['mask', 'image', 'color', 'material', 'shape', 'size']
27 |
28 | # Create a dictionary mapping feature names to `tf.Example`-compatible
29 | # shape and data type descriptors.
30 | features = {
31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string),
33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
35 | 'z': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
36 | 'pixel_coords': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
37 | 'rotation': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
38 | 'size': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
39 | 'material': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
40 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
41 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.string),
42 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
43 | }
44 |
45 |
46 | def _decode(example_proto):
47 | # Parse the input `tf.Example` proto using the feature description dict above.
48 | single_example = tf.parse_single_example(example_proto, features)
49 | for k in BYTE_FEATURES:
50 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8),
51 | axis=-1)
52 | return single_example
53 |
54 |
55 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
56 | """Read, decompress, and parse the TFRecords file.
57 |
58 | Args:
59 | tfrecords_path: str. Path to the dataset file.
60 | read_buffer_size: int. Number of bytes in the read buffer. See documentation
61 | for `tf.data.TFRecordDataset.__init__`.
62 | map_parallel_calls: int. Number of elements decoded asynchronously in
63 | parallel. See documentation for `tf.data.Dataset.map`.
64 |
65 | Returns:
66 | An unbatched `tf.data.TFRecordDataset`.
67 | """
68 | raw_dataset = tf.data.TFRecordDataset(
69 | tfrecords_path, compression_type=COMPRESSION_TYPE,
70 | buffer_size=read_buffer_size)
71 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)
72 |
--------------------------------------------------------------------------------
/data/tfloaders/multi_dsprites.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Multi-dSprites dataset reader."""
16 |
17 | import functools
18 |
19 | import tensorflow.compat.v1 as tf
20 |
21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string("GZIP")
22 | IMAGE_SIZE = [64, 64]
23 | # The maximum number of foreground and background entities in each variant
24 | # of the provided datasets. The values correspond to the number of
25 | # segmentation masks returned per scene.
26 | MAX_NUM_ENTITIES = {"binarized": 4, "colored_on_grayscale": 6, "colored_on_colored": 5}
27 | BYTE_FEATURES = ["mask", "image"]
28 |
29 |
30 | def feature_descriptions(max_num_entities, is_grayscale=False):
31 | """Create a dictionary desc ribing the dataset features.
32 |
33 | Args:
34 | max_num_entities: int. The maximum number of foreground and background
35 | entities in each image. This corresponds to the number of segmentation
36 | masks and generative factors returned per scene.
37 | is_grayscale: bool. Whether images are grayscale. Otherwise they're assumed
38 | to be RGB.
39 |
40 | Returns:
41 | A dictionary which maps feature names to `tf.Example`-compatible shape and
42 | data type descriptors.
43 | """
44 | num_channels = 1 if is_grayscale else 3
45 | return {
46 | "image": tf.FixedLenFeature(IMAGE_SIZE + [num_channels], tf.string),
47 | "mask": tf.FixedLenFeature(IMAGE_SIZE + [max_num_entities, 1], tf.string),
48 | "x": tf.FixedLenFeature([max_num_entities], tf.float32),
49 | "y": tf.FixedLenFeature([max_num_entities], tf.float32),
50 | "shape": tf.FixedLenFeature([max_num_entities], tf.float32),
51 | "color": tf.FixedLenFeature([max_num_entities, num_channels], tf.float32),
52 | "visibility": tf.FixedLenFeature([max_num_entities], tf.float32),
53 | "orientation": tf.FixedLenFeature([max_num_entities], tf.float32),
54 | "scale": tf.FixedLenFeature([max_num_entities], tf.float32),
55 | }
56 |
57 |
58 | def _decode(example_proto, features):
59 | # Parse the input `tf.Example` proto using a feature description dictionary.
60 | single_example = tf.parse_single_example(example_proto, features)
61 | for k in BYTE_FEATURES:
62 | single_example[k] = tf.squeeze(
63 | tf.decode_raw(single_example[k], tf.uint8), axis=-1
64 | )
65 | # To return masks in the canonical [entities, height, width, channels] format,
66 | # we need to transpose the tensor axes.
67 | single_example["mask"] = tf.transpose(single_example["mask"], [2, 0, 1, 3])
68 | return single_example
69 |
70 |
71 | def dataset(
72 | tfrecords_path, dataset_variant, read_buffer_size=None, map_parallel_calls=None
73 | ):
74 | """Read, decompress, and parse the TFRecords file.
75 |
76 | Args:
77 | tfrecords_path: str. Path to the dataset file.
78 | dataset_variant: str. One of ['binarized', 'colored_on_grayscale',
79 | 'colored_on_colored']. This is used to identify the maximum number of
80 | entities in each scene. If an incorrect identifier is passed in, the
81 | TFRecords file will not be read correctly.
82 | read_buffer_size: int. Number of bytes in the read buffer. See documentation
83 | for `tf.data.TFRecordDataset.__init__`.
84 | map_parallel_calls: int. Number of elements decoded asynchronously in
85 | parallel. See documentation for `tf.data.Dataset.map`.
86 |
87 | Returns:
88 | An unbatched `tf.data.TFRecordDataset`.
89 | """
90 | if dataset_variant not in MAX_NUM_ENTITIES:
91 | raise ValueError(
92 | "Invalid `dataset_variant` provided. The supported values"
93 | " are: {}".format(list(MAX_NUM_ENTITIES.keys()))
94 | )
95 | max_num_entities = MAX_NUM_ENTITIES[dataset_variant]
96 | is_grayscale = dataset_variant == "binarized"
97 | raw_dataset = tf.data.TFRecordDataset(
98 | tfrecords_path, compression_type=COMPRESSION_TYPE, buffer_size=read_buffer_size
99 | )
100 | features = feature_descriptions(max_num_entities, is_grayscale)
101 | partial_decode_fn = functools.partial(_decode, features=features)
102 | return raw_dataset.map(partial_decode_fn, num_parallel_calls=map_parallel_calls)
103 |
104 |
105 |
--------------------------------------------------------------------------------
/data/tfloaders/tetrominoes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tetrominoes dataset reader."""
16 |
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
21 | IMAGE_SIZE = [35, 35]
22 | # The maximum number of foreground and background entities in the provided
23 | # dataset. This corresponds to the number of segmentation masks returned per
24 | # scene.
25 | MAX_NUM_ENTITIES = 4
26 | BYTE_FEATURES = ['mask', 'image']
27 |
28 | # Create a dictionary mapping feature names to `tf.Example`-compatible
29 | # shape and data type descriptors.
30 | features = {
31 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
32 | 'mask': tf.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string),
33 | 'x': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
34 | 'y': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
35 | 'shape': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
36 | 'color': tf.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
37 | 'visibility': tf.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
38 | }
39 |
40 |
41 | def _decode(example_proto):
42 | # Parse the input `tf.Example` proto using the feature description dict above.
43 | single_example = tf.parse_single_example(example_proto, features)
44 | for k in BYTE_FEATURES:
45 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8),
46 | axis=-1)
47 | return single_example
48 |
49 |
50 | def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
51 | """Read, decompress, and parse the TFRecords file.
52 |
53 | Args:
54 | tfrecords_path: str. Path to the dataset file.
55 | read_buffer_size: int. Number of bytes in the read buffer. See documentation
56 | for `tf.data.TFRecordDataset.__init__`.
57 | map_parallel_calls: int. Number of elements decoded asynchronously in
58 | parallel. See documentation for `tf.data.Dataset.map`.
59 |
60 | Returns:
61 | An unbatched `tf.data.TFRecordDataset`.
62 | """
63 | raw_dataset = tf.data.TFRecordDataset(
64 | tfrecords_path, compression_type=COMPRESSION_TYPE,
65 | buffer_size=read_buffer_size)
66 | return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)
67 |
68 |
--------------------------------------------------------------------------------
/eval_obj.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | # sys.path.append(rootdir)
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim
7 | import tqdm
8 | from ema_pytorch import EMA
9 | import matplotlib.pyplot as plt
10 |
11 | from source.models.objs.knet import AKOrN
12 | from source.models.objs.vit import ViT
13 | from source.utils import get_worker_init_fn
14 | from torch.nn import functional as F
15 | from source.layers.common_layers import RGBNormalize
16 | import numpy as np
17 |
18 | import timm
19 | from timm.models import VisionTransformer
20 | from source.utils import gen_saccade_imgs, apply_pca_torch, str2bool
21 | import argparse
22 |
23 | from source.evals.objs.mbo import calc_mean_best_overlap
24 | from source.evals.objs.fgari import calc_fgari_score
25 |
26 | from typing import Callable
27 |
28 | class Wrapper(nn.Module):
29 | def __init__(self, model):
30 | super().__init__()
31 | self.module = model
32 |
33 |
34 | from collections import OrderedDict
35 | from typing import Dict, Callable
36 | import torch
37 |
38 | noise = 0.0
39 |
40 |
41 | def remove_all_forward_hooks(model: torch.nn.Module) -> None:
42 | for name, child in model._modules.items():
43 | if child is not None:
44 | if hasattr(child, "_forward_hooks"):
45 | child._forward_hooks: Dict[int, Callable] = OrderedDict()
46 | remove_all_forward_hooks(child)
47 |
48 |
49 | def model_preds(model, org_images):
50 | activation = {}
51 | imsize_h, imsize_w = org_images.shape[-2], org_images.shape[-1]
52 |
53 | def get_activation(name):
54 | def hook(model, input, output):
55 | activation[name] = output.detach()
56 |
57 | return hook
58 |
59 | if isinstance(model, AKOrN):
60 | model.out[0].register_forward_hook(get_activation("z"))
61 | elif isinstance(model, ViT):
62 | model.out[0].register_forward_hook(get_activation("z"))
63 |
64 | else:
65 | raise Exception()
66 |
67 | model.eval()
68 | imgs = org_images.cuda()
69 |
70 | with torch.no_grad():
71 | if (
72 | isinstance(model, AKOrN)
73 | or isinstance(model, ViT)
74 | ):
75 | output, _xs = model(imgs, return_xs=True)
76 | else:
77 | output = model(imgs)
78 | _xs = None
79 | v = activation["z"]
80 |
81 | if isinstance(model, AKOrN) or isinstance(model, ViT):
82 | v = F.normalize(v, dim=1)
83 | #elif isinstance(model, ViTWrapper):
84 | # v = F.normalize(v, dim=2)
85 | # v = v.permute(0, 2, 1)[..., 1:]
86 | # h, w = int(np.sqrt(x.shape[-1])), int(np.sqrt(x.shape[-1])) # estimated inpsize
87 | # v = v.unflatten(-1, (h, w))
88 | remove_all_forward_hooks(model)
89 | return v
90 |
91 | def clustering(x, h, w, method="spectral", n_clusters=3):
92 | from sklearn.cluster import KMeans
93 |
94 | if method == "agglomerative":
95 | import fastcluster
96 | from scipy.cluster.hierarchy import fcluster
97 | from scipy.cluster.hierarchy import linkage
98 |
99 | x = x.view(x.shape[0], -1).transpose(-2, -1).to("cpu").detach()
100 | Z = fastcluster.average(x)
101 | label = fcluster(Z, t=n_clusters, criterion="maxclust")
102 | return label.reshape(h, w)
103 | elif method == "kmeans":
104 | kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit(
105 | x.view(x.shape[0], -1).transpose(-2, -1).to("cpu").detach()
106 | )
107 | label = kmeans.labels_
108 | return label.reshape(h, w)
109 |
110 | else:
111 | raise ValueError("Clustering method not found")
112 |
113 |
114 |
115 |
116 |
117 | from source.layers.common_fns import positionalencoding2d
118 |
119 | def eval(
120 | model,
121 | images,
122 | gt,
123 | method="agglomerative",
124 | n_clusters=7,
125 | saccade_r=1,
126 | pca=False,
127 | pca_dim=128,
128 | ):
129 | preds = []
130 | N = images.shape[0]
131 | _imgs, _ = gen_saccade_imgs(images, model.psize, model.psize // saccade_r)
132 | outputs = []
133 | for img in _imgs:
134 | v = model_preds(model, img)
135 | outputs.append(v.detach().cpu())
136 |
137 | nh, nw = int(np.sqrt(len(_imgs))), int(np.sqrt(len(_imgs)))
138 | ho, wo = outputs[0].shape[-2], outputs[0].shape[-1]
139 | nimg = torch.zeros(N, outputs[0].shape[1], ho, nh, wo, nw)
140 | for h in range(nh):
141 | for w in range(nw):
142 | nimg[:, :, :, h, :, w] = outputs[h * (nh) + w]
143 | nimg = nimg.view(N, -1, ho * nh, wo * nw)
144 |
145 | from source.utils import apply_pca_torch
146 |
147 | with torch.no_grad():
148 | if pca:
149 | pcaimg_ = apply_pca_torch(nimg, n_components=pca_dim)
150 | x = pcaimg_
151 | else:
152 | x = nimg
153 |
154 | for idx in range(N):
155 | _x = x[idx]
156 | pred = clustering(_x, *_x.shape[1:], method, n_clusters)
157 | pred = torch.nn.Upsample(
158 | scale_factor=(images.shape[-2]/pred.shape[-2], images.shape[-1]/pred.shape[-1]),
159 | mode='nearest')(torch.Tensor(pred[None, None]).float())[0, 0]
160 | preds.append(pred)
161 |
162 | preds = torch.stack(preds, 0).long()
163 |
164 | scores = {}
165 | evaluate_sem = False
166 | if isinstance(gt, list):
167 | gt_sem = gt[1]
168 | gt = gt[0]
169 | evaluate_sem = True
170 |
171 | _gt = ((gt > 0).float() * gt).long() # set ignore bg (-1) to 0
172 | # compute fgari
173 | scores["fgari"] = np.array(calc_fgari_score(_gt, preds))
174 |
175 | # compute mean best overlap
176 | score, _scores = calc_mean_best_overlap(gt.numpy(), preds.numpy())
177 | scores["mbo"] = score
178 | scores["mbo_scores"] = _scores
179 |
180 | if evaluate_sem:
181 | score, _scores = calc_mean_best_overlap(gt_sem.numpy(), preds.numpy())
182 | scores["mbo_c"] = score
183 | scores["mbo_c_scores"] = _scores
184 |
185 | return scores, preds
186 |
187 |
188 | def get_loader(data, data_root, imsize, batchsize):
189 |
190 | from source.data.datasets.objs.load_data import load_data
191 |
192 | dataset, imsize, collate_fn = load_data(data, data_root, imsize, is_eval=True)
193 |
194 | if data == "clevrtex_full" or data == "clevrtex_outd" or data == "clevrtex_camo":
195 | loader = torch.utils.data.DataLoader(
196 | dataset,
197 | batch_size=batchsize,
198 | num_workers=0,
199 | shuffle=True,
200 | collate_fn=collate_fn,
201 | )
202 |
203 | elif data == "coco":
204 | loader = torch.utils.data.DataLoader(
205 | dataset,
206 | batch_size=batchsize,
207 | num_workers=0,
208 | shuffle=True,
209 | collate_fn=collate_fn,
210 | )
211 | else:
212 | loader = torch.utils.data.DataLoader(
213 | dataset,
214 | batch_size=batchsize,
215 | num_workers=0,
216 | shuffle=True,
217 | )
218 | return loader, imsize
219 |
220 |
221 | def eval_dataset(
222 | model,
223 | data,
224 | data_root=None,
225 | imsize=None,
226 | batchsize=100,
227 | method="agglomerative",
228 | instance=True,
229 | saccade_r=1,
230 | pca=False,
231 | ):
232 |
233 | scores = []
234 | preds = []
235 | masks = []
236 |
237 | loader, imsize = get_loader(data, data_root, imsize, batchsize)
238 | for ret in tqdm.tqdm(loader):
239 | pca_dim = 128
240 | if data == "clevr":
241 | images = ret[0]
242 | if instance:
243 | labels = ret[1]["pixelwise_instance_labels"]
244 | else:
245 | labels = ret[1]["pixelwise_class_labels"]
246 | n_clusters = 11
247 |
248 | elif data == "clevrtex_camo" or data == "clevrtex_full" or data == "clevrtex_outd":
249 | images = ret[1]
250 | labels = ret[2][:, 0]
251 | n_clusters = 11
252 |
253 | elif data == "pascal":
254 | images = ret[0]
255 | labels_instance = ret[1]["pixelwise_instance_labels"]
256 | labels_sem = ret[1]["pixelwise_class_labels"]
257 | labels = [labels_instance, labels_sem]
258 | n_clusters = 4
259 |
260 | elif data == "coco":
261 | images = ret["img"]
262 | labels_instance = ret["masks"].long()
263 | labels_sem = ret["sem_masks"].long()
264 | ovlp = ret["inst_overlap_masks"].long()
265 | labels_instance[ovlp == 1] = -1
266 | labels_sem[ovlp == 1] = -1
267 | labels = [labels_instance, labels_sem]
268 | n_clusters = 7
269 | score, pred = eval(
270 | model,
271 | images,
272 | labels,
273 | method,
274 | n_clusters,
275 | saccade_r=saccade_r,
276 | pca=pca,
277 | pca_dim=pca_dim,
278 | )
279 | scores.append(score)
280 | preds.append(pred)
281 | masks.append(labels)
282 | return scores, preds
283 |
284 |
285 | def print_stats(scores):
286 | fgaris = []
287 | mbos = []
288 | mbocs = []
289 | for _s in scores:
290 | fgaris.append(_s["fgari"])
291 | mbos.append(_s["mbo_scores"])
292 | if "mbo_c" in _s:
293 | mbocs.append(_s["mbo_c_scores"])
294 | print(np.concatenate(fgaris, 0).mean(), np.concatenate(fgaris, 0).std())
295 | _mbos = np.concatenate(mbos)
296 | _mbos = _mbos[_mbos != -1]
297 | print(np.mean(_mbos), np.std(_mbos))
298 | if len(mbocs) > 0:
299 | _mbocs = np.concatenate(mbocs)
300 | _mbocs = _mbocs[_mbocs != -1]
301 | print(np.mean(_mbocs), np.std(_mbocs))
302 |
303 |
304 | if __name__ == "__main__":
305 |
306 | parser = argparse.ArgumentParser()
307 |
308 | # Eval options
309 | parser.add_argument("--model_path", type=str, help="path to the model")
310 | parser.add_argument("--saccade_r", type=int, default=1, help="Uptiling factor")
311 | parser.add_argument("--pca", type=str2bool, default=True)
312 |
313 | # Data loading
314 | parser.add_argument("--limit_cores_used", type=str2bool, default=False)
315 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core")
316 | parser.add_argument("--cpu_core_end", type=int, default=32, help="end core")
317 | parser.add_argument("--data", type=str, default="clevrtex_full")
318 | parser.add_argument(
319 | "--data_root",
320 | type=str,
321 | default=None,
322 | help="optional. you can specify the dir path if the default path of each dataset is not appropritate one. Currently only applied to ImageNet",
323 | )
324 | parser.add_argument("--batchsize", type=int, default=250)
325 | parser.add_argument("--num_workers", type=int, default=8)
326 | parser.add_argument(
327 | "--data_imsize",
328 | type=int,
329 | default=None,
330 | help="Image size. If None, use the default size of each dataset",
331 | )
332 |
333 | # General model options
334 | parser.add_argument("--model", type=str, default="knet", help="model")
335 | parser.add_argument("--L", type=int, default=2, help="num of layers")
336 | parser.add_argument("--ch", type=int, default=256, help="num of channels")
337 | parser.add_argument(
338 | "--model_imsize",
339 | type=int,
340 | default=None,
341 | help="""
342 | Model's imsize that was set when it was initialized.
343 | This is used when evaluating or when finetuning a pretrained model.
344 | """,
345 | )
346 | parser.add_argument("--autorescale", type=str2bool, default=False)
347 | parser.add_argument("--psize", type=int, default=8, help="patch size")
348 | parser.add_argument("--ksize", type=int, default=1, help="kernel size")
349 | parser.add_argument("--T", type=int, default=8, help="num of recurrence")
350 | parser.add_argument(
351 | "--maxpool", type=str2bool, default=True, help="max pooling or avg pooling"
352 | )
353 | parser.add_argument(
354 | "--heads", type=int, default=8, help="num of heads in self-attention"
355 | )
356 | parser.add_argument(
357 | "--gta",
358 | type=str2bool,
359 | default=True,
360 | help="""
361 | use Geometric Transform Attention (https://github.com/autonomousvision/gta) as positional encoding.
362 | If False, use standard absolute positional encoding
363 | """,
364 | )
365 |
366 | # AKOrN options
367 | parser.add_argument("--N", type=int, default=4, help="num of rotating dimensions")
368 | parser.add_argument("--J", type=str, default="conv", help="connectivity")
369 | parser.add_argument("--use_omega", type=str2bool, default=False)
370 | parser.add_argument("--global_omg", type=str2bool, default=False)
371 | parser.add_argument(
372 | "--c_norm",
373 | type=str,
374 | default="gn",
375 | help="normalization. gn, sandb(scale and bias), or none",
376 | )
377 |
378 | parser.add_argument(
379 | "--use_ro_x",
380 | type=str2bool,
381 | default=False,
382 | help="apply linear transform to oscillators between consecutive layers",
383 | )
384 |
385 | # ablation of some components in the AKOrN's block
386 | parser.add_argument(
387 | "--no_ro", type=str2bool, default=False, help="ablation: no use readout module"
388 | )
389 | parser.add_argument(
390 | "--project",
391 | type=str2bool,
392 | default=True,
393 | help="use projection or not in the Kuramoto layer",
394 | )
395 |
396 | args = parser.parse_args()
397 |
398 | torch.backends.cudnn.benchmark = True
399 | torch.backends.cuda.enable_flash_sdp(enabled=True)
400 |
401 | if args.limit_cores_used:
402 | def worker_init_fn(worker_id):
403 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end))
404 |
405 | if args.model == "akorn":
406 | net = AKOrN(
407 | args.N,
408 | ch=args.ch,
409 | L=args.L,
410 | T=args.T,
411 | J=args.J, # "conv" or "attn",
412 | use_omega=args.use_omega,
413 | global_omg=args.global_omg,
414 | c_norm=args.c_norm,
415 | psize=args.psize,
416 | imsize=args.model_imsize,
417 | autorescale=args.autorescale,
418 | maxpool=args.maxpool,
419 | project=args.project,
420 | heads=args.heads,
421 | use_ro_x=args.use_ro_x,
422 | no_ro=args.no_ro,
423 | gta=args.gta,
424 | ).to("cuda")
425 | elif args.model == "vit":
426 | net = ViT(
427 | psize=args.psize,
428 | imsize=args.model_imsize,
429 | autorescale=args.autorescale,
430 | ch=args.ch,
431 | blocks=args.L,
432 | heads=args.heads,
433 | mlp_dim=2 * args.ch,
434 | T=args.T,
435 | maxpool=args.maxpool,
436 | gta=args.gta,
437 | ).cuda()
438 |
439 | model = EMA(net)
440 | model.load_state_dict(torch.load(args.model_path, weights_only=True)["model_state_dict"])
441 | model = model.ema_model
442 |
443 | with torch.no_grad():
444 | scores, preds = eval_dataset(
445 | model,
446 | data=args.data,
447 | data_root=args.data_root,
448 | imsize=args.data_imsize,
449 | batchsize=args.batchsize,
450 | instance=True,
451 | method="agglomerative",
452 | saccade_r=args.saccade_r,
453 | pca=args.pca,
454 | )
455 | print_stats(scores)
456 |
--------------------------------------------------------------------------------
/eval_sudoku.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import torch
3 | import torch.nn
4 | import torch.optim
5 | import tqdm
6 | import torchvision
7 | from torchvision import transforms
8 | import numpy as np
9 | from torch.optim.swa_utils import AveragedModel
10 | import matplotlib.pyplot as plt
11 |
12 | from source.data.datasets.sudoku.sudoku import SudokuDataset, HardSudokuDataset
13 | from source.models.sudoku.knet import SudokuAKOrN
14 | from source.models.sudoku.transformer import SudokuTransformer
15 | from source.evals.sudoku.evals import compute_board_accuracy
16 | from source.utils import str2bool
17 | from ema_pytorch import EMA
18 | import argparse
19 |
20 | if __name__ == "__main__":
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument("--model_path", type=str, help="path to the model")
25 |
26 | # Data loading
27 | parser.add_argument("--data", type=str, default="id", help="data")
28 | parser.add_argument("--limit_cores_used", type=str2bool, default=False)
29 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core")
30 | parser.add_argument("--cpu_core_end", type=int, default=16, help="end core")
31 | parser.add_argument(
32 | "--data_root",
33 | type=str,
34 | default=None,
35 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset",
36 | )
37 | parser.add_argument("--batchsize", type=int, default=100)
38 | parser.add_argument("--num_workers", type=int, default=4)
39 |
40 | # General model options
41 | parser.add_argument("--model", type=str, default="akorn", help="model")
42 | parser.add_argument("--L", type=int, default=1, help="num of layers")
43 | parser.add_argument("--T", type=int, default=16, help="Timesteps")
44 | parser.add_argument("--ch", type=int, default=512, help="num of channels")
45 | parser.add_argument("--heads", type=int, default=8)
46 |
47 | # AKOrN options
48 | parser.add_argument("--N", type=int, default=4)
49 | parser.add_argument(
50 | "--K",
51 | type=int,
52 | default=1,
53 | help="num of random oscillator samples for each input",
54 | )
55 | parser.add_argument("--minimum_chunk", type=int, default=None)
56 | parser.add_argument("--evote_type", type=str, default="last", help="last or sum")
57 | parser.add_argument("--gamma", type=float, default=1.0, help="step size")
58 | parser.add_argument("--J", type=str, default="attn", help="connectivity")
59 | parser.add_argument("--use_omega", type=str2bool, default=True)
60 | parser.add_argument("--global_omg", type=str2bool, default=True)
61 | parser.add_argument("--learn_omg", type=str2bool, default=False)
62 | parser.add_argument("--init_omg", type=float, default=0.1)
63 | parser.add_argument("--nl", type=str2bool, default=True)
64 |
65 | parser.add_argument("--speed_test", action="store_true")
66 |
67 | args = parser.parse_args()
68 |
69 | torch.backends.cudnn.benchmark = True
70 | torch.backends.cuda.enable_flash_sdp(enabled=True)
71 |
72 | if args.limit_cores_used:
73 |
74 | def worker_init_fn(worker_id):
75 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end))
76 |
77 | else:
78 | worker_init_fn = None
79 |
80 | if args.data == "id":
81 | loader = torch.utils.data.DataLoader(
82 | SudokuDataset(
83 | args.data_root if args.data_root is not None else "./data/sudoku",
84 | train=False,
85 | ),
86 | batch_size=args.batchsize,
87 | shuffle=False,
88 | num_workers=args.num_workers,
89 | worker_init_fn=worker_init_fn,
90 | )
91 | elif args.data == "ood":
92 | loader = torch.utils.data.DataLoader(
93 | HardSudokuDataset(
94 | args.data_root if args.data_root is not None else "./data/sudoku-rrn",
95 | split="test",
96 | ),
97 | batch_size=args.batchsize,
98 | shuffle=False,
99 | num_workers=args.num_workers,
100 | worker_init_fn=worker_init_fn,
101 | )
102 | else:
103 | raise NotImplementedError
104 |
105 | if args.model == "akorn":
106 | print(
107 | f"n: {args.N}, ch: {args.ch}, L: {args.L}, T: {args.T}, type of J: {args.J}"
108 | )
109 | net = SudokuAKOrN(
110 | n=args.N,
111 | ch=args.ch,
112 | L=args.L,
113 | T=args.T,
114 | gamma=args.gamma,
115 | J=args.J,
116 | use_omega=args.use_omega,
117 | global_omg=args.global_omg,
118 | init_omg=args.init_omg,
119 | learn_omg=args.learn_omg,
120 | nl=args.nl,
121 | heads=args.heads,
122 | )
123 | elif args.model == "itrsa":
124 | net = SudokuTransformer(
125 | ch=args.ch,
126 | blocks=args.L,
127 | heads=args.heads,
128 | mlp_dim=args.ch * 2,
129 | T=args.T,
130 | gta=False,
131 | )
132 | else:
133 | raise NotImplementedError
134 |
135 | model = EMA(net).cuda()
136 | model.load_state_dict(
137 | torch.load(args.model_path, weights_only=True)["model_state_dict"]
138 | )
139 | model = model.ema_model
140 | model.eval()
141 |
142 | K = args.K
143 |
144 | corrects_vote = 0
145 | corrects_avg = 0
146 | totals = 0
147 |
148 | minimum_chunk = args.minimum_chunk if args.minimum_chunk is not None else K
149 |
150 | for i, (X, Y, is_input) in tqdm.tqdm(enumerate(loader)):
151 | B = X.shape[0]
152 | if args.model == 'akorn' and K > 1: # Energy-based voting
153 | for j in range(B):
154 | preds = []
155 | es_list = []
156 | for k in range(K//minimum_chunk):
157 |
158 | _X = X[j : j + 1].repeat(minimum_chunk, 1, 1, 1)
159 | _Y = Y[j : j + 1].repeat(minimum_chunk, 1, 1, 1)
160 | _is_input = is_input[j : j + 1].repeat(minimum_chunk, 1, 1, 1)
161 | _X, _Y, _is_input = (
162 | _X.to(torch.int32).cuda(),
163 | _Y.cuda(),
164 | _is_input.cuda(),
165 | )
166 |
167 | with torch.no_grad():
168 | pred, es = model(_X, _is_input, return_es=True)
169 | preds.append(pred.detach())
170 | if args.evote_type =='sum':
171 | # the sum of energy values over timesteps as board correctness indicator
172 | es = torch.stack(es[-1], 0).sum(0).detach()
173 | elif args.evote_type == 'last':
174 | es = es[-1][-1].detach()
175 | es_list.append(es)
176 |
177 | pred = torch.cat(preds, 0)
178 | es = torch.cat(es_list, 0)
179 |
180 | idxes = torch.argsort(es) # minimum energy first
181 | pred_vote = pred[idxes[:1]].mean(0, keepdim=True)
182 | pred_avg = pred.mean(0, keepdim=True)
183 |
184 | _, _, board_correct_vote = compute_board_accuracy(
185 | pred_vote, _Y[:1], _is_input[:1]
186 | )
187 |
188 | corrects_vote += board_correct_vote.sum().item()
189 | totals += board_correct_vote.numel()
190 |
191 | else:
192 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda()
193 | with torch.no_grad():
194 | pred = model(X, is_input)
195 | num_blanks, num_corrects, board_correct = compute_board_accuracy(pred, Y, is_input)
196 | corrects_vote += board_correct.sum().item()
197 | totals += board_correct.numel()
198 |
199 | # Compute mean and standard deviation across networks
200 | accuracy_vote = corrects_vote / totals
201 |
202 | print(f"Vote accuracy: {accuracy_vote:.4f}")
203 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | einops
5 | jupyter
6 | matplotlib
7 | tensorboard
8 | tqdm
9 | argparse
10 | git+https://github.com/fra31/auto-attack
11 | tensorflow-cpu
12 | ema_pytorch
13 | accelerate
14 | scipy
15 | scikit-learn
16 | scikit-image
17 | timm
18 | opencv-python
19 | pycocotools
20 | fastcluster
21 |
--------------------------------------------------------------------------------
/scripts/sudoku.md:
--------------------------------------------------------------------------------
1 | ## Donwload the sudoku datasets
2 | ```
3 | cd data
4 | bash download_satnet.sh
5 | bash download_rrn.sh
6 | cd ..
7 | ```
8 |
9 | ## Training
10 | ### AKOrN
11 | ```
12 | python train_sudoku.py --exp_name=sudoku_akorn --eval_freq=10 --epochs=100 --model=akorn --lr=0.001 --T=16 --use_omega=True --global_omg=True --init_omg=0.5 --learn_omg=True
13 | ```
14 |
15 | ### ItrSA and Transformer
16 | ```
17 | # ItrSA
18 | python train_sudoku.py --exp_name=sudoku_itrsa --eval_freq=10 --epochs=100 --model=itrsa --lr=0.0005 --T=16
19 | # Transformer
20 | python train_sudoku.py --exp_name=sudoku_itrsa --eval_freq=10 --epochs=100 --model=itrsa --lr=0.0005 --T=1 --L=8
21 | ```
22 |
23 | ## Evaluation
24 |
25 | ### Inference with the test-time extension of the Kuramoto updates($T_{\rm eval}=128$) .
26 | ```
27 | export data=ood # id or ood
28 | python eval_sudoku.py --data=${data} --model=akorn --model_path=runs/sudoku_akorn/ema_99.pth --T=128
29 | ```
30 |
31 | ### Test-time extension of the K-updates + Energy-based voting ($T_{\rm eval}=128, (num\ random\ samples)=100$ ) .
32 | ```
33 | python eval_sudoku.py --data=${data} --model=akorn --model_path=runs/sudoku_akorn/ema_99.pth --T=128 --K=100 --evote_type=sum
34 | ```
35 |
36 |
37 | ### Performance table
38 | | Model | ID | OOD |
39 | |-------------------------|--------------------|-----------------|
40 | | Transformer | 98.6±0.3 | 5.2±0.2 |
41 | | ItrSA ($T_{\rm eval}=32$) | 95.7±8.5 | 34.4±5.4 |
42 | | AKOrN ($T_{\rm eval}=128$) | 100.0±0.0 | 51.7±3.3 |
43 | | AKOrN ($T_{\rm eval}=128, K=100$) | 100.0±0.0 | 81.6±1.5 |
44 | | AKOrN ($T_{\rm eval}=128, K=4096$) | 100.0±0.0 | 89.5±2.5 |
45 |
46 | ### Visualization of oscillator dynamics over timesteps
47 | 
48 |
49 |
--------------------------------------------------------------------------------
/scripts/synths.md:
--------------------------------------------------------------------------------
1 | ## Donwload the synthetic datasets (Tetrominoes, dSprits, CLEVR)
2 | ```
3 | # Need gsutil installed. run `conda install conda-forge::gsutil` to install it or manually download datasets from https://console.cloud.google.com/storage/browser/multi-object-datasets;tab=objects?pli=1&inv=1&invt=AbjJBg&prefix=&forceOnObjectsSortingFiltering=false
4 | cd data
5 | bash download_synths.sh
6 | cd ..
7 | ```
8 |
9 | ## Training
10 | ```
11 | export NUM_GPUS= # If you use a single GPU, run a command without the multi GPU option arguments (`--multi-gpu --num_processes=$NUM_GPUS`).
12 | ```
13 |
14 | #### AKOrN (Attentive models)
15 | ```
16 | #Tetrominoes
17 | export dataset=tetrominoes; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=128 --psize=4 --epochs=50 --c_norm=none
18 | #dSprites
19 | export dataset=dsprites; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=128 --psize=4 --epochs=50 --c_norm=none
20 | #CLEVR
21 | export dataset=clevr accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_attn --model=akorn --data=${dataset} --J=attn --L=1 --ch=256 --psize=8 --epochs=300 --c_norm=none
22 | ```
23 |
24 | #### AKOrN (Convolutional models)
25 | ```
26 | #Tetrominoes
27 | export dataset=tetrominoes; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=5 --ch=128 --psize=4 --epochs=50 --c_norm=none
28 | #dSprites
29 | export dataset=dsprites; accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=7 --ch=128 --psize=4 --epochs=50 --c_norm=none
30 | #CLEVR
31 | export dataset=clevr accelerate launch --num_processes=${NUM_GPUS} train_obj.py --exp_name=${dataset}_akorn_conv --model=akorn --data=${dataset} --J=conv --L=1 --ksize=7 --ch=256 --psize=8 --epochs=300 --c_norm=none
32 | ```
33 |
34 | #### ItrSA
35 | ```
36 | export dataset=tetrominoes; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=128 --psize=4 --epochs=50
37 | export dataset=dsprites; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=128 --psize=4 --epochs=50
38 | export dataset=clevr; accelerate launch --multi-gpu --num_processes=$NUM_GPUS train_obj.py --exp_name=${dataset}_itrsa --data=${dataset} --model=vit --L=1 --gta=False --T=8 --ch=256 --psize=8 --epochs=300
39 | ```
40 |
41 | ## Evaluation
42 |
43 | ```
44 | export DATA=tetrominoes #{tetrominoes, dsprits, clevr}. Please adjust the model parameters (–model, –J, –ch, –psize) based on the dataset and model you want to evaluate.
45 | export IMSIZE=32 # {32:tetrominoes, 64:dsprites, 128:clevr}.
46 | python eval_obj.py --model=akorn --data=${DATA} --J=attn --L=$L$ --model_path=runs/${dataset}_akorn_attn --model_imsize=32 --J=attn --L=1 --T=8 --ch=128 --psize=4 --c_norm=none
47 | ```
48 |
49 | #### Performance table
50 | | Model | Tetrominoes FG-ARI | Tetrominoes MBO | dSprites FG-ARI | dSprites MBO | CLEVR FG-ARI | CLEVR MBO |
51 | |-------------------------|--------------------|-----------------|-----------------|--------------|--------------|-----------|
52 | | ItrConv | 59.0±2.9 | 51.6±2.2 | 29.1±6.2 | 38.5±5.2 | 49.3±5.1 | 29.7±3.0 |
53 | | AKOrN (conv) | 76.4±0.8 | 51.9±1.5 | 63.8±7.7 | 50.7±4.7 | 59.0±4.3 | 44.4±2.0 |
54 | | ItrSA | 85.8±0.8 | 54.9±3.4 | 68.1±1.4 | 63.0±1.2 | 82.5±1.7 | 39.4±1.9 |
55 | | AKOrN (attn) | 88.6±1.7 | 56.4±0.9 | 78.3±1.3 | 63.0±1.8 | 91.0±0.5 | 45.5±1.4 |
56 |
57 | ##### (+up-tiling (×4))
58 | | Model | Tetrominoes FG-ARI | Tetrominoes MBO | dSprites FG-ARI | dSprites MBO | CLEVR FG-ARI | CLEVR MBO |
59 | |-------------------------|--------------------|-----------------|-----------------|--------------|--------------|-----------|
60 | | AKOrN^attn | 93.1±0.3 | 56.3±0.0 | 87.1±1.0 | 60.2±1.9 | 94.6±0.7 | 44.7±0.7 |
61 |
--------------------------------------------------------------------------------
/source/data/augs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 |
4 |
5 | def gauss_noise_tensor(sigma=0.1):
6 | def fn(img):
7 | out = img + sigma * torch.randn_like(img)
8 | out = torch.clamp(out, 0, 1) # pixel space is [0, 1]
9 | return out
10 |
11 | return fn
12 |
13 |
14 | def augmentation_strong(noise=0.0, imsize=32):
15 | transform_aug = transforms.Compose(
16 | [
17 | transforms.RandomHorizontalFlip(),
18 | transforms.RandomResizedCrop(imsize, scale=(0.2, 1.0)),
19 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
20 | transforms.AugMix(),
21 | transforms.ToTensor(),
22 | gauss_noise_tensor(noise) if noise > 0 else lambda x: x,
23 | ]
24 | )
25 | return transform_aug
26 |
27 |
28 | def simclr_augmentation(imsize, hflip=False):
29 | return transforms.Compose(
30 | [
31 | transforms.RandomResizedCrop(imsize),
32 | transforms.RandomHorizontalFlip(0.5) if hflip else lambda x: x,
33 | get_color_distortion(s=0.5),
34 | transforms.ToTensor(),
35 | ]
36 | )
37 |
38 |
39 | def random_Linf_noise(trnsfms: transforms.Compose = None, epsilon=64 / 255):
40 | if trnsfms is None:
41 | trnsfms = transforms.Compose([transforms.ToTensor()])
42 |
43 | randeps = torch.rand(1).item() * epsilon
44 |
45 | def fn(x):
46 | x = x + randeps * torch.randn_like(x).sign()
47 | return torch.clamp(x, 0, 1)
48 |
49 | trnsfms.transforms.append(fn)
50 | return trnsfms
51 |
52 |
53 | def get_color_distortion(s=0.5):
54 | # s is the strength of color distortion
55 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
56 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
57 | rnd_gray = transforms.RandomGrayscale(p=0.2)
58 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
59 | return color_distort
60 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/clevr.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from torchvision import transforms
4 |
5 | from source.data.augs import simclr_augmentation
6 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset
7 |
8 |
9 | def get_clevr_pair(root, split="train", imsize=128, hflip=False):
10 | path = Path(root, f"clevr_{split}.npz")
11 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip))
12 |
13 |
14 | def get_clevr(root, split="train", imsize=128):
15 | path = Path(root, f"clevr_{split}.npz")
16 | transform = transforms.Compose(
17 | [
18 | transforms.Resize(imsize),
19 | transforms.ToTensor(),
20 | ]
21 | )
22 | return NumpyDataset(path, transform=transform)
23 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/clevr_tex.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/karazijal/clevrtex-generation/blob/main/clevrtex_eval.py
2 | import json
3 | from collections import defaultdict
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision.transforms.functional as Ft
10 | from PIL import Image
11 | from scipy.optimize import linear_sum_assignment
12 | from sklearn.metrics import adjusted_rand_score
13 |
14 |
15 | class DatasetReadError(ValueError):
16 | pass
17 |
18 |
19 | def collate_fn(batch):
20 | return (
21 | *torch.utils.data._utils.collate.default_collate(
22 | [(b[0], b[1], b[2]) for b in batch]
23 | ),
24 | [b[3] for b in batch],
25 | )
26 |
27 |
28 | class CLEVRTEX:
29 | ccrop_frac = 0.8
30 | splits = {"test": (0.0, 0.1), "val": (0.1, 0.2), "train": (0.2, 1.0)}
31 | shape = (3, 240, 320)
32 | variants = {"full", "pbg", "vbg", "grassbg", "camo", "outd"}
33 |
34 | def _index_with_bias_and_limit(self, idx):
35 | if idx >= 0:
36 | idx += self.bias
37 | if idx >= self.limit:
38 | raise IndexError()
39 | else:
40 | idx = self.limit + idx
41 | if idx < self.bias:
42 | raise IndexError()
43 | return idx
44 |
45 | def _reindex(self):
46 | print(f"Indexing {self.basepath}")
47 |
48 | img_index = {}
49 | msk_index = {}
50 | met_index = {}
51 |
52 | prefix = f"CLEVRTEX_{self.dataset_variant}_"
53 |
54 | img_suffix = ".png"
55 | msk_suffix = "_flat.png"
56 | met_suffix = ".json"
57 |
58 | _max = 0
59 | for img_path in self.basepath.glob(f"**/{prefix}??????{img_suffix}"):
60 | indstr = img_path.name.replace(prefix, "").replace(img_suffix, "")
61 | msk_path = img_path.parent / f"{prefix}{indstr}{msk_suffix}"
62 | met_path = img_path.parent / f"{prefix}{indstr}{met_suffix}"
63 | indstr_stripped = indstr.lstrip("0")
64 | if indstr_stripped:
65 | ind = int(indstr)
66 | else:
67 | ind = 0
68 | if ind > _max:
69 | _max = ind
70 |
71 | if not msk_path.exists():
72 | raise DatasetReadError(f"Missing {msk_suffix.name}")
73 |
74 | if ind in img_index:
75 | raise DatasetReadError(f"Duplica {ind}")
76 |
77 | img_index[ind] = img_path
78 | msk_index[ind] = msk_path
79 | if self.return_metadata:
80 | if not met_path.exists():
81 | raise DatasetReadError(f"Missing {met_path.name}")
82 | met_index[ind] = met_path
83 | else:
84 | met_index[ind] = None
85 |
86 | if len(img_index) == 0:
87 | raise DatasetReadError(f"No values found")
88 | missing = [i for i in range(0, _max) if i not in img_index]
89 | if missing:
90 | raise DatasetReadError(f"Missing images numbers {missing}")
91 |
92 | return img_index, msk_index, met_index
93 |
94 | def _variant_subfolder(self):
95 | return f"clevrtex_{self.dataset_variant.lower()}"
96 |
97 | def __init__(
98 | self,
99 | path: Path,
100 | dataset_variant="full",
101 | split="train",
102 | crop=True,
103 | resize=(128, 128),
104 | return_metadata=True,
105 | transform=None,
106 | ):
107 | self.transform = transform
108 | self.return_metadata = return_metadata
109 | self.crop = crop
110 | self.resize = resize
111 | if dataset_variant not in self.variants:
112 | raise DatasetReadError(
113 | f"Unknown variant {dataset_variant}; [{', '.join(self.variants)}] available "
114 | )
115 |
116 | if split not in self.splits:
117 | raise DatasetReadError(
118 | f"Unknown split {split}; [{', '.join(self.splits)}] available "
119 | )
120 | if dataset_variant == "outd":
121 | # No dataset splits in
122 | split = None
123 |
124 | self.dataset_variant = dataset_variant
125 | self.split = split
126 |
127 | self.basepath = Path(path)
128 | if not self.basepath.exists():
129 | raise DatasetReadError()
130 | sub_fold = self._variant_subfolder()
131 | if self.basepath.name != sub_fold:
132 | self.basepath = self.basepath / sub_fold
133 | # try:
134 | # with (self.basepath / 'manifest_ind.json').open('r') as inf:
135 | # self.index = json.load(inf)
136 | # except (json.JSONDecodeError, IOError, FileNotFoundError):
137 | self.index, self.mask_index, self.metadata_index = self._reindex()
138 |
139 | print(f"Sourced {dataset_variant} ({split}) from {self.basepath}")
140 |
141 | bias, limit = self.splits.get(split, (0.0, 1.0))
142 | if isinstance(bias, float):
143 | bias = int(bias * len(self.index))
144 | if isinstance(limit, float):
145 | limit = int(limit * len(self.index))
146 | self.limit = limit
147 | self.bias = bias
148 |
149 | def _format_metadata(self, meta):
150 | """
151 | Drop unimportanat, unsued or incorrect data from metadata.
152 | Data may become incorrect due to transformations,
153 | such as cropping and resizing would make pixel coordinates incorrect.
154 | Furthermore, only VBG dataset has color assigned to objects, we delete the value for others.
155 | """
156 | objs = []
157 | for obj in meta["objects"]:
158 | o = {
159 | "material": obj["material"],
160 | "shape": obj["shape"],
161 | "size": obj["size"],
162 | "rotation": obj["rotation"],
163 | }
164 | if self.dataset_variant == "vbg":
165 | o["color"] = obj["color"]
166 | objs.append(o)
167 | return {"ground_material": meta["ground_material"], "objects": objs}
168 |
169 | def __len__(self):
170 | return self.limit - self.bias
171 |
172 | def __getitem__(self, ind):
173 | ind = self._index_with_bias_and_limit(ind)
174 |
175 | img = Image.open(self.index[ind]).convert("RGB")
176 | msk = Image.open(self.mask_index[ind])
177 |
178 | if self.crop:
179 | crop_size = int(0.8 * float(min(img.width, img.height)))
180 | img = img.crop(
181 | (
182 | (img.width - crop_size) // 2,
183 | (img.height - crop_size) // 2,
184 | (img.width + crop_size) // 2,
185 | (img.height + crop_size) // 2,
186 | )
187 | )
188 | msk = msk.crop(
189 | (
190 | (msk.width - crop_size) // 2,
191 | (msk.height - crop_size) // 2,
192 | (msk.width + crop_size) // 2,
193 | (msk.height + crop_size) // 2,
194 | )
195 | )
196 | if self.resize:
197 | img = img.resize(self.resize, resample=Image.BILINEAR)
198 | msk = msk.resize(self.resize, resample=Image.NEAREST)
199 |
200 | if self.transform is not None:
201 | img = self.transform(np.array(img))
202 | # img = Ft.to_tensor(np.array(img)[..., :3])
203 | msk = torch.from_numpy(np.array(msk))[None]
204 |
205 | ret = (ind, img, msk)
206 |
207 | if self.return_metadata:
208 | with self.metadata_index[ind].open("r") as inf:
209 | meta = json.load(inf)
210 | ret = (ind, img, msk, self._format_metadata(meta))
211 |
212 | return ret
213 |
214 |
215 | def collate_fn(batch):
216 | return (
217 | *torch.utils.data._utils.collate.default_collate(
218 | [(b[0], b[1], b[2]) for b in batch]
219 | ),
220 | [b[3] for b in batch],
221 | )
222 |
223 |
224 | class RunningMean:
225 | def __init__(self):
226 | self.v = 0.0
227 | self.n = 0
228 |
229 | def update(self, v, n=1):
230 | self.v += v * n
231 | self.n += n
232 |
233 | def value(self):
234 | if self.n:
235 | return self.v / (self.n)
236 | else:
237 | return float("nan")
238 |
239 | def __str__(self):
240 | return str(self.value())
241 |
242 |
243 | class CLEVRTEX_Evaluator:
244 | def __init__(self, masks_have_background=True):
245 | self.masks_have_background = masks_have_background
246 | self.stats = defaultdict(RunningMean)
247 | self.tags = defaultdict(lambda: defaultdict(lambda: defaultdict(RunningMean)))
248 |
249 | def ari(self, pred_mask, true_mask, skip_0=False):
250 | B = pred_mask.shape[0]
251 | pm = pred_mask.argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy()
252 | tm = true_mask.argmax(axis=1).squeeze().view(B, -1).cpu().detach().numpy()
253 | aris = []
254 | for bi in range(B):
255 | t = tm[bi]
256 | p = pm[bi]
257 | if skip_0:
258 | p = p[t > 0]
259 | t = t[t > 0]
260 | ari_score = adjusted_rand_score(t, p)
261 | if ari_score != ari_score:
262 | print(f"NaN at bi")
263 | aris.append(ari_score)
264 | aris = torch.tensor(np.array(aris), device=pred_mask.device)
265 | return aris
266 |
267 | def msc(self, pred_mask, true_mask):
268 | B = pred_mask.shape[0]
269 | bpm = pred_mask.argmax(axis=1).squeeze()
270 | btm = true_mask.argmax(axis=1).squeeze()
271 | covering = torch.zeros(B, device=pred_mask.device, dtype=torch.float)
272 | for bi in range(B):
273 | score = 0.0
274 | norms = 0.0
275 | for ti in range(btm[bi].max()):
276 | tm = btm[bi] == ti
277 | if not torch.any(tm):
278 | continue
279 | iou_max = 0.0
280 | for pi in range(bpm[bi].max()):
281 | pm = bpm[bi] == pi
282 | if not torch.any(pm):
283 | continue
284 | iou = (tm & pm).to(torch.float).sum() / (tm | pm).to(
285 | torch.float
286 | ).sum()
287 | if iou > iou_max:
288 | iou_max = iou
289 | r = tm.to(torch.float).sum()
290 | score += r * iou_max
291 | norms += r
292 | covering[bi] = score / norms
293 | return covering
294 |
295 | def reindex(self, tensor, reindex_tensor, dim=1):
296 | """
297 | Reindexes tensor along using reindex_tensor.
298 | Effectivelly permutes for each dimensions 1), "Predicted masks sum out to more than 1."
392 | if not self.masks_have_background:
393 | # Some models predict only foreground masks.
394 | # For convenienve we calculate background masks.
395 | pred_masks = torch.cat([1.0 - total_pred_masks, pred_masks], dim=1)
396 |
397 | # Decide the masks Should we effectivelly threshold them?
398 | K = pred_masks.shape[1]
399 | pred_masks = pred_masks.argmax(dim=1)
400 | pred_masks = (
401 | pred_masks.unsqueeze(1)
402 | == torch.arange(K, device=pred_masks.device).view(1, -1, 1, 1, 1)
403 | ).to(torch.float)
404 | # Coerce true_Masks into known form
405 | if len(true_masks.shape) == 4:
406 | if true_masks.shape[1] == 1:
407 | # Need to expand into masks
408 | true_masks = (
409 | true_masks.unsqueeze(1)
410 | == torch.arange(
411 | max(true_masks.max() + 1, pred_masks.shape[1]),
412 | device=true_masks.device,
413 | ).view(1, -1, 1, 1, 1)
414 | ).to(pred_image.dtype)
415 | else:
416 | true_masks = true_masks.unsqueeze(2)
417 | true_masks = true_masks.view(pred_image.shape[0], -1, 1, *pred_image.shape[-2:])
418 |
419 | K = max(true_masks.shape[1], pred_masks.shape[1])
420 | if true_masks.shape[1] < K:
421 | true_masks = torch.cat(
422 | [
423 | true_masks,
424 | true_masks.new_zeros(
425 | true_masks.shape[0],
426 | K - true_masks.shape[1],
427 | 1,
428 | *true_masks.shape[-2:],
429 | ),
430 | ],
431 | dim=1,
432 | )
433 | if pred_masks.shape[1] < K:
434 | pred_masks = torch.cat(
435 | [
436 | pred_masks,
437 | pred_masks.new_zeros(
438 | pred_masks.shape[0],
439 | K - pred_masks.shape[1],
440 | 1,
441 | *pred_masks.shape[-2:],
442 | ),
443 | ],
444 | dim=1,
445 | )
446 |
447 | mse = F.mse_loss(pred_image, true_image, reduction="none").sum((1, 2, 3))
448 | self.add_statistic("MSE", mse)
449 |
450 | # If argmax above, these masks are either 0 or 1
451 | pred_count = (
452 | (pred_masks >= 0.5).any(-1).any(-1).any(-1).to(torch.float).sum(-1)
453 | ) # shape: (B,)
454 | true_count = (
455 | (true_masks >= 0.5).any(-1).any(-1).any(-1).to(torch.float).sum(-1)
456 | ) # shape: (B,)
457 | accuracy = (true_count == pred_count).to(torch.float)
458 | self.add_statistic("acc", accuracy)
459 |
460 | pred_reindex, ious, _ = self.ious_alignment(pred_masks, true_masks)
461 | pred_masks = self.reindex(pred_masks, pred_reindex, dim=1)
462 |
463 | truem = true_masks.any(-1).any(-1).any(-1)
464 | predm = pred_masks.any(-1).any(-1).any(-1)
465 |
466 | vism = truem | predm
467 | num_pairs = vism.to(torch.float).sum(-1)
468 |
469 | # mIoU
470 | mIoU = ious.sum(-1) / num_pairs
471 | mIoU_fg = ious[:, 1:].sum(-1) / (
472 | num_pairs - 1
473 | ) # do not consider the background
474 | mIoU_gt = ious.sum(-1) / truem.to(torch.float).sum(-1)
475 |
476 | self.add_statistic("mIoU", mIoU)
477 | self.add_statistic("mIoU_fg", mIoU_fg)
478 | self.add_statistic("mIoU_gt", mIoU_gt)
479 |
480 | msc = self.msc(pred_masks, true_masks)
481 | self.add_statistic("mSC", msc)
482 |
483 | # DICE
484 | dices = (
485 | 2
486 | * (pred_masks * true_masks).sum((-3, -2, -1))
487 | / (pred_masks.sum((-3, -2, -1)) + true_masks.sum((-3, -2, -1)))
488 | )
489 | dices = torch.nan_to_num(
490 | dices, nan=0.0, posinf=0.0
491 | ) # if there were any empties, they now have 0. DICE
492 |
493 | dice = dices.sum(-1) / num_pairs
494 | dice_fg = dices[:, 1:].sum(-1) / (num_pairs - 1)
495 | self.add_statistic("DICE", dice)
496 | self.add_statistic("DICE_FG", dice_fg)
497 |
498 | # ARI
499 | ari = self.ari(pred_masks, true_masks)
500 | ari_fg = self.ari(pred_masks, true_masks, skip_0=True)
501 | if torch.any(torch.isnan(ari_fg)):
502 | print("NaN ari_fg")
503 | if torch.any(torch.isinf(ari_fg)):
504 | print("Inf ari_fg")
505 | self.add_statistic("ARI", ari)
506 | self.add_statistic("ARI_FG", ari_fg)
507 |
508 | # mAP --?
509 |
510 | if true_metadata is not None:
511 | smses = F.mse_loss(
512 | pred_image[:, None] * true_masks,
513 | true_image[:, None] * true_masks,
514 | reduction="none",
515 | ).sum((-1, -2, -3))
516 |
517 | for bi, meta in enumerate(true_metadata):
518 | # ground
519 | self.add_statistic(
520 | "ground_mse", smses[bi, 0], ground_material=meta["ground_material"]
521 | )
522 | self.add_statistic(
523 | "ground_iou", ious[bi, 0], ground_material=meta["ground_material"]
524 | )
525 |
526 | for i, obj in enumerate(meta["objects"]):
527 | tags = {k: v for k, v in obj.items() if k != "rotation"}
528 | if truem[bi, i + 1]:
529 | self.add_statistic("obj_mse", smses[bi, i + 1], **tags)
530 | self.add_statistic("obj_iou", ious[bi, i + 1], **tags)
531 | # Maybe number of components?
532 | return pred_masks, true_masks
533 |
534 |
535 | class CLEVRTEXPair(CLEVRTEX):
536 | """Generate mini-batche pairs on CIFAR10 training set."""
537 |
538 | def __getitem__(self, ind):
539 | ind = self._index_with_bias_and_limit(ind)
540 |
541 | img = Image.open(self.index[ind]).convert("RGB")
542 |
543 | if self.crop:
544 | crop_size = int(0.8 * float(min(img.width, img.height)))
545 | img = img.crop(
546 | (
547 | (img.width - crop_size) // 2,
548 | (img.height - crop_size) // 2,
549 | (img.width + crop_size) // 2,
550 | (img.height + crop_size) // 2,
551 | )
552 | )
553 | if self.resize:
554 | img = img.resize(self.resize, resample=Image.BILINEAR)
555 | imgs = [self.transform(img), self.transform(img)]
556 | return torch.stack(imgs) # stack a positive pair
557 |
558 |
559 | import torchvision
560 | from torchvision import transforms
561 | from PIL import Image
562 | from source.data.augs import simclr_augmentation
563 |
564 |
565 | def get_clevrtex_pair(root, split="train", imsize=128, hflip=False):
566 |
567 | return CLEVRTEXPair(
568 | root,
569 | "full",
570 | split,
571 | crop=True,
572 | resize=(imsize, imsize),
573 | return_metadata=False,
574 | transform=simclr_augmentation(imsize=imsize, hflip=hflip),
575 | )
576 |
577 |
578 | def get_clevrtex(
579 | root, split="test", data_type="full", return_meta_data=False, imsize=128
580 | ):
581 | return CLEVRTEX(
582 | root,
583 | data_type,
584 | split,
585 | crop=True,
586 | resize=(imsize, imsize),
587 | return_metadata=return_meta_data,
588 | transform=torchvision.transforms.ToTensor(),
589 | )
590 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/coco.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/Wuziyi616/SlotDiffusion
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | from pycocotools.coco import COCO
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | import torchvision.utils as vutils
12 |
13 |
14 | import torch
15 | import torchvision.transforms as transforms
16 |
17 |
18 | def suppress_mask_idx(masks):
19 | """Make the mask index 0, 1, 2, ..."""
20 | # the original mask could have not continuous index, 0, 3, 4, 6, 9, 13, ...
21 | # we make them 0, 1, 2, 3, 4, 5, ...
22 | if isinstance(masks, np.ndarray):
23 | pkg = np
24 | elif isinstance(masks, torch.Tensor):
25 | pkg = torch
26 | else:
27 | raise NotImplementedError
28 | obj_idx = pkg.unique(masks)
29 | idx_mapping = pkg.arange(obj_idx.max() + 1)
30 | idx_mapping[obj_idx] = pkg.arange(len(obj_idx))
31 | masks = idx_mapping[masks]
32 | return masks
33 |
34 |
35 | class RandomHorizontalFlip:
36 | """Flip the image and bbox horizontally."""
37 |
38 | def __init__(self, prob=0.5):
39 | self.prob = prob
40 |
41 | def __call__(self, sample):
42 | # [H, W, 3], [H, W(, 2)], [N, 5]
43 | image, masks, annos, scale, size = (
44 | sample["image"],
45 | sample["masks"],
46 | sample["annos"],
47 | sample["scale"],
48 | sample["size"],
49 | )
50 |
51 | if np.random.uniform(0, 1) < self.prob:
52 | image = np.ascontiguousarray(image[:, ::-1, :])
53 | masks = np.ascontiguousarray(masks[:, ::-1])
54 | _, w, _ = image.shape
55 | # adjust annos
56 | if annos.shape[0] > 0:
57 | x1 = annos[:, 0].copy()
58 | x2 = annos[:, 2].copy()
59 | annos[:, 0] = w - x2
60 | annos[:, 2] = w - x1
61 |
62 | return {
63 | "image": image,
64 | "masks": masks,
65 | "annos": annos,
66 | "scale": scale,
67 | "size": size,
68 | }
69 |
70 |
71 | class ResizeMinShape:
72 | """Resize for later center crop."""
73 |
74 | def __init__(self, resolution=(224, 224)):
75 | self.resolution = resolution
76 |
77 | def __call__(self, sample):
78 | image, masks, annos, scale, size = (
79 | sample["image"],
80 | sample["masks"],
81 | sample["annos"],
82 | sample["scale"],
83 | sample["size"],
84 | )
85 | h, w, _ = image.shape
86 | # resize so that the h' is at lease resolution[0]
87 | # and the w' is at lease resolution[1]
88 | factor = max(self.resolution[0] / h, self.resolution[1] / w)
89 | resize_h, resize_w = int(round(h * factor)), int(round(w * factor))
90 | image = cv2.resize(image, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
91 | masks = cv2.resize(masks, (resize_w, resize_h), interpolation=cv2.INTER_NEAREST)
92 | # adjust annos
93 | factor = float(factor)
94 | annos[:, :4] *= factor
95 | scale *= factor
96 | return {
97 | "image": image,
98 | "masks": masks,
99 | "annos": annos,
100 | "scale": scale,
101 | "size": size,
102 | }
103 |
104 |
105 | class CenterCrop:
106 | """Crop the center square of the image."""
107 |
108 | def __init__(self, resolution=(224, 224)):
109 | self.resolution = resolution
110 |
111 | def __call__(self, sample):
112 | image, masks, annos, scale, size = (
113 | sample["image"],
114 | sample["masks"],
115 | sample["annos"],
116 | sample["scale"],
117 | sample["size"],
118 | )
119 |
120 | h, w, _ = image.shape
121 | assert h >= self.resolution[0] and w >= self.resolution[1]
122 | assert h == self.resolution[0] or w == self.resolution[1]
123 |
124 | if h == self.resolution[0]:
125 | crop_ymin = 0
126 | crop_ymax = h
127 | crop_xmin = (w - self.resolution[0]) // 2
128 | crop_xmax = crop_xmin + self.resolution[0]
129 | else:
130 | crop_xmin = 0
131 | crop_xmax = w
132 | crop_ymin = (h - self.resolution[1]) // 2
133 | crop_ymax = crop_ymin + self.resolution[1]
134 | image = image[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
135 | masks = masks[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
136 | # adjust annos
137 | if annos.shape[0] > 0:
138 | annos[:, [0, 2]] = annos[:, [0, 2]] - crop_xmin
139 | annos[:, [1, 3]] = annos[:, [1, 3]] - crop_ymin
140 | # filter out annos that are out of the image
141 | mask1 = (annos[:, 2] > 0) & (annos[:, 3] > 0)
142 | mask2 = (annos[:, 0] < self.resolution[0]) & (
143 | annos[:, 1] < self.resolution[1]
144 | )
145 | annos = annos[mask1 & mask2]
146 | annos[:, [0, 2]] = np.clip(annos[:, [0, 2]], 0, self.resolution[0])
147 | annos[:, [1, 3]] = np.clip(annos[:, [1, 3]], 0, self.resolution[1])
148 |
149 | return {
150 | "image": image,
151 | "masks": masks,
152 | "annos": annos,
153 | "scale": scale,
154 | "size": size,
155 | }
156 |
157 |
158 | class Normalize:
159 | """Normalize the image with mean and std."""
160 |
161 | def __init__(self, mean=0.5, std=0.5):
162 | if isinstance(mean, (list, tuple)):
163 | mean = np.array(mean, dtype=np.float32)[None, None] # [1, 1, 3]
164 | if isinstance(std, (list, tuple)):
165 | std = np.array(std, dtype=np.float32)[None, None] # [1, 1, 3]
166 | self.mean = mean
167 | self.std = std
168 |
169 | def normalize_image(self, image):
170 | image = image.astype(np.float32) / 255.0
171 | image = (image - self.mean) / self.std
172 | return image
173 |
174 | def denormalize_image(self, image):
175 | # simple numbers
176 | if isinstance(self.mean, (int, float)) and isinstance(self.std, (int, float)):
177 | image = image * self.std + self.mean
178 | return image.clamp(0, 1)
179 | # need to convert the shapes
180 | mean = image.new_tensor(self.mean.squeeze()) # [3]
181 | std = image.new_tensor(self.std.squeeze()) # [3]
182 | if image.shape[-1] == 3: # C last
183 | mean = mean[None, None] # [1, 1, 3]
184 | std = std[None, None] # [1, 1, 3]
185 | else: # C first
186 | mean = mean[:, None, None] # [3, 1, 1]
187 | std = std[:, None, None] # [3, 1, 1]
188 | if len(image.shape) == 4: # [B, C, H, W] or [B, H, W, C], batch dim
189 | mean = mean[None]
190 | std = std[None]
191 | image = image * self.std + self.mean
192 | return image.clamp(0, 1)
193 |
194 | def __call__(self, sample):
195 | # [H, W, C]
196 | image, masks, annos, scale, size = (
197 | sample["image"],
198 | sample["masks"],
199 | sample["annos"],
200 | sample["scale"],
201 | sample["size"],
202 | )
203 | image = self.normalize_image(image)
204 | # make mask index start from 0 and continuous
205 | # `masks` is [H, W(, 2 or 3)]
206 | if len(masks.shape) == 3:
207 | assert masks.shape[-1] in [2, 3]
208 | # we don't suppress the last mask since it is the overlapping mask
209 | # i.e. regions with overlapping instances
210 | for i in range(masks.shape[-1] - 1):
211 | masks[:, :, i] = suppress_mask_idx(masks[:, :, i])
212 | else:
213 | masks = suppress_mask_idx(masks)
214 | return {
215 | "image": image,
216 | "masks": masks,
217 | "annos": annos,
218 | "scale": scale,
219 | "size": size,
220 | }
221 |
222 |
223 | class COCOCollater:
224 | """Collect images, annotations, etc. into a batch."""
225 |
226 | def __init__(self):
227 | pass
228 |
229 | def __call__(self, data):
230 | images = [s["image"] for s in data]
231 | masks = [s["masks"] for s in data]
232 | annos = [s["annos"] for s in data]
233 | scales = [s["scale"] for s in data]
234 | sizes = [s["size"] for s in data]
235 |
236 | images = np.stack(images, axis=0) # [B, H, W, C]
237 | images = torch.from_numpy(images).permute(0, 3, 1, 2) # [B, C, H, W]
238 |
239 | masks = np.stack(masks, axis=0)
240 | masks = torch.from_numpy(masks) # [B, H, W(, 2 or 3)]
241 |
242 | max_annos_num = max(anno.shape[0] for anno in annos)
243 | if max_annos_num > 0:
244 | input_annos = np.ones((len(annos), max_annos_num, 5), dtype=np.float32) * (
245 | -1
246 | )
247 | for i, anno in enumerate(annos):
248 | if anno.shape[0] > 0:
249 | input_annos[i, : anno.shape[0], :] = anno
250 | else:
251 | input_annos = np.ones((len(annos), 1, 5), dtype=np.float32) * (-1)
252 | input_annos = torch.from_numpy(input_annos).float()
253 |
254 | scales = torch.from_numpy(np.array(scales)).float()
255 | sizes = torch.from_numpy(np.array(sizes)).float()
256 |
257 | data_dict = {
258 | "img": images.contiguous().float(),
259 | "masks": masks.contiguous().long(),
260 | "annos": input_annos,
261 | "scale": scales,
262 | "size": sizes,
263 | }
264 | if len(masks.shape) == 4:
265 | assert masks.shape[-1] in [2, 3]
266 | if masks.shape[-1] == 3:
267 | data_dict["masks"] = masks[:, :, :, 0]
268 | data_dict["sem_masks"] = masks[:, :, :, 1]
269 | data_dict["inst_overlap_masks"] = masks[:, :, :, 2]
270 | else:
271 | data_dict["masks"] = masks[:, :, :, 0]
272 | data_dict["inst_overlap_masks"] = masks[:, :, :, 1]
273 | return data_dict
274 |
275 |
276 | class COCOTransforms(object):
277 | """Data pre-processing steps."""
278 |
279 | def __init__(
280 | self,
281 | resolution,
282 | val=False,
283 | ):
284 | self.normalize = Normalize(0.0, 1.0)
285 | if val:
286 | self.transforms = transforms.Compose(
287 | [
288 | ResizeMinShape(resolution),
289 | CenterCrop(resolution),
290 | self.normalize,
291 | ]
292 | )
293 | else:
294 | from source.data.augs import get_color_distortion
295 |
296 | self.transforms = transforms.Compose(
297 | [
298 | transforms.RandomResizedCrop(resolution),
299 | ResizeMinShape(resolution),
300 | CenterCrop(resolution),
301 | ]
302 | )
303 | self.resolution = resolution
304 |
305 | def __call__(self, input):
306 | return self.transforms(input)
307 |
308 |
309 | COCO_CLASSES = [
310 | "person",
311 | "bicycle",
312 | "car",
313 | "motorcycle",
314 | "airplane",
315 | "bus",
316 | "train",
317 | "truck",
318 | "boat",
319 | "traffic light",
320 | "fire hydrant",
321 | "stop sign",
322 | "parking meter",
323 | "bench",
324 | "bird",
325 | "cat",
326 | "dog",
327 | "horse",
328 | "sheep",
329 | "cow",
330 | "elephant",
331 | "bear",
332 | "zebra",
333 | "giraffe",
334 | "backpack",
335 | "umbrella",
336 | "handbag",
337 | "tie",
338 | "suitcase",
339 | "frisbee",
340 | "skis",
341 | "snowboard",
342 | "sports ball",
343 | "kite",
344 | "baseball bat",
345 | "baseball glove",
346 | "skateboard",
347 | "surfboard",
348 | "tennis racket",
349 | "bottle",
350 | "wine glass",
351 | "cup",
352 | "fork",
353 | "knife",
354 | "spoon",
355 | "bowl",
356 | "banana",
357 | "apple",
358 | "sandwich",
359 | "orange",
360 | "broccoli",
361 | "carrot",
362 | "hot dog",
363 | "pizza",
364 | "donut",
365 | "cake",
366 | "chair",
367 | "couch",
368 | "potted plant",
369 | "bed",
370 | "dining table",
371 | "toilet",
372 | "tv",
373 | "laptop",
374 | "mouse",
375 | "remote",
376 | "keyboard",
377 | "cell phone",
378 | "microwave",
379 | "oven",
380 | "toaster",
381 | "sink",
382 | "refrigerator",
383 | "book",
384 | "clock",
385 | "vase",
386 | "scissors",
387 | "teddy bear",
388 | "hair drier",
389 | "toothbrush",
390 | ]
391 |
392 | COCO_CLASSES_COLOR = [
393 | (241, 23, 78),
394 | (63, 71, 49),
395 | (67, 79, 143),
396 | (32, 250, 205),
397 | (136, 228, 157),
398 | (135, 125, 104),
399 | (151, 46, 171),
400 | (129, 37, 28),
401 | (3, 248, 159),
402 | (154, 129, 58),
403 | (93, 155, 200),
404 | (201, 98, 152),
405 | (187, 194, 70),
406 | (122, 144, 121),
407 | (168, 31, 32),
408 | (168, 68, 189),
409 | (173, 68, 45),
410 | (200, 81, 154),
411 | (171, 114, 139),
412 | (216, 211, 39),
413 | (187, 119, 238),
414 | (201, 120, 112),
415 | (129, 16, 164),
416 | (211, 3, 208),
417 | (169, 41, 248),
418 | (100, 77, 159),
419 | (140, 104, 243),
420 | (26, 165, 41),
421 | (225, 176, 197),
422 | (35, 212, 67),
423 | (160, 245, 68),
424 | (7, 87, 70),
425 | (52, 107, 85),
426 | (103, 64, 188),
427 | (245, 76, 17),
428 | (248, 154, 59),
429 | (77, 45, 123),
430 | (210, 95, 230),
431 | (172, 188, 171),
432 | (250, 44, 233),
433 | (161, 71, 46),
434 | (144, 14, 134),
435 | (231, 142, 186),
436 | (34, 1, 200),
437 | (144, 42, 108),
438 | (222, 70, 139),
439 | (138, 62, 77),
440 | (178, 99, 61),
441 | (17, 94, 132),
442 | (93, 248, 254),
443 | (244, 116, 204),
444 | (138, 165, 238),
445 | (44, 216, 225),
446 | (224, 164, 12),
447 | (91, 126, 184),
448 | (116, 254, 49),
449 | (70, 250, 105),
450 | (252, 237, 54),
451 | (196, 136, 21),
452 | (234, 13, 149),
453 | (66, 43, 47),
454 | (2, 73, 234),
455 | (118, 181, 5),
456 | (105, 99, 225),
457 | (150, 253, 92),
458 | (59, 2, 121),
459 | (176, 190, 223),
460 | (91, 62, 47),
461 | (198, 124, 140),
462 | (100, 135, 185),
463 | (20, 207, 98),
464 | (216, 38, 133),
465 | (17, 202, 208),
466 | (216, 135, 81),
467 | (212, 203, 33),
468 | (108, 135, 76),
469 | (28, 47, 170),
470 | (142, 128, 121),
471 | (23, 161, 179),
472 | (33, 183, 224),
473 | ]
474 |
475 |
476 | def to_rgb_from_tensor(x):
477 | """Reverse the Normalize operation in torchvision."""
478 | return (x * 0.5 + 0.5).clamp(0, 1)
479 |
480 |
481 | def _draw_bbox(img, anno, bbox_width=2):
482 | """Draw bbox on images.
483 |
484 | Args:
485 | img: (3, H, W), torch.Tensor
486 | anno: (N, 5)
487 | """
488 | anno = anno[anno[:, -1] != -1]
489 | img = torch.round((to_rgb_from_tensor(img) * 255.0)).to(dtype=torch.uint8)
490 | bbox = anno[:, :4]
491 | label = anno[:, -1]
492 | draw_label = [COCO_CLASSES[int(lbl)] for lbl in label]
493 | draw_color = [COCO_CLASSES_COLOR[int(lbl)] for lbl in label]
494 | bbox_img = vutils.draw_bounding_boxes(
495 | img, bbox, labels=draw_label, colors=draw_color, width=bbox_width
496 | )
497 | bbox_img = bbox_img.float() / 255.0 * 2.0 - 1.0
498 | return bbox_img
499 |
500 |
501 | def draw_coco_bbox(imgs, annos, bbox_width=2):
502 | """Draw bbox on batch images.
503 |
504 | Args:
505 | imgs: (B, 3, H, W), torch.Tensor
506 | annos: (B, N, 5)
507 | """
508 | if len(imgs.shape) == 3:
509 | return draw_coco_bbox(imgs[None], annos[None], bbox_width)[0]
510 |
511 | bbox_imgs = []
512 | for img, anno in zip(imgs, annos):
513 | bbox_imgs.append(_draw_bbox(img, anno, bbox_width=bbox_width))
514 | bbox_imgs = torch.stack(bbox_imgs, dim=0)
515 | return bbox_imgs
516 |
517 |
518 | class COCO2017Dataset(Dataset):
519 | """COCO 2017 dataset."""
520 |
521 | def __init__(
522 | self,
523 | data_root,
524 | split,
525 | coco_transforms=None,
526 | load_anno=True,
527 | ):
528 | set_name = f"{split}2017"
529 | assert set_name in ["train2017", "val2017"], "Wrong set name!"
530 |
531 | self.split = split
532 | self.load_anno = load_anno
533 | self.coco_transforms = coco_transforms
534 |
535 | self.image_dir = os.path.join(data_root, "images", set_name)
536 | self.anno_dir = os.path.join(
537 | data_root, "annotations", f"instances_{set_name}.json"
538 | )
539 | self.coco = COCO(self.anno_dir)
540 |
541 | self.image_ids = self.coco.getImgIds()
542 |
543 | if split == "train":
544 | # filter image id without annotation
545 | ids = []
546 | for image_id in self.image_ids:
547 | anno_ids = self.coco.getAnnIds(imgIds=image_id)
548 | annos = self.coco.loadAnns(anno_ids)
549 | if len(annos) == 0:
550 | continue
551 | ids.append(image_id)
552 | self.image_ids = ids
553 |
554 | self.cat_ids = self.coco.getCatIds()
555 | self.cats = sorted(self.coco.loadCats(self.cat_ids), key=lambda x: x["id"])
556 | self.num_classes = len(self.cats)
557 |
558 | # cat_id is an original cat id,coco_label is set from 0 to 79
559 | self.cat_id_to_cat_name = {cat["id"]: cat["name"] for cat in self.cats}
560 | self.cat_id_to_coco_label = {cat["id"]: i for i, cat in enumerate(self.cats)}
561 | self.coco_label_to_cat_id = {i: cat["id"] for i, cat in enumerate(self.cats)}
562 | self.coco_label_to_cat_name = {
563 | coco_label: self.cat_id_to_cat_name[cat_id]
564 | for coco_label, cat_id in self.coco_label_to_cat_id.items()
565 | }
566 |
567 | print(f"Dataset Size:{len(self.image_ids)}")
568 | print(f"Dataset Class Num:{self.num_classes}")
569 |
570 | # by default only load instance seg_mask, not semantic seg_mask
571 | self.load_sem_mask = False
572 |
573 | def __len__(self):
574 | return len(self.image_ids)
575 |
576 | def __getitem__(self, idx):
577 | image = self.load_image(idx)
578 | H, W = image.shape[:2]
579 |
580 | if self.load_anno:
581 | annos = self.load_annos(idx) # [N, 5]
582 | masks, inst_overlap_masks = self.load_inst_masks(idx) # [H, W]x2
583 | masks = [masks, inst_overlap_masks]
584 | if self.load_sem_mask:
585 | sem_masks = self.load_sem_masks(idx) # [H, W]
586 | masks.insert(1, sem_masks) # [inst, sem, inst_overlap]
587 | masks = np.stack(masks, axis=-1) # [H, W, 2 or 3]
588 | else:
589 | annos = np.zeros((0, 5), dtype=np.float32)
590 | masks = np.zeros((H, W), dtype=np.int32)
591 |
592 | scale = np.array(1.0).astype(np.float32)
593 | size = np.array([image.shape[0], image.shape[1]]).astype(np.float32)
594 |
595 | sample = {
596 | "image": image,
597 | "masks": masks,
598 | # if load_sem_mask, will have a `sem_masks` key after collate_fn
599 | "annos": annos,
600 | "scale": scale,
601 | "size": size,
602 | }
603 | return self.coco_transforms(sample)
604 |
605 | def load_image(self, idx):
606 | """Read image."""
607 | file_name = self.coco.loadImgs(self.image_ids[idx])[0]["file_name"]
608 | image = cv2.imdecode(
609 | np.fromfile(os.path.join(self.image_dir, file_name), dtype=np.uint8),
610 | cv2.IMREAD_COLOR,
611 | )
612 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
613 | return image.astype(np.uint8)
614 |
615 | def load_annos(self, idx):
616 | """Load bbox and cls."""
617 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx])
618 | annos = self.coco.loadAnns(anno_ids)
619 |
620 | image_info = self.coco.loadImgs(self.image_ids[idx])[0]
621 | image_h, image_w = image_info["height"], image_info["width"]
622 |
623 | targets = np.zeros((0, 5))
624 | if len(annos) == 0:
625 | return targets.astype(np.float32)
626 |
627 | # filter annos
628 | for anno in annos:
629 | if anno.get("ignore", False):
630 | continue
631 | if anno.get("iscrowd", False):
632 | continue
633 | if anno["category_id"] not in self.cat_ids:
634 | continue
635 |
636 | # bbox format: [x_min, y_min, w, h]
637 | bbox = anno["bbox"]
638 | inter_w = max(0, min(bbox[0] + bbox[2], image_w) - max(bbox[0], 0))
639 | inter_h = max(0, min(bbox[1] + bbox[3], image_h) - max(bbox[1], 0))
640 | if inter_w * inter_h == 0:
641 | continue
642 | if bbox[2] * bbox[3] < 1 or bbox[2] < 1 or bbox[3] < 1:
643 | continue
644 |
645 | target = np.zeros((1, 5))
646 | target[0, :4] = bbox
647 | target[0, 4] = self.cat_id_to_coco_label[anno["category_id"]]
648 | targets = np.append(targets, target, axis=0)
649 |
650 | # [x_min, y_min, w, h] --> [x_min, y_min, x_max, y_max]
651 | targets[:, 2] = targets[:, 0] + targets[:, 2]
652 | targets[:, 3] = targets[:, 1] + targets[:, 3]
653 |
654 | return targets.astype(np.float32) # [N, 5 (x1, y1, x2, y2, cat_id)]
655 |
656 | def load_inst_masks(self, idx):
657 | """Load instance seg_mask and merge them into an argmaxed mask."""
658 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx])
659 | annos = self.coco.loadAnns(anno_ids)
660 |
661 | image_info = self.coco.loadImgs(self.image_ids[idx])[0]
662 | image_h, image_w = image_info["height"], image_info["width"]
663 |
664 | masks = np.zeros((image_h, image_w), dtype=np.int32)
665 | inst_overlap_masks = np.zeros_like(masks) # for overlap check
666 | for i, anno in enumerate(annos):
667 | if anno.get("ignore", False):
668 | continue
669 | if anno.get("iscrowd", False):
670 | continue
671 | if anno["category_id"] not in self.cat_ids:
672 | continue
673 | mask = self.coco.annToMask(anno)
674 | masks[mask > 0] = i + 1 # to put background as 0
675 | inst_overlap_masks[mask > 0] += 1
676 | # overlap value > 1 indicates overlap
677 | inst_overlap_masks = (inst_overlap_masks > 1).astype(np.int32)
678 | # [H, W], [H, W], 1 is overlapping pixels
679 | return masks, inst_overlap_masks
680 |
681 | def load_sem_masks(self, idx):
682 | """Load instance seg_mask and merge them into an argmaxed mask."""
683 | anno_ids = self.coco.getAnnIds(imgIds=self.image_ids[idx])
684 | annos = self.coco.loadAnns(anno_ids)
685 |
686 | image_info = self.coco.loadImgs(self.image_ids[idx])[0]
687 | image_h, image_w = image_info["height"], image_info["width"]
688 |
689 | masks = np.zeros((image_h, image_w), dtype=np.int32)
690 | for i, anno in enumerate(annos):
691 | if anno.get("ignore", False):
692 | continue
693 | if anno.get("iscrowd", False):
694 | continue
695 | if anno["category_id"] not in self.cat_ids:
696 | continue
697 | mask = self.coco.annToMask(anno)
698 | coco_lbl = self.cat_id_to_coco_label[anno["category_id"]]
699 | masks[mask > 0] = coco_lbl + 1 # to put background as 0
700 | # [H, W]
701 | return masks
702 |
703 |
704 | def build_coco_dataset(root_dir, resolution, load_anno=True, val_only=False):
705 | """Build COCO2017 dataset that load images."""
706 | val_transforms = COCOTransforms(
707 | resolution,
708 | val=True,
709 | )
710 | args = dict(
711 | data_root=root_dir,
712 | coco_transforms=val_transforms,
713 | split="val",
714 | load_anno=load_anno,
715 | )
716 | val_dataset = COCO2017Dataset(**args)
717 | if val_only:
718 | return val_dataset, COCOCollater()
719 | args["split"] = "train"
720 | args["load_anno"] = False
721 | args["coco_transforms"] = COCOTransforms(
722 | resolution,
723 | val=False,
724 | )
725 | train_dataset = COCO2017Dataset(**args)
726 | return train_dataset, val_dataset, COCOCollater()
727 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/dsprites.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | import torch
6 | import torchvision
7 | from torchvision import transforms
8 | from PIL import Image
9 |
10 | from source.data.augs import get_color_distortion
11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset
12 | from source.data.augs import simclr_augmentation
13 |
14 |
15 | def get_dsprites_pair(root, split="train", imsize=64, hflip=False):
16 | path = Path(root, f"colored_on_grayscale_{split}.npz")
17 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip))
18 |
19 |
20 | def get_dsprites(root, split="train", imsize=64):
21 | path = Path(root, f"colored_on_grayscale_{split}.npz")
22 | return NumpyDataset(
23 | path,
24 | transform=transforms.Compose(
25 | [transforms.Resize(imsize), transforms.ToTensor()]
26 | ),
27 | )
28 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from source.data.augs import get_color_distortion
4 | from torchvision.datasets import ImageNet
5 | from torchvision import transforms
6 | from PIL import Image
7 |
8 |
9 | class ImageNetPair(ImageNet):
10 | """Generate mini-batche pairs on CIFAR10 training set."""
11 |
12 | def __getitem__(self, idx):
13 | path, target = self.imgs[idx][0], self.targets[idx]
14 | img = self.loader(path)
15 | imgs = [self.transform(img), self.transform(img)]
16 | return torch.stack(imgs)
17 |
18 |
19 | def get_imagenet(
20 | root,
21 | split="train",
22 | transform=transforms.Compose(
23 | [
24 | transforms.CenterCrop(256),
25 | transforms.ToTensor(),
26 | ]
27 | ),
28 | ):
29 | return ImageNet(
30 | root=root,
31 | split=split,
32 | transform=transform,
33 | )
34 |
35 |
36 | def get_imagenet_pair(root, split="train", imsize=256, hflip=False):
37 | from source.data.augs import simclr_augmentation
38 |
39 | return ImageNetPair(
40 | root=root,
41 | split=split,
42 | transform=simclr_augmentation(imsize=imsize, hflip=hflip),
43 | )
44 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/load_data.py:
--------------------------------------------------------------------------------
1 | def load_data(data, data_root, data_imsize, is_eval=False):
2 | # Data loading
3 |
4 | collate_fn = None
5 |
6 | if data == "clevrtex_full" or data == "clevrtex_camo" or data == "clevrtex_outd":
7 | from source.data.datasets.objs.clevr_tex import get_clevrtex_pair, get_clevrtex, collate_fn
8 | if data == "clevrtex_camo":
9 | assert is_eval, "Camo dataset is only for evaluation"
10 | data_type = "camo"
11 | default_path = "./data/clevr_tex/clevrtex_camo"
12 | elif data == "clevrtex_outd":
13 | assert is_eval, "OOD dataset is only for evaluation"
14 | data_type = "outd"
15 | default_path = "./data/clevr_tex/clevrtex_outd"
16 | else:
17 | data_type = "full"
18 | default_path = "./data/clevr_tex/clevrtex_full"
19 |
20 | data_root = (
21 | default_path
22 | if data_root is None
23 | else data_root
24 | )
25 | imsize = 128 if data_imsize is None else data_imsize
26 | if is_eval:
27 | dataset = get_clevrtex(data_root, split="test", data_type=data_type, imsize=imsize, return_meta_data=True)
28 | else:
29 | dataset = get_clevrtex_pair(
30 | root=data_root,
31 | split="train",
32 | )
33 |
34 | elif data == "clevr":
35 | from source.data.datasets.objs.clevr import get_clevr_pair, get_clevr
36 | imsize = 128 if data_imsize is None else data_imsize
37 | data_root = "./data/clevr_with_masks/"
38 | if is_eval:
39 | dataset = get_clevr(data_root, split="test", imsize=imsize)
40 | else:
41 | dataset = get_clevr_pair(
42 | root="./data/clevr_with_masks/",
43 | split="train",
44 | )
45 |
46 | elif data == "imagenet":
47 | from source.data.datasets.objs.imagenet import get_imagenet_pair
48 | data_root = (
49 | "./data/ImageNet2012/"
50 | if data_root is None
51 | else data_root
52 | )
53 | dataset = get_imagenet_pair(
54 | root=data_root,
55 | split="train",
56 | hflip=True,
57 | imsize=256 if data_imsize is None else data_imsize,
58 | )
59 | elif data == "coco":
60 | imsize = 256 if data_imsize is None else data_imsize
61 | from source.data.datasets.objs.coco import get_coco_dataset
62 | data_root = "./data/COCO"
63 | if not is_eval:
64 | raise ValueError("COCO dataset is only for evaluation")
65 | else:
66 | dataset = get_coco_dataset(root=data_root)
67 |
68 | elif data == "pascal":
69 | imsize = 256 if data_imsize is None else data_imsize
70 | from source.data.datasets.objs.pascal import get_pascal_dataset
71 | data_root = "./data/VOCdevkit/VOC2012"
72 | if not is_eval:
73 | raise ValueError("Pascal dataset is only for evaluation")
74 | else:
75 | dataset = get_coco_dataset(root=data_root)
76 |
77 | elif data == "dsprites":
78 | imsize = 64 if data_imsize is None else data_imsize
79 | from source.data.datasets.objs.dsprites import get_dsprites_pair
80 | dataset = get_dsprites_pair(
81 | root="./data/multi_dsprites/",
82 | split="train",
83 | imsize=imsize,
84 | )
85 |
86 | elif data == "tetrominoes":
87 | from source.data.datasets.objs.tetrominoes import get_tetrominoes_pair
88 | imsize = 32 if data_imsize is None else data_imsize
89 | dataset = get_tetrominoes_pair(
90 | root="./data/tetrominoes/",
91 | split="train",
92 | imsize=imsize,
93 | )
94 |
95 | elif data == "Shapes":
96 | imsize = 40 if data_imsize is None else data_imsize
97 | from source.data.datasets.objs.shapes import get_shapes_pair
98 | dataset = get_shapes_pair(
99 | root="./data/Shapes/",
100 | split="train",
101 | imsize=imsize,
102 | )
103 | return dataset, imsize, collate_fn
104 |
105 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/npdataset.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | import torch
7 | import torchvision
8 | from torchvision import transforms
9 | from PIL import Image
10 |
11 | from source.data.augs import get_color_distortion
12 |
13 |
14 | class NumpyDataset(Dataset):
15 | """NpzDataset: loads a npz file as dataset."""
16 |
17 | def __init__(self, filename, transform=torchvision.transforms.ToTensor()):
18 | super().__init__()
19 |
20 | dataset = np.load(filename)
21 | self.images = dataset["images"].astype(np.float32)
22 | if self.images.shape[1] == 1:
23 | self.images = np.repeat(self.images, 3, axis=1)
24 | self.pixelwise_instance_labels = dataset["labels"]
25 |
26 | if "class_labels" in dataset:
27 | self.class_labels = dataset["class_labels"]
28 | else:
29 | self.class_labels = None
30 |
31 | if "pixelwise_class_labels" in dataset:
32 | self.pixelwise_class_labels = dataset["pixelwise_class_labels"]
33 | else:
34 | self.pixelwise_class_labels = None
35 |
36 | self.transform = transform
37 |
38 | def __len__(self):
39 | return self.images.shape[0]
40 |
41 | def __getitem__(self, idx):
42 | img = self.images[idx] # {"input_images": self.images[idx]}
43 | img = np.transpose(img, (1, 2, 0))
44 | img = (255 * img).astype(np.uint8)
45 | img = Image.fromarray(img) # .convert('RGB')
46 | labels = {"pixelwise_instance_labels": self.pixelwise_instance_labels[idx]}
47 |
48 | if self.class_labels is not None:
49 | labels["class_labels"] = self.class_labels[idx]
50 | labels["pixelwise_class_labels"] = self.pixelwise_class_labels[idx]
51 | return self.transform(img), labels
52 |
53 |
54 | class PairDataset(NumpyDataset):
55 | """Generate mini-batche pairs on CIFAR10 training set."""
56 |
57 | def __getitem__(self, idx):
58 | img = self.images[idx]
59 | img = np.transpose(img, (1, 2, 0))
60 | img = (255 * img).astype(np.uint8)
61 | img = Image.fromarray(img) # .convert('RGB')
62 | imgs = [self.transform(img), self.transform(img)]
63 | return torch.stack(imgs) # stack a positive pair
64 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/pascal.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/loeweX/RotatingFeatures/blob/main/codebase/data/PascalDataset.py
2 |
3 | import os
4 | from typing import Tuple, Dict
5 |
6 | import torch
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class PascalDataset(Dataset):
13 | def __init__(self, root_dir, partition: str, transform, transform_label) -> None:
14 | """
15 | Initialize the Pascal VOC 2012 Dataset.
16 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information about the dataset.
17 |
18 | We make use of the “trainaug” variant of this dataset, an unofficial split consisting of 10,582 images,
19 | which includes 1,464 images from the original segmentation train set and 9,118 images from the
20 | Semantic Boundaries dataset.
21 |
22 | Args:
23 | opt (DictConfig): Configuration options.
24 | partition (str): Dataset partition ("train", "val", or "test").
25 | """
26 | super(PascalDataset, self).__init__()
27 |
28 | self.partition = partition
29 | self.to_tensor = transforms.ToTensor()
30 | if self.partition == "train":
31 | self.partition = "trainaug"
32 | self.transform = transform
33 | self.transform_label = transform_label
34 |
35 | # As is common in the literature, we test our model on the validation set of the Pascal VOC dataset.
36 | # For validation, create a train/validation split of the official training set manually and
37 | # adjust the code accordingly.
38 | if self.partition == "test":
39 | self.partition = "val"
40 |
41 | # Load Pascal dataset.
42 | partition_dir = os.path.join(
43 | root_dir, "ImageSets", "Segmentation", f"{self.partition}.txt"
44 | )
45 |
46 | with open(partition_dir) as f:
47 | file_names = [x.strip() for x in f.readlines()]
48 |
49 | self.images = [
50 | os.path.join(root_dir, "JPEGImages", f"{x}.jpg") for x in file_names
51 | ]
52 | self.pixelwise_class_labels = [
53 | os.path.join(root_dir, "SegmentationClass", f"{x}.png") for x in file_names
54 | ]
55 | self.pixelwise_instance_labels = [
56 | os.path.join(root_dir, "SegmentationObject", f"{x}.png") for x in file_names
57 | ]
58 |
59 | # Normalize input images using mean and standard deviation of ImageNet.
60 | # self.normalize = transforms.Normalize(
61 | # (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
62 | # )
63 |
64 | self.num_classes = 20
65 |
66 | def __len__(self) -> int:
67 | """
68 | Get the number of images in the dataset.
69 |
70 | Returns:
71 | int: Number of images.
72 | """
73 | return len(self.images)
74 |
75 | @staticmethod
76 | def _preprocess_pascal_labels(labels: torch.Tensor) -> torch.Tensor:
77 | """
78 | Preprocess Pascal VOC labels by converting to integer 255-scale and
79 | marking object boundaries as "ignore" label.
80 |
81 | Args:
82 | labels (torch.Tensor): The input labels.
83 |
84 | Returns:
85 | torch.Tensor: The preprocessed labels.
86 | """
87 | labels = labels * 255
88 | labels[labels == 255] = -1 # "Ignore" label throughout the codebase is -1.
89 | return labels
90 |
91 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
92 | """
93 | Get an item from the dataset.
94 |
95 | Args:
96 | idx (int): Index of the item to retrieve.
97 |
98 | Returns:
99 | tuple: A tuple containing the input image and corresponding gt_labels.
100 | """
101 | input_image = Image.open(self.images[idx]).convert("RGB")
102 | pixelwise_class_labels = Image.open(self.pixelwise_class_labels[idx])
103 |
104 | try:
105 | pixelwise_instance_labels = Image.open(self.pixelwise_instance_labels[idx])
106 | except FileNotFoundError as e:
107 | # Instance labels are not available for all images in the trainaug set.
108 | if self.partition != "trainaug":
109 | raise FileNotFoundError(
110 | "Instance labels should only be missing for the trainaug partition."
111 | ) from e
112 | # Create an empty target.
113 | pixelwise_instance_labels = Image.new(
114 | "L", (pixelwise_class_labels.width, pixelwise_class_labels.height), 0
115 | )
116 |
117 | pixelwise_class_labels = self._preprocess_pascal_labels(
118 | self.to_tensor(pixelwise_class_labels)
119 | )
120 | pixelwise_instance_labels = self._preprocess_pascal_labels(
121 | self.to_tensor(pixelwise_instance_labels)
122 | )
123 |
124 | input_image = self.transform(input_image)
125 | pixelwise_instance_labels = self.transform_label(pixelwise_instance_labels)[0]
126 | pixelwise_class_labels = self.transform_label(pixelwise_class_labels)[0]
127 |
128 | labels = {
129 | "pixelwise_class_labels": pixelwise_class_labels,
130 | "pixelwise_instance_labels": pixelwise_instance_labels,
131 | }
132 | return input_image, labels
133 |
134 |
135 | def get_pascal(root, split="train", imsize=256, imsize_label=320):
136 | transform = transforms.Compose(
137 | [
138 | transforms.Resize(
139 | imsize, interpolation=transforms.InterpolationMode.NEAREST
140 | ),
141 | transforms.CenterCrop(imsize),
142 | transforms.ToTensor(),
143 | ]
144 | )
145 |
146 | transform_label = transforms.Compose(
147 | [
148 | transforms.Resize(
149 | imsize_label, interpolation=transforms.InterpolationMode.NEAREST
150 | ),
151 | transforms.CenterCrop(imsize_label),
152 | ]
153 | )
154 | return PascalDataset(
155 | root, split, transform=transform, transform_label=transform_label
156 | )
157 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/shapes.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | import torch
6 | import torchvision
7 | from torchvision import transforms
8 | from PIL import Image
9 |
10 | from source.data.augs import simclr_augmentation
11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset
12 |
13 |
14 | def get_shapes_pair(root, split="train", imsize=40):
15 |
16 | path = Path(root, f"{split}.npz")
17 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=False))
18 |
19 |
20 | def get_shapes(root, split="train", imsize=40):
21 | path = Path(root, f"{split}.npz")
22 | return NumpyDataset(
23 | path,
24 | transform=transforms.Compose(
25 | [
26 | transforms.Resize((imsize, imsize)),
27 | transforms.ToTensor(),
28 | ]
29 | ),
30 | )
31 |
--------------------------------------------------------------------------------
/source/data/datasets/objs/tetrominoes.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | import torch
6 | import torchvision
7 | from torchvision import transforms
8 | from PIL import Image
9 |
10 | from source.data.augs import simclr_augmentation
11 | from source.data.datasets.objs.npdataset import NumpyDataset, PairDataset
12 |
13 |
14 | def get_tetrominoes_pair(root, split="train", imsize=32, hflip=False):
15 | path = Path(root, f"tetrominoes_{split}.npz")
16 | return PairDataset(path, transform=simclr_augmentation(imsize=imsize, hflip=hflip))
17 |
18 |
19 | def get_tetrominoes(root, split="train"):
20 | path = Path(root, f"tetrominoes_{split}.npz")
21 | return NumpyDataset(path, transform=transforms.ToTensor())
22 |
--------------------------------------------------------------------------------
/source/data/datasets/sudoku/sudoku.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 |
7 |
8 | def convert_onehot_to_int(X):
9 | # [B, H, W, 9]->[B, H, W]
10 | is_input = X.sum(dim=-1)
11 | return (is_input * (X.argmax(-1) + 1)).to(torch.int32)
12 |
13 |
14 | # copied from https://github.com/yilundu/ired_code_release/blob/3d74b85fab7fcf5e28aaf15e9ed3bf51c1a1d545/sat_dataset.py#L17
15 | def load_rrn_dataset(data_dir, split):
16 | if not osp.exists(data_dir):
17 | raise ValueError(
18 | f"Data directory {data_dir} does not exist. Run data/download-rrn.sh to download the dataset."
19 | )
20 |
21 | split_to_filename = {"train": "train.csv", "val": "valid.csv", "test": "test.csv"}
22 |
23 | filename = osp.join(data_dir, split_to_filename[split])
24 | df = pd.read_csv(filename, header=None)
25 |
26 | def str2onehot(x):
27 | x = np.array(list(map(int, x)), dtype="int64")
28 | y = np.zeros((len(x), 9), dtype="float32")
29 | idx = np.where(x > 0)[0]
30 | y[idx, x[idx] - 1] = 1
31 | return y.reshape((9, 9, 9))
32 |
33 | features = list()
34 | labels = list()
35 | for i in range(len(df)):
36 | inp = df.iloc[i][0]
37 | out = df.iloc[i][1]
38 | features.append(str2onehot(inp))
39 | labels.append(str2onehot(out))
40 |
41 | return torch.tensor(np.array(features)), torch.tensor(np.array(labels))
42 |
43 |
44 | def load_sat_dataset(path):
45 | with open(os.path.join(path, "features.pt"), "rb") as f:
46 | X = torch.load(f)
47 | with open(os.path.join(path, "labels.pt"), "rb") as f:
48 | Y = torch.load(f)
49 | with open(os.path.join(path, "perm.pt"), "rb") as f:
50 | perm = torch.load(f)
51 | return X, Y, perm
52 |
53 |
54 | class SudokuDataset:
55 |
56 | def __init__(self, path='./data/sudoku/', train=True):
57 |
58 | X, Y, _ = load_sat_dataset(path)
59 |
60 | is_input = X.sum(dim=3, keepdim=True).int()
61 |
62 | indices = torch.arange(0, 9000) if train else torch.arange(9000, 10000)
63 |
64 | self.X = X[indices]
65 | self.Y = Y[indices]
66 | self.is_input = is_input[indices]
67 |
68 | def __len__(self):
69 | return len(self.X)
70 |
71 | def __getitem__(self, idx):
72 | return self.X[idx], self.Y[idx], self.is_input[idx]
73 |
74 |
75 | class HardSudokuDataset:
76 | def __init__(self, path='./data/sudoku-rnn/', split="test"):
77 |
78 | X, Y = load_rrn_dataset(path, split)
79 |
80 | is_input = X.sum(dim=3, keepdim=True).int()
81 |
82 | self.X = X
83 | self.Y = Y
84 | self.is_input = is_input
85 |
86 | def __len__(self):
87 | return len(self.X)
88 |
89 | def __getitem__(self, idx):
90 | return self.X[idx], self.Y[idx], self.is_input[idx]
91 |
--------------------------------------------------------------------------------
/source/evals/objs/fgari.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import adjusted_rand_score
3 |
4 | def calc_fgari_score(
5 | gt_labels: np.ndarray, pred_labels: np.ndarray
6 | ) -> float:
7 | """
8 | Calculate Adjusted Rand Index (ARI) score for object discovery evaluation.
9 |
10 | Args:
11 | opt (DictConfig): Configuration options.
12 | gt_labels (np.ndarray): Ground truth labels, shape ((b, h, w)).
13 | pred_labels (np.ndarray): Predicted labels, shape (b, h, w).
14 |
15 | Returns:
16 | float: ARI score.
17 | """
18 | aris = []
19 | for idx in range(gt_labels.shape[0]):
20 | # Remove "ignore" (-1) and background (0) gt_labels.
21 | area_to_eval = np.where(gt_labels[idx] > 0)
22 |
23 | ari = adjusted_rand_score(
24 | gt_labels[idx][area_to_eval], pred_labels[idx][area_to_eval]
25 | )
26 | aris.append(ari)
27 | return aris
--------------------------------------------------------------------------------
/source/evals/objs/mbo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_iou_matrix(gt_labels: np.ndarray, pred_labels: np.ndarray) -> np.ndarray:
5 | """
6 | Compute the Intersection over Union (IoU) matrix between ground truth and predicted labels.
7 |
8 | Args:
9 | gt_labels (np.ndarray): Ground truth labels, shape (m, h, w).
10 | pred_labels (np.ndarray): Predicted labels, shape (o, h, w).
11 |
12 | Returns:
13 | np.ndarray: IoU matrix, shape (m, o).
14 | """
15 | intersection = np.logical_and(
16 | gt_labels[:, None, :, :], pred_labels[None, :, :, :]
17 | ).sum(axis=(2, 3))
18 | union = np.logical_or(gt_labels[:, None, :, :], pred_labels[None, :, :, :]).sum(
19 | axis=(2, 3)
20 | )
21 | return intersection / (union + 1e-9)
22 |
23 |
24 | def mean_best_overlap_single_sample(
25 | gt_labels: np.ndarray, pred_labels: np.ndarray
26 | ) -> float:
27 | """
28 | Compute the Mean Best Overlap (MBO) for a single sample between ground truth and predicted labels.
29 |
30 | Args:
31 | gt_labels (np.ndarray): Ground truth labels, shape (h, w).
32 | pred_labels (np.ndarray): Predicted labels, shape (h, w).
33 |
34 | Returns:
35 | float: MBO score for the sample.
36 | """
37 | from copy import deepcopy
38 |
39 | pred_labels = deepcopy(pred_labels)
40 |
41 | unique_gt_labels = np.unique(gt_labels)
42 | # Remove "ignore" (-1) label.
43 | unique_gt_labels = unique_gt_labels[unique_gt_labels != -1]
44 |
45 | # Mask areas with "ignore" gt_labels in pred_labels.
46 | pred_labels[np.where(gt_labels < 0)] = -1
47 |
48 | # Ignore background (0) gt_labels.
49 | unique_gt_labels = unique_gt_labels[unique_gt_labels != 0]
50 |
51 | if len(unique_gt_labels) == 0:
52 | return -1 # If no gt_labels left, skip this element.
53 |
54 | unique_pred_labels = np.unique(pred_labels)
55 |
56 | # Remove "ignore" (-1) label.
57 | unique_pred_labels = unique_pred_labels[unique_pred_labels != -1]
58 |
59 | gt_masks = np.equal(gt_labels[None, :, :], unique_gt_labels[:, None, None])
60 | pred_masks = np.equal(pred_labels[None, :, :], unique_pred_labels[:, None, None])
61 |
62 | iou_matrix = compute_iou_matrix(gt_masks, pred_masks)
63 | best_iou = np.max(iou_matrix, axis=1)
64 | return np.mean(best_iou)
65 |
66 |
67 | def calc_mean_best_overlap(gt_labels: np.ndarray, pred_labels: np.ndarray) -> float:
68 | """
69 | Calculate the Mean Best Overlap (MBO) for a batch of ground truth and predicted labels.
70 |
71 | Args:
72 | opt (DictConfig): Configuration options.
73 | gt_labels (np.ndarray): Ground truth labels, shape (b, h, w).
74 | pred_labels (np.ndarray): Predicted labels, shape (b, h, w).
75 |
76 | Returns:
77 | float: MBO score for the batch.
78 | """
79 | mean_best_overlap = np.array(
80 | [
81 | mean_best_overlap_single_sample(gt_labels[b_idx], pred_labels[b_idx])
82 | for b_idx in range(gt_labels.shape[0])
83 | ]
84 | )
85 |
86 | if np.any(mean_best_overlap != -1):
87 | return np.mean(mean_best_overlap[mean_best_overlap != -1]), mean_best_overlap
88 | else:
89 | return 0.0, mean_best_overlap
90 |
--------------------------------------------------------------------------------
/source/evals/sudoku/evals.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | def compute_board_accuracy(pred, Y, is_input):
5 | #print(pred.shape)
6 | B = pred.shape[0]
7 | pred = pred.reshape((B, -1, 9)).argmax(-1)
8 | Y = Y.argmax(dim=-1).reshape(B, -1)
9 | mask = 1 - is_input.reshape(B, -1)
10 |
11 | num_blanks = mask.sum(1)
12 | num_correct = (mask * (pred == Y)).sum(1)
13 | board_correct = (num_correct == num_blanks).int()
14 | return num_blanks, num_correct, board_correct
15 |
16 |
17 |
--------------------------------------------------------------------------------
/source/layers/common_fns.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | def positionalencoding2d(d_model, height, width):
5 | """
6 | :param d_model: dimension of the model
7 | :param height: height of the positions
8 | :param width: width of the positions
9 | :return: d_model*height*width position matrix
10 | """
11 | if d_model % 4 != 0:
12 | raise ValueError("Cannot use sin/cos positional encoding with "
13 | "odd dimension (got dim={:d})".format(d_model))
14 | pe = torch.zeros(d_model, height, width)
15 | # Each dimension use half of d_model
16 | d_model = int(d_model / 2)
17 | div_term = torch.exp(torch.arange(0., d_model, 2) *
18 | -(math.log(10000.0) / d_model))
19 | pos_w = torch.arange(0., width).unsqueeze(1)
20 | pos_h = torch.arange(0., height).unsqueeze(1)
21 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
22 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
23 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
24 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
25 |
26 | return pe
27 |
--------------------------------------------------------------------------------
/source/layers/common_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import math
5 | import einops
6 | import numpy as np
7 |
8 | from source.layers.gta import (
9 | make_2dcoord,
10 | make_SO2mats,
11 | rep_mul_x,
12 | )
13 |
14 |
15 | class Interpolate(nn.Module):
16 |
17 | def __init__(self, r, mode="bilinear"):
18 | super().__init__()
19 | self.r = r
20 | self.mode = mode
21 |
22 | def forward(self, x):
23 | return F.interpolate(
24 | x, scale_factor=self.r, mode=self.mode, align_corners=False
25 | )
26 |
27 |
28 | class Reshape(nn.Module):
29 | def __init__(self, *args):
30 | super().__init__()
31 | self.shape = args
32 |
33 | def forward(self, x):
34 | return x.view(self.shape)
35 |
36 |
37 | class ResBlock(nn.Module):
38 |
39 | def __init__(self, fn):
40 | super().__init__()
41 | self.fn = fn
42 |
43 | def forward(self, x):
44 | return x + self.fn(x)
45 |
46 |
47 | class PatchEmbedding(nn.Module):
48 | def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
49 | super(PatchEmbedding, self).__init__()
50 | self.patch_size = patch_size
51 | self.embed_dim = embed_dim
52 | self.num_patches = (img_size // patch_size) ** 2
53 | self.proj = nn.Conv2d(
54 | in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
55 | )
56 |
57 | def forward(self, x):
58 | x = self.proj(x) # (B, embed_dim, H/patch_size, W/patch_size)
59 | x = x.flatten(2) # (B, embed_dim, num_patches)
60 | x = x.transpose(1, 2) # (B, num_patches, embed_dim)
61 | return x
62 |
63 |
64 | class ReadOutConv(nn.Module):
65 | def __init__(
66 | self,
67 | inch,
68 | outch,
69 | out_dim,
70 | kernel_size=1,
71 | stride=1,
72 | padding=0,
73 | ):
74 | super().__init__()
75 | self.outch = outch
76 | self.out_dim = out_dim
77 | self.invconv = nn.Conv2d(
78 | inch,
79 | outch * out_dim,
80 | kernel_size=kernel_size,
81 | stride=stride,
82 | padding=padding,
83 | )
84 | self.bias = nn.Parameter(torch.zeros(outch))
85 |
86 | def forward(self, x):
87 | x = self.invconv(x).unflatten(1, (self.outch, -1))
88 | x = torch.linalg.norm(x, dim=2) + self.bias[None, :, None, None]
89 | return x
90 |
91 |
92 | class BNReLUConv2d(nn.Module):
93 |
94 | def __init__(
95 | self,
96 | inch,
97 | outch,
98 | kernel_size=1,
99 | stride=1,
100 | padding=0,
101 | norm=None,
102 | act=nn.ReLU(),
103 | ):
104 | super().__init__()
105 | if norm == "gn":
106 | norm = lambda ch: nn.GroupNorm(8, ch)
107 | elif norm == "bn":
108 | norm = lambda ch: nn.BatchNorm2d(ch)
109 | elif norm == None:
110 | norm = lambda ch: nn.Identity()
111 | else:
112 | raise NotImplementedError
113 |
114 | conv = nn.Conv2d(
115 | inch,
116 | outch,
117 | kernel_size=kernel_size,
118 | stride=stride,
119 | padding=padding,
120 | )
121 |
122 | self.fn = nn.Sequential(
123 | norm(inch),
124 | act,
125 | conv,
126 | )
127 |
128 | def forward(self, x):
129 | return self.fn(x)
130 |
131 |
132 | class FF(nn.Module):
133 |
134 | def __init__(
135 | self,
136 | inch,
137 | outch,
138 | hidch=None,
139 | kernel_size=1,
140 | stride=1,
141 | padding=0,
142 | norm=None,
143 | act=nn.ReLU(),
144 | ):
145 | super().__init__()
146 | if hidch is None:
147 | hidch = 4 * inch
148 | self.fn = nn.Sequential(
149 | BNReLUConv2d(
150 | inch,
151 | hidch,
152 | kernel_size=kernel_size,
153 | stride=stride,
154 | padding=padding,
155 | norm=norm,
156 | act=act,
157 | ),
158 | BNReLUConv2d(
159 | hidch,
160 | outch,
161 | kernel_size=kernel_size,
162 | stride=stride,
163 | padding=padding,
164 | norm=norm,
165 | act=act,
166 | ),
167 | )
168 |
169 | def forward(self, x):
170 | x = self.fn(x)
171 | return x
172 |
173 |
174 | class LayerNormForImage(nn.Module):
175 | def __init__(self, num_features, eps=1e-5):
176 | super().__init__()
177 | self.eps = eps
178 | self.gamma = nn.Parameter(torch.ones(1, num_features, 1, 1))
179 | self.beta = nn.Parameter(torch.zeros(1, num_features, 1, 1))
180 |
181 | def forward(self, x):
182 | # x shape: [B, C, H, W]
183 | mean = x.mean(dim=1, keepdim=True)
184 | var = x.var(dim=1, keepdim=True, unbiased=False)
185 | x_normalized = (x - mean) / torch.sqrt(var + self.eps)
186 | if x.ndim == 2:
187 | x_normalized = self.gamma[..., 0, 0] * x_normalized + self.beta[..., 0, 0]
188 | else:
189 | x_normalized = self.gamma * x_normalized + self.beta
190 | return x_normalized
191 |
192 |
193 | class ScaleAndBias(nn.Module):
194 | def __init__(self, num_channels, token_input=False):
195 | super().__init__()
196 | self.scale = nn.Parameter(torch.ones(num_channels))
197 | self.bias = nn.Parameter(torch.zeros(num_channels))
198 | self.token_input = token_input
199 |
200 | def forward(self, x):
201 | # Determine the shape for scale and bias based on input dimensions
202 | if self.token_input:
203 | # token input
204 | shape = [1, 1, -1]
205 | scale = self.scale.view(*shape)
206 | bias = self.bias.view(*shape)
207 | else:
208 | # image input
209 | shape = [1, -1] + [1] * (x.dim() - 2)
210 | scale = self.scale.view(*shape)
211 | bias = self.bias.view(*shape)
212 | return x * scale + bias
213 |
214 |
215 | class RGBNormalize(nn.Module):
216 | def __init__(self, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)):
217 | super().__init__()
218 |
219 | self.mean = torch.tensor(mean).view(1, len(mean), 1, 1)
220 | self.std = torch.tensor(std).view(1, len(std), 1, 1)
221 |
222 | def forward(self, x):
223 | if x.device != self.mean.device:
224 | self.mean = self.mean.to(x.device)
225 | self.std = self.std.to(x.device)
226 | return (x - self.mean) / self.std
227 |
228 | def inverse(self, x):
229 | if x.device != self.mean.device:
230 | self.mean = self.mean.to(x.device)
231 | self.std = self.std.to(x.device)
232 | return (x * self.std) + self.mean
233 |
234 |
235 | class FeatureAttention(nn.Module):
236 | def __init__(self, n, ch):
237 | super().__init__()
238 | self.n = n
239 | self.ch = ch
240 | self.q_linear = nn.Linear(n, n)
241 | self.k_linear = nn.Linear(n, n)
242 | self.v_linear = nn.Linear(n, n)
243 | self.o_linear = nn.Linear(n, n)
244 |
245 | def forward(self, x):
246 | B = x.shape[0]
247 | q, k, v = map(lambda x: x.view(B, -1, self.n), (x, x, x))
248 | q = self.q_linear(q)
249 | k = self.k_linear(k)
250 | v = self.v_linear(v)
251 |
252 | o = F.scaled_dot_product_attention(q, k, v)
253 | return self.o_linear(o).view(B, -1)
254 |
255 |
256 | class Attention(nn.Module):
257 | def __init__(
258 | self,
259 | ch,
260 | heads=8,
261 | weight="conv",
262 | kernel_size=1,
263 | stride=1,
264 | padding=0,
265 | gta=False,
266 | rope=False,
267 | hw=None,
268 | ):
269 | super().__init__()
270 |
271 | self.heads = heads
272 | self.head_dim = ch // heads
273 | self.weight = weight
274 | self.stride = stride
275 |
276 | if weight == "conv":
277 | self.W_qkv = nn.Conv2d(
278 | ch,
279 | 3 * ch,
280 | kernel_size=kernel_size,
281 | stride=stride,
282 | padding=padding,
283 | )
284 | self.W_o = nn.Conv2d(
285 | ch,
286 | ch,
287 | kernel_size=kernel_size,
288 | stride=stride,
289 | padding=padding,
290 | )
291 | elif weight == "fc":
292 | self.W_qkv = nn.Linear(ch, 3 * ch)
293 | self.W_o = nn.Linear(ch, ch)
294 | else:
295 | raise ValueError("weight should be 'conv' or 'fc': {}".format(weight))
296 |
297 | self.gta = gta
298 | self.rope = rope
299 | assert (int(self.gta) + int(self.rope)) <= 1 # either gta or rope
300 |
301 | self.hw = hw
302 |
303 | if gta or rope:
304 | assert hw is not None
305 | F = self.head_dim // 4
306 | if self.head_dim % 4 != 0:
307 | F = F + 1
308 |
309 | if not isinstance(hw, list):
310 | coord = hw
311 | _mat = make_SO2mats(coord, F).flatten(1, 2) # [h*w, head_dim/2, 2, 2]
312 | else:
313 | coord = make_2dcoord(hw[0], hw[1])
314 | _mat = (
315 | make_SO2mats(coord, F).flatten(2, 3).flatten(0, 1)
316 | ) # [h*w, head_dim/2, 2, 2]
317 |
318 | _mat = _mat[..., : self.head_dim // 2, :, :]
319 | # set indentity matrix for additional tokens
320 |
321 | if gta:
322 | self.mat_q = nn.Parameter(_mat)
323 | self.mat_k = nn.Parameter(_mat)
324 | self.mat_v = nn.Parameter(_mat)
325 | self.mat_o = nn.Parameter(_mat.transpose(-2, -1))
326 | elif rope:
327 | self.mat_q = nn.Parameter(_mat)
328 | self.mat_k = nn.Parameter(_mat)
329 |
330 | def rescale_gta_mat(self, mat, hw):
331 | # _mat = [h*w, head_dim/2, 2, 2]
332 | if hw[0] == self.hw[0] and hw[1] == self.hw[1]:
333 | return mat
334 | else:
335 | f, c, d = mat.shape[1:]
336 | mat = einops.rearrange(
337 | mat, "(h w) f c d -> (f c d) h w", h=self.hw[0], w=self.hw[1]
338 | )
339 | mat = F.interpolate(mat[None], size=hw, mode="bilinear")[0]
340 | mat = einops.rearrange(mat, "(f c d) h w -> (h w) f c d", f=f, c=c, d=d)
341 | return mat
342 |
343 | def forward(self, x):
344 |
345 | if self.weight == "conv":
346 | h, w = x.shape[2] // self.stride, x.shape[3] // self.stride
347 | else:
348 | h, w = self.hw
349 |
350 | reshape_str = (
351 | "b (c nh) h w -> b nh (h w) c"
352 | if self.weight == "conv"
353 | else "b k (c nh) -> b nh k c"
354 | )
355 | dim = 1 if self.weight == "conv" else 2
356 | q, k, v = self.W_qkv(x).chunk(3, dim=dim)
357 | q, k, v = map(
358 | lambda x: einops.rearrange(x, reshape_str, nh=self.heads),
359 | (q, k, v),
360 | )
361 | if self.gta:
362 | q, k, v = map(
363 | lambda args: rep_mul_x(self.rescale_gta_mat(args[0], (h, w)), args[1]),
364 | ((self.mat_q, q), (self.mat_k, k), (self.mat_v, v)),
365 | )
366 | elif self.rope:
367 | q, k = map(
368 | lambda args: rep_mul_x(args[0], args[1]),
369 | ((self.mat_q, q), (self.mat_k, k)),
370 | )
371 |
372 | x = torch.nn.functional.scaled_dot_product_attention(
373 | q, k, v, attn_mask=self.mask if hasattr(self, "mask") else None
374 | )
375 |
376 | if self.gta:
377 | x = rep_mul_x(self.rescale_gta_mat(self.mat_o, (h, w)), x)
378 |
379 | if self.weight == "conv":
380 | x = einops.rearrange(x, "b nh (h w) c -> b (c nh) h w", h=h, w=w)
381 | else:
382 | x = einops.rearrange(x, "b nh k c -> b k (c nh)")
383 |
384 | x = self.W_o(x)
385 |
386 | return x
387 |
--------------------------------------------------------------------------------
/source/layers/gta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | from einops import rearrange
5 |
6 |
7 | def make_2dcoord(H, W, normalize=False):
8 | """
9 | Return(torch.Tensor): 2d coord values of shape [H, W, 2]
10 | """
11 | x = np.arange(H, dtype=np.float32) # [0, H)
12 | y = np.arange(W, dtype=np.float32) # [0, W)
13 | if normalize:
14 | x = x / H
15 | y = y / W
16 | x_grid, y_grid = np.meshgrid(x, y, indexing="ij")
17 | return torch.Tensor(
18 | np.stack([x_grid.flatten(), y_grid.flatten()], -1).reshape(H, W, 2)
19 | )
20 |
21 |
22 | def make_SO2mats(coord, nfreqs):
23 | """
24 | Args:
25 | coord: [..., 2 or 3]
26 | freqs: [n_freqs, 2 or 3]
27 | Return:
28 | mats of shape [..., n_freqs, (2 or 3), 2, 2]
29 | """
30 | dim = coord.shape[-1]
31 | b = 10000.0
32 | freqs = torch.exp(torch.arange(0.0, 2 * nfreqs, 2) * -(math.log(b) / (2 * nfreqs)))
33 | grid_ths = [
34 | torch.einsum("...i,j->...ij", coord[..., d : d + 1], freqs).flatten(-2, -1)
35 | for d in range(dim)
36 | ]
37 |
38 | _mats = [
39 | [
40 | torch.cos(grid_ths[d]),
41 | -torch.sin(grid_ths[d]),
42 | torch.sin(grid_ths[d]),
43 | torch.cos(grid_ths[d]),
44 | ]
45 | for d in range(dim)
46 | ]
47 | mats = [
48 | rearrange(torch.stack(_mats[d], -1), "... (h w)->... h w", h=2, w=2)
49 | for d in range(dim)
50 | ]
51 | mat = torch.stack(mats, -3)
52 | return mat
53 |
54 |
55 | # GTA
56 | @torch.jit.script
57 | def rep_mul_x(rep, x):
58 | # rep.shape=[T, F, 2, 2], x.shape=[B, H, T, F*2]
59 | shape = x.shape
60 | d = rep.shape[-1]
61 | return (
62 | (rep[None, None] * (x.unflatten(-1, (-1, d))[..., None, :])).sum(-1).view(shape)
63 | )
64 |
65 |
66 | @torch.jit.script
67 | def rep_mul_qkv(rep, q, k, v):
68 | return rep_mul_x(rep, q), rep_mul_x(rep, k), rep_mul_x(rep, v)
69 |
70 |
71 | @torch.jit.script
72 | def rep_mul_qk(rep, q, k):
73 | return rep_mul_x(rep, q), rep_mul_x(rep, k)
74 |
75 |
76 | def embed_block_diagonal(M, n):
77 | """
78 | Embed a [h*w, d/2, 2, 2] tensor M into a [h*w, d//2n, 4, 4] tensor M'
79 | with block diagonal structure.
80 |
81 | Args:
82 | M (torch.Tensor): Tensor of shape [h*w, d/2, 2, 2]
83 | n (int): Number of blocks to embed into 2nx2n structure
84 |
85 | Returns:
86 | torch.Tensor: Tensor of shape [h*w, d//2n, 4, 4]
87 | """
88 | h_w, d_half, _, _ = M.shape
89 |
90 | # Initialize an empty tensor for the block diagonal tensor M'
91 | M_prime = torch.zeros((h_w, d_half // n, 4, 4))
92 |
93 | # Embed M into the block diagonal structure of M_prime
94 | for t in range(h_w):
95 | for d in range(d_half // n):
96 | M_prime[t, d] = torch.block_diag(*[M[t, n * d + i] for i in range(n)])
97 | print(M_prime.shape)
98 | return M_prime
99 |
--------------------------------------------------------------------------------
/source/layers/klayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.parametrizations import weight_norm
4 |
5 | import numpy as np
6 |
7 | from source.layers.common_layers import (
8 | ScaleAndBias,
9 | Attention,
10 | )
11 |
12 | from source.layers.kutils import (
13 | reshape,
14 | reshape_back,
15 | normalize,
16 | )
17 |
18 | from einops.layers.torch import Rearrange
19 |
20 |
21 | class OmegaLayer(nn.Module):
22 |
23 | def __init__(self, n, ch, init_omg=0.1, global_omg=False, learn_omg=True):
24 | super().__init__()
25 | self.n = n
26 | self.ch = ch
27 | self.global_omg = global_omg
28 |
29 | if not learn_omg:
30 | print("Not learning omega")
31 |
32 | if n % 2 != 0:
33 | # n is odd
34 | raise NotImplementedError
35 | else:
36 | # n is even
37 | if global_omg:
38 | self.omg_param = nn.Parameter(
39 | init_omg * (1 / np.sqrt(2)) * torch.ones(2), requires_grad=learn_omg
40 | )
41 | else:
42 | self.omg_param = nn.Parameter(
43 | init_omg * (1 / np.sqrt(2)) * torch.ones(ch // 2, 2),
44 | requires_grad=learn_omg,
45 | )
46 |
47 | def forward(self, x):
48 | _x = reshape(x, 2)
49 | if self.global_omg:
50 | omg = torch.linalg.norm(self.omg_param).repeat(_x.shape[1])
51 | else:
52 | omg = torch.linalg.norm(self.omg_param, dim=1)
53 | omg = omg[None]
54 | for _ in range(_x.ndim - 3):
55 | omg = omg.unsqueeze(-1)
56 | omg_x = torch.stack([omg * _x[:, :, 1], -omg * _x[:, :, 0]], dim=2)
57 | omg_x = reshape_back(omg_x)
58 | return omg_x
59 |
60 |
61 | class KLayer(nn.Module): # Kuramoto layer
62 |
63 | def __init__(
64 | self,
65 | n,
66 | ch,
67 | J="conv",
68 | c_norm="gn",
69 | use_omega=False,
70 | init_omg=1.0,
71 | ksize=3,
72 | gta=False,
73 | hw=None,
74 | global_omg=False,
75 | heads=8,
76 | learn_omg=True,
77 | apply_proj=True,
78 | ):
79 | # connnectivity is either 'conv' or 'ca'
80 | super().__init__()
81 | assert (ch % n) == 0
82 | self.n = n
83 | self.ch = ch
84 | self.use_omega = use_omega
85 | self.global_omg = global_omg
86 | self.apply_proj = apply_proj
87 |
88 | self.omg = (
89 | OmegaLayer(n, ch, init_omg, global_omg, learn_omg)
90 | if self.use_omega
91 | else nn.Identity()
92 | )
93 |
94 | if J == "conv":
95 | self.connectivity = nn.Conv2d(ch, ch, ksize, 1, ksize // 2)
96 | self.x_type = "image"
97 | elif J == "attn":
98 | self.connectivity = Attention(
99 | ch,
100 | heads=heads,
101 | weight="conv",
102 | kernel_size=1,
103 | stride=1,
104 | padding=0,
105 | gta=gta,
106 | hw=hw,
107 | )
108 | self.x_type = "image"
109 | else:
110 | raise NotImplementedError
111 |
112 | if c_norm == "gn":
113 | self.c_norm = nn.GroupNorm(ch // n, ch, affine=True)
114 | elif c_norm == "sandb":
115 | self.c_norm = ScaleAndBias(ch, token_input=False)
116 | elif c_norm is None or c_norm == "none":
117 | self.c_norm = nn.Identity()
118 | else:
119 | raise NotImplementedError
120 |
121 | def project(self, y, x):
122 | sim = x * y # similarity between update and current state
123 | yxx = torch.sum(sim, 2, keepdim=True) * x
124 | return y - yxx, sim
125 |
126 | def kupdate(self, x: torch.Tensor, c: torch.Tensor = None):
127 | # compute \sum_j[J_ij x_j]
128 | _y = self.connectivity(x)
129 | # add bias c.
130 | y = _y + c
131 |
132 | if hasattr(self, "omg"):
133 | omg_x = self.omg(x)
134 | else:
135 | omg_x = torch.zeros_like(x)
136 |
137 | y = reshape(y, self.n)
138 | x = reshape(x, self.n)
139 |
140 | # project y onto the tangent space
141 | if self.apply_proj:
142 | y_yxx, sim = self.project(y, x)
143 | else:
144 | y_yxx = y
145 | sim = y * x
146 |
147 | dxdt = omg_x + reshape_back(y_yxx)
148 | sim = reshape_back(sim)
149 |
150 | return dxdt, sim
151 |
152 | def forward(self, x: torch.Tensor, c: torch.Tensor, T: int, gamma):
153 | # x.shape = c.shape = [B, C,...] or [B, T, C]
154 | xs, es = [], []
155 | c = self.c_norm(c)
156 | x = normalize(x, self.n)
157 | es.append(torch.zeros(x.shape[0]).to(x.device))
158 | # Iterate kuramoto update with condition c
159 | for t in range(T):
160 | dxdt, _sim = self.kupdate(x, c)
161 | x = normalize(x + gamma * dxdt, self.n)
162 | xs.append(x)
163 | es.append((-_sim).reshape(x.shape[0], -1).sum(-1))
164 |
165 | return xs, es
166 |
--------------------------------------------------------------------------------
/source/layers/kutils.py:
--------------------------------------------------------------------------------
1 | from sympy import prod
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | import math
6 | import einops
7 |
8 |
9 | def reshape(x: torch.Tensor, n: int):
10 | if x.ndim == 3: # x.shape = ([B, T, C ])
11 | return x.transpose(1, 2).unflatten(1, (-1, n))
12 | else: # x.shape = ([B, C, ..., ])
13 | return x.unflatten(1, (-1, n))
14 |
15 |
16 | def reshape_back(x):
17 | if x.ndim == 4: # Tokens
18 | return x.flatten(1, 2).transpose(1, 2)
19 | else:
20 | return x.flatten(1, 2)
21 |
22 |
23 | def _l2normalize(x):
24 | return torch.nn.functional.normalize(x, dim=2)
25 |
26 |
27 | def norm(n, x, dim=2, keepdim=True):
28 | return torch.linalg.norm(reshape(x, n), dim=dim, keepdim=keepdim)
29 |
30 |
31 | def normalize(x: torch.Tensor, n):
32 | x = reshape(x, n)
33 | x = _l2normalize(x)
34 | x = reshape_back(x)
35 | return x
36 |
37 | # currently not used
38 | def compute_exponential_map(n, x, dxdt, reshaped_inputs=False):
39 | if not reshaped_inputs:
40 | dxdt = reshape(dxdt, n)
41 | x = reshape(x, n)
42 | norm = torch.linalg.norm(dxdt, dim=2, keepdim=True)
43 | norm = torch.clip(norm, 0, math.pi)
44 | nx = torch.cos(norm) * x + torch.sin(norm) * (dxdt / (norm + 1e-5))
45 | if not reshaped_inputs:
46 | nx = reshape_back(nx)
47 | return nx
48 |
49 |
50 | class Normalize(nn.Module):
51 |
52 | def __init__(self, n):
53 | super().__init__()
54 | self.n = n
55 |
56 | def forward(self, x):
57 | return normalize(self.n, x)
58 |
--------------------------------------------------------------------------------
/source/models/objs/knet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | from source.layers.klayer import (
7 | KLayer,
8 | )
9 |
10 |
11 | from source.layers.common_layers import (
12 | RGBNormalize,
13 | ReadOutConv,
14 | Reshape,
15 | )
16 |
17 | from source.layers.common_fns import (
18 | positionalencoding2d,
19 | )
20 |
21 |
22 | class AKOrN(nn.Module):
23 |
24 | def __init__(
25 | self,
26 | n=4,
27 | ch=256,
28 | L=1,
29 | T=8,
30 | psize=4,
31 | gta=True,
32 | J="attn",
33 | ksize=1,
34 | c_norm="gn",
35 | gamma=1.0,
36 | imsize=128,
37 | use_omega=False,
38 | init_omg=1.0,
39 | global_omg=False,
40 | maxpool=True,
41 | project=True,
42 | heads=8,
43 | use_ro_x=False,
44 | learn_omg=True,
45 | no_ro=False,
46 | autorescale=True,
47 | ):
48 | super().__init__()
49 | # assuming input's range is [0, 1]
50 | self.patchfy = nn.Sequential(
51 | RGBNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
52 | nn.Conv2d(3, ch, kernel_size=psize, stride=psize, padding=0),
53 | )
54 |
55 | if not gta:
56 | self.pos_enc = True
57 | self.pemb_x = nn.Parameter(
58 | positionalencoding2d(ch, imsize // psize, imsize // psize).reshape(
59 | -1, imsize // psize, imsize // psize
60 | )
61 | )
62 | self.pemb_c = nn.Parameter(
63 | positionalencoding2d(ch, imsize // psize, imsize // psize).reshape(
64 | -1, imsize // psize, imsize // psize
65 | )
66 | )
67 | else:
68 | self.pos_enc = False
69 |
70 | self.n = n
71 | self.ch = ch
72 | self.L = L
73 | if isinstance(T, int):
74 | self.T = [T] * L
75 | else:
76 | self.T = T
77 | if isinstance(J, str):
78 | self.J = [J] * L
79 | else:
80 | self.J = J
81 | self.gamma = torch.nn.Parameter(torch.Tensor([gamma]), requires_grad=False)
82 | self.psize = psize
83 | self.imsize = imsize
84 |
85 | self.layers = nn.ModuleList()
86 | feature_hw = imsize // psize
87 |
88 | feature_hws = [feature_hw] * self.L
89 | chs = [ch] * (self.L + 1)
90 |
91 | for l in range(self.L):
92 | ch = chs[l]
93 | if l == self.L - 1:
94 | ch_next = chs[l + 1]
95 | else:
96 | ch_next = chs[l + 1]
97 |
98 | klayer = KLayer(
99 | n=n,
100 | ch=ch,
101 | J=self.J[l],
102 | gta=gta,
103 | c_norm=c_norm,
104 | use_omega=use_omega,
105 | init_omg=init_omg,
106 | global_omg=global_omg,
107 | heads=heads,
108 | learn_omg=learn_omg,
109 | ksize=ksize,
110 | hw=[feature_hws[l], feature_hws[l]],
111 | apply_proj=project,
112 | )
113 | readout = (
114 | ReadOutConv(ch, ch_next, self.n, 1, 1, 0)
115 | if not no_ro
116 | else nn.Identity()
117 | )
118 | linear_x = (
119 | nn.Conv2d(ch, ch_next, 1, 1, 0)
120 | if use_ro_x and l < self.L - 1
121 | else nn.Identity()
122 | )
123 | self.layers.append(nn.ModuleList([klayer, readout, linear_x]))
124 | ch = ch_next
125 |
126 | if maxpool:
127 | pool = nn.AdaptiveMaxPool2d((1, 1))
128 | else:
129 | pool = nn.AdaptiveAvgPool2d((1, 1))
130 |
131 | self.out = nn.Sequential(
132 | nn.Identity(),
133 | pool,
134 | Reshape(-1, ch),
135 | nn.Linear(ch, 4 * ch),
136 | nn.ReLU(),
137 | nn.Linear(4 * ch, ch),
138 | )
139 |
140 | self.fixed_ptb = False
141 | self.autorescale = autorescale
142 |
143 | def feature(self, inp):
144 | if self.autorescale and (
145 | inp.shape[2] != self.imsize or inp.shape[3] != self.imsize
146 | ):
147 | inp = F.interpolate(
148 | inp,
149 | (self.imsize, self.imsize),
150 | mode="bilinear",
151 | )
152 | c = self.patchfy(inp)
153 |
154 | if self.fixed_ptb:
155 | g = torch.Generator(device="cpu").manual_seed(1234)
156 | x = torch.randn(*(c.shape), generator=g).to(c.device)
157 | else:
158 | x = torch.randn_like(c)
159 |
160 | if self.pos_enc:
161 | c = c + self.pemb_c[None]
162 | x = x + self.pemb_x[None]
163 | xs = [x]
164 | es = [torch.zeros(x.shape[0], device=x.device)]
165 | for l, (kblock, ro, lin_x) in enumerate(self.layers):
166 | _xs, _es = kblock(x, c, T=self.T[l], gamma=self.gamma)
167 | x = _xs[-1]
168 | c = ro(x)
169 | x = lin_x(x)
170 | xs.append(_xs)
171 | es.append(_es)
172 |
173 | return c, x, xs, es
174 |
175 | def forward(self, input, return_xs=False, return_es=False):
176 | c, x, xs, es = self.feature(input)
177 | c = self.out(c)
178 |
179 | ret = [c]
180 | if return_xs:
181 | ret.append(xs)
182 | if return_es:
183 | ret.append(es)
184 |
185 | if len(ret) == 1:
186 | return ret[0]
187 | return ret
188 |
--------------------------------------------------------------------------------
/source/models/objs/patchconv.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/akorn/b018ff1ab4e1b3c32398d1e07d2ee9231dc425c1/source/models/objs/patchconv.py
--------------------------------------------------------------------------------
/source/models/objs/pretrained_vits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | from timm.models import VisionTransformer
5 | from source.layers.common_layers import RGBNormalize
6 |
7 |
8 | class ViTWrapper(nn.Module):
9 | def __init__(self, model):
10 | super().__init__()
11 | self.model = model
12 | self.norm = RGBNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
13 |
14 | def forward(self, x):
15 | x = self.norm(x)
16 | return self.model(x)
17 |
18 |
19 | def load_dino():
20 | model = timm.create_model(
21 | "vit_base_patch16_224_dino", pretrained=True, img_size=256
22 | )
23 | model = ViTWrapper(model).cuda()
24 | model.psize = 16
25 | return model
26 |
27 |
28 | def load_dinov2(imsize=256):
29 | model = timm.create_model(
30 | "vit_large_patch14_dinov2.lvd142m", pretrained=True, img_size=imsize
31 | )
32 | model = ViTWrapper(model).cuda()
33 | model.psize = 16
34 | return model
35 |
36 |
37 | def load_mocov3():
38 | from timm.models.vision_transformer import vit_base_patch16_224
39 |
40 | model = vit_base_patch16_224(pretrained=False, dynamic_img_size=True)
41 | checkpoint = torch.hub.load_state_dict_from_url(
42 | "https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar",
43 | map_location="cpu",
44 | )
45 | # Load the MoCo v3 state dict into the model
46 | state_dict = checkpoint["state_dict"]
47 | new_state_dict = {
48 | k.replace("module.base_encoder.", ""): v for k, v in state_dict.items()
49 | }
50 | model.load_state_dict(new_state_dict, strict=False)
51 | model.eval()
52 | model = ViTWrapper(model).cuda()
53 | model.img_size = 256
54 | model.psize = 16
55 | return model
56 |
57 |
58 | def load_mae():
59 | from timm.models.vision_transformer import vit_base_patch16_224
60 |
61 | model = vit_base_patch16_224(pretrained=False, dynamic_img_size=True)
62 | checkpoint = torch.hub.load_state_dict_from_url(
63 | "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth",
64 | map_location="cpu",
65 | ) # Load the state_dict into the model
66 | model.load_state_dict(
67 | checkpoint["model"], strict=False
68 | ) # Set the model to evaluation mode model.eval()
69 | model.eval()
70 | model = ViTWrapper(model).cuda()
71 | model.img_size = 256
72 | model.psize = 16
73 | return model
74 |
--------------------------------------------------------------------------------
/source/models/objs/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 | import einops
7 | from einops.layers.torch import Rearrange, Reduce
8 | from source.layers.common_layers import PatchEmbedding, RGBNormalize, LayerNormForImage
9 | from source.layers.common_fns import positionalencoding2d
10 | from source.layers.common_layers import Attention
11 | from source.layers.common_layers import Reshape
12 |
13 |
14 | class TransformerBlock(nn.Module):
15 | def __init__(
16 | self,
17 | embed_dim,
18 | num_heads,
19 | mlp_dim,
20 | dropout=0.0,
21 | hw=None,
22 | gta=False,
23 | ):
24 | super().__init__()
25 | self.layernorm1 = LayerNormForImage(embed_dim)
26 | self.attn = Attention(
27 | embed_dim,
28 | num_heads,
29 | weight="conv",
30 | gta=gta,
31 | hw=hw,
32 | )
33 | self.layernorm2 = LayerNormForImage(embed_dim)
34 | self.mlp = nn.Sequential(
35 | nn.Conv2d(embed_dim, mlp_dim, 1, 1, 0),
36 | nn.GELU(),
37 | nn.Conv2d(mlp_dim, embed_dim, 1, 1, 0),
38 | nn.Dropout(dropout),
39 | )
40 |
41 | def forward(self, src, T):
42 | xs = []
43 | # Repeat attention T times
44 | for _ in range(T):
45 | src2 = self.layernorm1(src)
46 | src2 = self.attn(src2)
47 | src = src + src2
48 | xs.append(src)
49 |
50 | src2 = self.layernorm2(src)
51 | src2 = self.mlp(src2)
52 | src = src + src2
53 | return src, xs
54 |
55 |
56 | class ViT(nn.Module):
57 | # ViT with iterlative self-attention
58 | def __init__(
59 | self,
60 | imsize=128,
61 | psize=8,
62 | ch=128,
63 | blocks=1,
64 | heads=4,
65 | mlp_dim=256,
66 | T=8,
67 | maxpool=False,
68 | gta=False,
69 | autorescale=False,
70 | ):
71 | super().__init__()
72 | self.T = T
73 | self.psize = psize
74 | self.autorescale = autorescale
75 | self.patchfy = nn.Sequential(
76 | RGBNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
77 | nn.Conv2d(3, ch, kernel_size=psize, stride=psize, padding=0),
78 | )
79 | if not gta:
80 | self.pos_embed = nn.Parameter(
81 | positionalencoding2d(ch, imsize // psize, imsize // psize)
82 | .reshape(-1, imsize // psize, imsize // psize)
83 | )
84 |
85 | self.transformer_encoder = nn.ModuleList(
86 | [
87 | TransformerBlock(
88 | ch,
89 | heads,
90 | mlp_dim,
91 | 0.0,
92 | hw=[imsize // psize, imsize // psize],
93 | gta=gta,
94 | )
95 | for _ in range(blocks)
96 | ]
97 | )
98 |
99 | self.out = torch.nn.Sequential(
100 | LayerNormForImage(ch),
101 | (
102 | nn.AdaptiveMaxPool2d((1, 1))
103 | if not maxpool
104 | else nn.AdaptiveMaxPool2d((1, 1))
105 | ),
106 | Reshape(-1, ch),
107 | nn.Linear(ch, 4 * ch),
108 | nn.ReLU(),
109 | nn.Linear(4 * ch, ch),
110 | )
111 |
112 | def feature(self, x):
113 | if self.autorescale and (
114 | x.shape[2] != self.imsize or x.shape[3] != self.imsize
115 | ):
116 | x = F.interpolate(
117 | x,
118 | (self.imsize, self.imsize),
119 | mode="bilinear",
120 | )
121 | x = self.patchfy(x)
122 | if hasattr(self, "pos_embed"):
123 | x = x + self.pos_embed[None]
124 | xs = [x]
125 | for block in self.transformer_encoder:
126 | x, _xs = block(x, self.T)
127 | xs.append(_xs)
128 | return x, xs
129 |
130 | def forward(self, x, return_xs=False):
131 | x, xs = self.feature(x)
132 | x = self.out(x)
133 | if return_xs:
134 | return x, xs
135 | else:
136 | return x
137 |
--------------------------------------------------------------------------------
/source/models/sudoku/knet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from source.layers.klayer import (
4 | KLayer,
5 | )
6 | from source.layers.common_layers import (
7 | ReadOutConv,
8 | BNReLUConv2d,
9 | FF,
10 | ResBlock,
11 | )
12 | from source.layers.common_fns import positionalencoding2d
13 |
14 |
15 | from source.data.datasets.sudoku.sudoku import convert_onehot_to_int
16 |
17 |
18 | class SudokuAKOrN(nn.Module):
19 |
20 | def __init__(
21 | self,
22 | n,
23 | ch=64,
24 | L=1,
25 | T=16,
26 | gamma=1.0,
27 | J="attn",
28 | use_omega=True,
29 | global_omg=True,
30 | init_omg=0.1,
31 | learn_omg=False,
32 | nl=True,
33 | heads=8,
34 | ):
35 | super().__init__()
36 | self.n = n
37 | self.L = L
38 | self.ch = ch
39 | self.embedding = nn.Embedding(10, ch)
40 |
41 | hw = [9, 9]
42 |
43 | self.layers = nn.ModuleList()
44 | for l in range(self.L):
45 | self.layers.append(
46 | nn.ModuleList(
47 | [
48 | KLayer(
49 | n,
50 | ch,
51 | J,
52 | use_omega=use_omega,
53 | c_norm=None,
54 | hw=hw,
55 | global_omg=global_omg,
56 | init_omg=init_omg,
57 | heads=heads,
58 | learn_omg=learn_omg,
59 | gta=True,
60 | ),
61 | nn.Sequential(
62 | ReadOutConv(ch, ch, n, 1, 1, 0),
63 | ResBlock(FF(ch, ch, ch, 1, 1, 0)) if nl else nn.Identity(),
64 | BNReLUConv2d(ch, ch, 1, 1, 0) if nl else nn.Identity(),
65 | ),
66 | ]
67 | )
68 | )
69 |
70 | self.out = nn.Sequential(nn.ReLU(), nn.Conv2d(ch, 9, 1, 1, 0))
71 |
72 | self.T = T
73 | self.gamma = torch.nn.Parameter(torch.Tensor([gamma]))
74 | self.fixed_noise = False
75 | self.x0 = nn.Parameter(torch.randn(1, ch, 9, 9))
76 |
77 | def feature(self, inp, is_input):
78 | # inp: torch.Tensor of shape [B, 9, 9, 9] the last dim repreents the digit in the one-hot representation.
79 | inp = convert_onehot_to_int(inp)
80 | c = self.embedding(inp).permute(0, 3, 1, 2)
81 | is_input = is_input.permute(0, 3, 1, 2)
82 | xs = []
83 | es = []
84 |
85 | # generate random oscillatores
86 | if self.fixed_noise:
87 | n = torch.randn(
88 | *(c.shape), generator=torch.Generator(device="cpu").manual_seed(42)
89 | ).to(c.device)
90 | x = is_input * c + (1 - is_input) * n
91 | else:
92 | n = torch.randn_like(c)
93 | x = is_input * c + (1 - is_input) * n
94 |
95 | for _, (klayer, readout) in enumerate(self.layers):
96 | # Process x and c.
97 | _xs, _es = klayer(
98 | x,
99 | c,
100 | self.T,
101 | self.gamma,
102 | )
103 | xs.append(_xs)
104 | es.append(_es)
105 |
106 | x = _xs[-1]
107 | c = readout(x)
108 |
109 | return c, xs, es
110 |
111 | def forward(self, c, is_input, return_xs=False, return_es=False):
112 | out, xs, es = self.feature(c, is_input)
113 | out = self.out(out).permute(0, 2, 3, 1)
114 | ret = [out]
115 | if return_xs:
116 | ret.append(xs)
117 | if return_es:
118 | ret.append(es)
119 | if len(ret) == 1:
120 | return ret[0]
121 | else:
122 | return ret
123 |
--------------------------------------------------------------------------------
/source/models/sudoku/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 | from source.layers.common_layers import Attention
7 | from source.layers.common_fns import positionalencoding2d
8 |
9 | from source.data.datasets.sudoku.sudoku import convert_onehot_to_int
10 |
11 |
12 | class TransformerBlock(nn.Module):
13 | def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.0, hw=None, gta=True):
14 | super().__init__()
15 | self.layernorm1 = nn.LayerNorm(embed_dim)
16 | self.attn = Attention(embed_dim, num_heads, weight="fc", gta=gta, hw=hw)
17 | self.layernorm2 = nn.LayerNorm(embed_dim)
18 | self.mlp = nn.Sequential(
19 | nn.Linear(embed_dim, mlp_dim),
20 | nn.GELU(),
21 | nn.Linear(mlp_dim, embed_dim),
22 | nn.Dropout(dropout),
23 | )
24 |
25 | def forward(self, src, T):
26 | # Repeat attention T times
27 | for _ in range(T):
28 | src2 = self.layernorm1(src)
29 | src2 = self.attn(src2)
30 | src = src + src2
31 |
32 | src2 = self.layernorm2(src)
33 | src2 = self.mlp(src2)
34 | src = src + src2
35 | return src
36 |
37 |
38 | class SudokuTransformer(nn.Module):
39 | def __init__(
40 | self,
41 | ch=64,
42 | blocks=6,
43 | heads=4,
44 | mlp_dim=1024,
45 | T=16,
46 | gta=False,
47 | ):
48 | super().__init__()
49 | self.T = T
50 |
51 | self.embedding = nn.Embedding(10, ch)
52 | if not gta:
53 | self.pos_embed = nn.Parameter(
54 | positionalencoding2d(ch, 9, 9)
55 | .reshape(-1,9,9)
56 | .flatten(1, 2)
57 | .transpose(0, 1)
58 | )
59 |
60 | self.transformer_encoder = nn.ModuleList(
61 | [
62 | TransformerBlock(
63 | ch, heads, mlp_dim, 0.0, hw=(9, 9), gta=gta
64 | )
65 | for _ in range(blocks)
66 | ]
67 | )
68 |
69 | self.out = torch.nn.Sequential(nn.LayerNorm(ch), nn.Linear(ch, 9))
70 |
71 | def forward(self, x, is_input):
72 | B = x.size(0)
73 | x = convert_onehot_to_int(x)
74 | x = self.embedding(x)
75 | x = x.view(B, -1, x.shape[-1]) # [B, H*W, C]
76 | is_input = is_input.view(B, -1)
77 | if hasattr(self, "pos_embed"):
78 | x = x + self.pos_embed[None]
79 | for block in self.transformer_encoder:
80 | x = block(x, self.T)
81 | x = self.out(x)
82 | x = x.view(-1, 9, 9, 9)
83 | return x
84 |
--------------------------------------------------------------------------------
/source/training_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | def add_gradient_histograms(writer, model, epoch):
6 | for name, param in model.named_parameters():
7 | if param.grad is not None:
8 | writer.add_histogram(name + "/grad", param.grad, epoch)
9 |
10 |
11 | def save_model(model, epoch, checkpoint_dir, prefix="checkpoint"):
12 |
13 | if not os.path.exists(checkpoint_dir):
14 | os.makedirs(checkpoint_dir)
15 |
16 | checkpoint_path = os.path.join(checkpoint_dir, f"{prefix}_{epoch}.pth")
17 |
18 | torch.save(
19 | {
20 | "epoch": epoch,
21 | "model_state_dict": model.state_dict(),
22 | },
23 | checkpoint_path,
24 | )
25 |
26 | print(f"Model saved: {checkpoint_path}")
27 |
28 |
29 | def save_checkpoint(
30 | model, optimizer, epoch, loss, checkpoint_dir, max_checkpoints=None
31 | ):
32 |
33 | if not os.path.exists(checkpoint_dir):
34 | os.makedirs(checkpoint_dir)
35 |
36 | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth")
37 |
38 | torch.save(
39 | {
40 | "epoch": epoch,
41 | "model_state_dict": model.state_dict(),
42 | "optimizer_state_dict": optimizer.state_dict(),
43 | "loss": loss,
44 | },
45 | checkpoint_path,
46 | )
47 |
48 | print(f"Checkpoint saved: {checkpoint_path}")
49 |
50 | manage_checkpoints(checkpoint_dir, max_checkpoints)
51 |
52 |
53 | def manage_checkpoints(checkpoint_dir, max_checkpoints):
54 | if max_checkpoints is None:
55 | return
56 | else:
57 | checkpoints = [
58 | f
59 | for f in os.listdir(checkpoint_dir)
60 | if f.startswith("checkpoint_") and f.endswith(".pth")
61 | ]
62 | checkpoints.sort(key=lambda f: int(f.split("_")[1].split(".")[0]))
63 |
64 | while len(checkpoints) > max_checkpoints:
65 | old_checkpoint = checkpoints.pop(0)
66 | os.remove(os.path.join(checkpoint_dir, old_checkpoint))
67 | print(f"Old checkpoint removed: {old_checkpoint}")
68 |
69 |
70 | class LinearWarmupScheduler(_LRScheduler):
71 | def __init__(self, optimizer, warmup_iters, last_iter=-1):
72 | self.warmup_iters = warmup_iters
73 | self.current_iter = 0 if last_iter == -1 else last_iter
74 | self.base_lrs = [group["lr"] for group in optimizer.param_groups]
75 | super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch=last_iter)
76 |
77 | def get_lr(self):
78 | if self.current_iter < self.warmup_iters:
79 | # Linear warmup phase
80 | return [
81 | base_lr * (self.current_iter + 1) / self.warmup_iters
82 | for base_lr in self.base_lrs
83 | ]
84 | else:
85 | # Maintain the base learning rate after warmup
86 | return [base_lr for base_lr in self.base_lrs]
87 |
88 | def step(self, it=None):
89 | if it is None:
90 | it = self.current_iter + 1
91 | self.current_iter = it
92 | super(LinearWarmupScheduler, self).step(it)
93 |
--------------------------------------------------------------------------------
/source/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | PLOTCOLORS = {
7 | "blue": "#377eb8",
8 | "orange": "#ff7f00",
9 | "green": "#4daf4a",
10 | "pink": "#f781bf",
11 | "brown": "#a65628",
12 | "purple": "#984ea3",
13 | "gray": "#999999",
14 | "red": "#e41a1c",
15 | "lightgray": "#d3d3d3",
16 | "lightgreen": "#90ee90",
17 | "yellow": "#dede00",
18 | }
19 |
20 |
21 | def ConvSingularValues(kernel, input_shape):
22 | transforms = torch.fft.fft2(kernel.permute(2, 3, 0, 1), input_shape, dim=[0, 1])
23 | print(transforms.shape)
24 | return torch.linalg.svd(transforms)
25 |
26 |
27 | def ConvSingularValuesNumpy(kernel, input_shape):
28 | kernel = kernel.detach().cpu().numpy()
29 | transforms = np.fft.fft2(kernel.transpose(2, 3, 0, 1), input_shape, axes=[0, 1])
30 | # transforms = np.fft.fft2(kernel.permute(2,3,0,1), input_shape, dim=[0,1])
31 | print(transforms.shape)
32 | return np.linalg.svd(transforms)
33 |
34 |
35 | def EigenValues(kernel, input_shape):
36 | # transforms = np.fft.fft2(kernel, input_shape, axes=[0, 1])
37 | transforms = torch.fft.fft2(kernel.permute(2, 3, 0, 1), input_shape, dim=[0, 1])
38 | print(transforms.shape)
39 | return torch.linalg.eig(transforms)
40 |
41 |
42 | def load_state_dict_ignore_size_mismatch(model, state_dict):
43 | model_state_dict = model.state_dict()
44 | matched_state_dict = {}
45 |
46 | for key, param in state_dict.items():
47 | if key in model_state_dict:
48 | if model_state_dict[key].shape == param.shape:
49 | matched_state_dict[key] = param
50 | else:
51 | print(
52 | f"Size mismatch for key '{key}': model {model_state_dict[key].shape}, checkpoint {param.shape}"
53 | )
54 | else:
55 | print(f"Key '{key}' not found in model state dict.")
56 |
57 | model_state_dict.update(matched_state_dict)
58 | model.load_state_dict(model_state_dict)
59 |
60 |
61 | def compare_optimizer_state_dicts(original, modified):
62 | diff = {}
63 | for key in original.keys():
64 | if key not in modified:
65 | diff[key] = "Removed"
66 | elif original[key] != modified[key]:
67 | diff[key] = {"Original": original[key], "Modified": modified[key]}
68 | for key in modified.keys():
69 | if key not in original:
70 | diff[key] = "Added"
71 | return diff
72 |
73 |
74 | def get_worker_init_fn(start, end):
75 | return lambda worker_id: os.sched_setaffinity(0, range(start, end))
76 |
77 |
78 | def str2bool(x):
79 | if isinstance(x, bool):
80 | return x
81 | x = x.lower()
82 | if x[0] in ["0", "n", "f"]:
83 | return False
84 | elif x[0] in ["1", "y", "t"]:
85 | return True
86 | raise ValueError("Invalid value: {}".format(x))
87 |
88 |
89 | def apply_pca(x, n_components=3):
90 | # x.shape = [B, C, H, W]
91 | from sklearn.decomposition import PCA
92 |
93 | pca = PCA(n_components)
94 | nx = []
95 | d = x.shape[1]
96 | for _x in x:
97 | _x = _x.permute(1, 2, 0).reshape(-1, d)
98 | _x = pca.fit_transform(_x)
99 | _x = _x.transpose(1, 0).reshape(n_components, x.shape[2], x.shape[3])
100 | nx.append(torch.tensor(_x))
101 | nx = torch.stack(nx, 0)
102 | # normalize to [0, 1]
103 | nx = (nx - nx.min()) / (nx.max() - nx.min())
104 | return nx
105 |
106 |
107 | def apply_pca_torch(x, n_components=3):
108 | # x: [B, C, H, W]
109 | B, C, H, W = x.shape
110 | N = H * W
111 |
112 | if n_components >= C:
113 | return x
114 |
115 | # Reshape to [B, N, C]
116 | x = x.permute(0, 2, 3, 1).reshape(B, N, C)
117 |
118 | # Center the data per sample
119 | x_mean = x.mean(dim=1, keepdim=True) # [B, 1, C]
120 | x_centered = x - x_mean # [B, N, C]
121 |
122 | # Compute covariance matrix per sample: [B, C, C]
123 | cov = torch.bmm(x_centered.transpose(1, 2), x_centered) / (N - 1)
124 |
125 | # Compute eigenvalues and eigenvectors per sample
126 | eigenvalues, eigenvectors = torch.linalg.eigh(
127 | cov
128 | ) # eigenvalues: [B, C], eigenvectors: [B, C, C]
129 |
130 | # Reverse the order of eigenvalues and eigenvectors to get descending order
131 | eigenvalues = eigenvalues.flip(dims=[1])
132 | eigenvectors = eigenvectors.flip(dims=[2])
133 |
134 | # Select the top 'dim' eigenvectors
135 | top_eigenvectors = eigenvectors[:, :, :n_components] # [B, C, dim]
136 |
137 | # Project the centered data onto the top eigenvectors
138 | x_pca = torch.bmm(x_centered, top_eigenvectors) # [B, N, dim]
139 |
140 | # Reshape back to [B, dim, H, W]
141 | x_pca = x_pca.transpose(1, 2).reshape(B, n_components, H, W)
142 |
143 | return x_pca
144 |
145 |
146 | def gen_saccade_imgs(img, psize, r):
147 | H, W = img.shape[-2:]
148 | img = F.interpolate(img, (H + psize - r, W + psize - r), mode="bicubic")
149 | imgs = []
150 | for h in range(0, psize, r):
151 | for w in range(0, psize, r):
152 | imgs.append(img[:, :, h : h + H, w : w + W])
153 | return imgs, img[:, :, psize // 2 : H + psize // 2, psize // 2 : W + psize // 2]
154 |
--------------------------------------------------------------------------------
/train_obj.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import logging
4 |
5 | import torch
6 | import torch.distributed
7 | import torch.nn as nn
8 | from torch import optim
9 | from source.training_utils import save_checkpoint, save_model, LinearWarmupScheduler, add_gradient_histograms
10 | from source.data.datasets.objs.load_data import load_data
11 |
12 | from source.utils import str2bool
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | import accelerate
16 | from accelerate import Accelerator
17 |
18 | # Visualization
19 | import torch.nn.functional as F
20 | from tqdm import tqdm
21 |
22 | # for distributed training
23 | from torch.distributed.nn.functional import all_gather
24 |
25 |
26 | def create_logger(logging_dir):
27 | """
28 | Create a logger that writes to a log file and stdout.
29 | """
30 | logging.basicConfig(
31 | level=logging.INFO,
32 | format="[\033[34m%(asctime)s\033[0m] %(message)s",
33 | datefmt="%Y-%m-%d %H:%M:%S",
34 | handlers=[
35 | logging.StreamHandler(),
36 | logging.FileHandler(f"{logging_dir}/log.txt"),
37 | ],
38 | )
39 | logger = logging.getLogger(__name__)
40 | return logger
41 |
42 |
43 | def simclr(zs, temperature=1.0, normalize=True, loss_type="ip"):
44 | # zs: list of tensors. Each tensor has shape (n, d)
45 | if normalize:
46 | zs = [F.normalize(z, p=2, dim=-1) for z in zs]
47 | if zs[0].dim() == 3:
48 | zs = [z.flatten(1, 2) for z in zs]
49 | m = len(zs)
50 | n = zs[0].shape[0]
51 | device = zs[0].device
52 | mask = torch.eye(n * m, device=device)
53 | label0 = torch.fmod(n + torch.arange(0, m * n, device=device), n * m)
54 | z = torch.cat(zs, 0)
55 | if loss_type == "euclid": # euclidean distance
56 | sim = -torch.cdist(z, z)
57 | elif loss_type == "sq": # squared euclidean distance
58 | sim = -(torch.cdist(z, z) ** 2)
59 | elif loss_type == "ip": # inner product
60 | sim = torch.matmul(z, z.transpose(0, 1))
61 | else:
62 | raise NotImplementedError
63 | logit_zz = sim / temperature
64 | logit_zz += mask * -1e8
65 | loss = nn.CrossEntropyLoss()(logit_zz, label0)
66 | return loss
67 |
68 |
69 | if __name__ == "__main__":
70 |
71 | parser = argparse.ArgumentParser()
72 |
73 | # Training options
74 | parser.add_argument("--exp_name", type=str, help="expname")
75 | parser.add_argument("--seed", type=int, default=1234)
76 | parser.add_argument("--beta", type=float, default=0.998, help="ema decay")
77 | parser.add_argument("--epochs", type=int, default=500, help="num of epochs")
78 | parser.add_argument(
79 | "--checkpoint_every",
80 | type=int,
81 | default=50,
82 | help="save checkpoint every specified epochs",
83 | )
84 | parser.add_argument("--lr", type=float, default=1e-3, help="lr")
85 | parser.add_argument("--warmup_iters", type=int, default=0)
86 | parser.add_argument(
87 | "--finetune",
88 | type=str,
89 | default=None,
90 | help="path to the checkpoint. Training starts from that checkpoint",
91 | )
92 |
93 | # Data loading
94 | parser.add_argument("--limit_cores_used", type=str2bool, default=False)
95 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core")
96 | parser.add_argument("--cpu_core_end", type=int, default=32, help="end core")
97 | parser.add_argument("--data", type=str, default="clevrtex")
98 | parser.add_argument(
99 | "--data_root",
100 | type=str,
101 | default=None,
102 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset",
103 | )
104 | parser.add_argument("--batchsize", type=int, default=256)
105 | parser.add_argument("--num_workers", type=int, default=8)
106 | parser.add_argument(
107 | "--data_imsize",
108 | type=int,
109 | default=None,
110 | help="Image size. If None, use the default size of each dataset",
111 | )
112 |
113 | # Simclr options
114 | parser.add_argument("--normalize", type=str2bool, default=True)
115 | parser.add_argument("--temp", type=float, default=0.1, help="simclr temperature.")
116 |
117 | # General model options
118 | parser.add_argument("--model", type=str, default="akorn", help="model")
119 | parser.add_argument("--L", type=int, default=1, help="num of layers")
120 | parser.add_argument("--ch", type=int, default=256, help="num of channels")
121 | parser.add_argument(
122 | "--model_imsize",
123 | type=int,
124 | default=None,
125 | help=
126 | """
127 | Model's imsize. This is used when you want finetune a pretrained model
128 | that was trained on images with different resolution than the finetune image dataset.
129 | """
130 | )
131 | parser.add_argument("--autorescale", type=str2bool, default=False)
132 | parser.add_argument("--psize", type=int, default=8, help="patch size")
133 | parser.add_argument("--ksize", type=int, default=1, help="kernel size")
134 | parser.add_argument("--T", type=int, default=8, help="num of recurrence")
135 | parser.add_argument(
136 | "--maxpool", type=str2bool, default=True, help="max pooling or avg pooling"
137 | )
138 | parser.add_argument(
139 | "--heads", type=int, default=8, help="num of heads in self-attention"
140 | )
141 | parser.add_argument(
142 | "--gta",
143 | type=str2bool,
144 | default=True,
145 | help="""
146 | use Geometric Transform Attention (https://github.com/autonomousvision/gta) as positional encoding.
147 | Note that, different from the original GTA, the rotating matrices are learnable.
148 | If False, use standard absolute positional encoding used in the original transformer paper.
149 | """,
150 | )
151 |
152 | # AKOrN options
153 | parser.add_argument("--N", type=int, default=4, help="num of rotating dimensions")
154 | parser.add_argument("--gamma", type=float, default=1.0, help="step size")
155 | parser.add_argument("--J", type=str, default="conv", help="connectivity")
156 | parser.add_argument("--use_omega", type=str2bool, default=False)
157 | parser.add_argument("--global_omg", type=str2bool, default=False)
158 | parser.add_argument(
159 | "--c_norm",
160 | type=str,
161 | default="gn",
162 | help="normalization. gn(GroupNorm), sandb(scale and bias), or none",
163 | )
164 | parser.add_argument(
165 | "--init_omg", type=float, default=0.01, help="initial omega length"
166 | )
167 | parser.add_argument("--learn_omg", type=str2bool, default=False)
168 | parser.add_argument(
169 | "--use_ro_x",
170 | type=str2bool,
171 | default=False,
172 | help="apply linear transform to oscillators between consecutive layers",
173 | )
174 |
175 | # ablation of some components in the AKOrN's block
176 | parser.add_argument(
177 | "--no_ro", type=str2bool, default=False, help="ablation: no use readout module"
178 | )
179 | parser.add_argument(
180 | "--project",
181 | type=str2bool,
182 | default=True,
183 | help="use projection or not in the Kuramoto layer",
184 | )
185 |
186 | args = parser.parse_args()
187 | torch.backends.cudnn.benchmark = True
188 | torch.backends.cuda.enable_flash_sdp(enabled=True)
189 | # Setup accelerator
190 | accelerator = Accelerator()
191 | device = accelerator.device
192 | accelerate.utils.set_seed(args.seed + accelerator.process_index)
193 |
194 | import random
195 | import numpy as np
196 | torch.manual_seed(args.seed)
197 | random.seed(args.seed)
198 | np.random.seed(args.seed)
199 |
200 | # Create job directory and logger
201 | jobdir = f"runs/{args.exp_name}/"
202 | if accelerator.is_main_process:
203 | if not os.path.exists(jobdir):
204 | os.makedirs(jobdir) # Make results folder (holds all experiment subfolders)
205 | logger = create_logger(jobdir)
206 | logger.info(f"Experiment directory created at {jobdir}")
207 | else:
208 | logger = create_logger(jobdir)
209 |
210 | if args.limit_cores_used:
211 | def worker_init_fn(worker_id):
212 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end))
213 |
214 | else:
215 | worker_init_fn = None
216 |
217 | sstrainset, imsize, _ = load_data(args.data, args.data_root, args.data_imsize, False)
218 |
219 | if accelerator.is_main_process:
220 | logger.info(f"Dataset contains {len(sstrainset):,} images")
221 |
222 | ssloader = torch.utils.data.DataLoader(
223 | sstrainset,
224 | batch_size=int(args.batchsize // accelerator.num_processes),
225 | shuffle=True,
226 | num_workers=args.num_workers,
227 | worker_init_fn=worker_init_fn,
228 | )
229 |
230 | if accelerator.is_main_process:
231 | writer = SummaryWriter(jobdir)
232 |
233 | def train(net, ema, opt, scheduler, loader, epoch):
234 | losses = []
235 | initial_params = {name: param.clone() for name, param in net.named_parameters()}
236 | running_loss = 0.0
237 | n = 0
238 |
239 | for i, data in tqdm(enumerate(loader, 0)):
240 | net.train()
241 | inputs = data.view(-1, 3, imsize, imsize).to(device) # 2x batchsize
242 |
243 | # forward
244 | outputs = net(inputs)
245 |
246 | # gather outputs because simclr loss requires all outputs across all processes
247 | if accelerator.num_processes > 1:
248 | outputs = torch.cat(all_gather(outputs), 0)
249 | outputs = outputs.unflatten(0, (outputs.shape[0] // 2, 2))
250 |
251 | loss = simclr(
252 | [outputs[:, 0], outputs[:, 1]],
253 | temperature=args.temp,
254 | normalize=args.normalize,
255 | loss_type="ip",
256 | )
257 |
258 | opt.zero_grad()
259 | accelerator.backward(loss)
260 | opt.step()
261 |
262 | scheduler.step()
263 |
264 | running_loss += loss.item() * inputs.shape[0]
265 | n += inputs.shape[0]
266 |
267 | ema.update()
268 |
269 | if accelerator.is_main_process:
270 | add_gradient_histograms(writer, net, epoch)
271 | for name, param in net.named_parameters():
272 | diff = param - initial_params[name]
273 | writer.add_histogram(f"{name}_diff", diff, epoch)
274 | if accelerator.is_main_process:
275 | logger.info(
276 | f"[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss/n:.3f}"
277 | )
278 |
279 | total_loss = running_loss / n
280 | if accelerator.is_main_process:
281 | writer.add_scalar("training loss", total_loss, epoch)
282 |
283 | return total_loss
284 |
285 | if args.model == "akorn":
286 | from source.models.objs.knet import AKOrN
287 |
288 | net = AKOrN(
289 | args.N,
290 | ch=args.ch,
291 | L=args.L,
292 | T=args.T,
293 | gamma=args.gamma,
294 | J=args.J, # "conv" or "attn",
295 | use_omega=args.use_omega,
296 | global_omg=args.global_omg,
297 | c_norm=args.c_norm,
298 | psize=args.psize,
299 | imsize=imsize if args.model_imsize is None else args.model_imsize,
300 | autorescale=args.autorescale,
301 | init_omg=args.init_omg,
302 | learn_omg=args.learn_omg,
303 | maxpool=args.maxpool,
304 | project=args.project,
305 | heads=args.heads,
306 | use_ro_x=args.use_ro_x,
307 | no_ro=args.no_ro,
308 | gta=args.gta,
309 | ).to("cuda")
310 |
311 | elif args.model == "vit":
312 | from source.models.objs.vit import ViT
313 | # T=1: ViT. T > 1: ItrSA.
314 | net = ViT(
315 | psize=args.psize,
316 | imsize=imsize if args.model_imsize is None else args.model_imsize,
317 | autorescale=args.autorescale,
318 | ch=args.ch,
319 | blocks=args.L,
320 | heads=args.heads,
321 | mlp_dim=2 * args.ch,
322 | T=args.T,
323 | maxpool=args.maxpool,
324 | gta=args.gta,
325 | ).cuda()
326 | else:
327 | raise NotImplementedError
328 |
329 | total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
330 | print(f"Total number of basemodel parameters: {total_params}")
331 |
332 | if args.finetune:
333 | if accelerator.is_main_process:
334 | logger.info("Loading checkpoint...")
335 | net.load_state_dict(torch.load(args.finetune)["model_state_dict"])
336 |
337 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.0)
338 |
339 | if args.finetune:
340 | if accelerator.is_main_process:
341 | logger.info("Loading optimizer state...")
342 | optimizer.load_state_dict(torch.load(args.finetune)["optimizer_state_dict"])
343 | for param_group in optimizer.param_groups:
344 | param_group["lr"] = args.lr
345 |
346 | from ema_pytorch import EMA
347 |
348 | ema = EMA(net, beta=args.beta, update_every=10, update_after_step=200)
349 |
350 | if args.finetune:
351 | if accelerator.is_main_process:
352 | logger.info("Loading checkpoint...")
353 | dir_name, file_name = os.path.split(args.finetune)
354 | file_name = file_name.replace("checkpoint", "ema")
355 | ema_path = os.path.join(dir_name, file_name)
356 | ema.load_state_dict(torch.load(ema_path)["model_state_dict"])
357 |
358 | if accelerator.is_main_process:
359 | logger.info(f"Training for {args.epochs} epochs...")
360 |
361 | net, optimizer, ssloader = accelerator.prepare(net, optimizer, ssloader)
362 |
363 | scheduler = LinearWarmupScheduler(optimizer, warmup_iters=args.warmup_iters)
364 |
365 | for epoch in range(0, args.epochs):
366 | total_loss = train(net, ema, optimizer, scheduler, ssloader, epoch)
367 | if (epoch + 1) % args.checkpoint_every == 0:
368 | if accelerator.is_main_process:
369 | save_checkpoint(
370 | accelerator.unwrap_model(net),
371 | optimizer,
372 | epoch,
373 | total_loss,
374 | checkpoint_dir=jobdir,
375 | )
376 | save_model(ema, epoch, checkpoint_dir=jobdir, prefix="ema")
377 | if accelerator.is_main_process:
378 | torch.save(
379 | accelerator.unwrap_model(net).state_dict(),
380 | os.path.join(jobdir, f"model.pth"),
381 | )
382 | torch.save(ema.state_dict(), os.path.join(jobdir, f"ema_model.pth"))
383 |
--------------------------------------------------------------------------------
/train_sudoku.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys, os
3 | import tqdm
4 | import argparse
5 |
6 | from source.models.sudoku.transformer import SudokuTransformer
7 |
8 | from source.training_utils import save_checkpoint, save_model
9 | from source.data.datasets.sudoku.sudoku import SudokuDataset, HardSudokuDataset
10 | from source.models.sudoku.knet import SudokuAKOrN
11 | from source.utils import str2bool
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import math
16 | from ema_pytorch import EMA
17 |
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 |
21 | def apply_threshold(model, threshold):
22 | with torch.no_grad():
23 | for param in model.parameters():
24 | param.data = torch.where(
25 | param.abs() < threshold, torch.tensor(0.0), param.data
26 | )
27 |
28 |
29 | if __name__ == "__main__":
30 |
31 | parser = argparse.ArgumentParser()
32 |
33 | parser.add_argument("--exp_name", type=str, help="expname")
34 | parser.add_argument("--seed", type=int, default=None, help="seed")
35 | parser.add_argument("--epochs", type=int, default=100, help="num of epochs")
36 | parser.add_argument("--lr", type=float, default=1e-3, help="lr")
37 | parser.add_argument("--beta", type=float, default=0.995, help="ema decay")
38 | parser.add_argument(
39 | "--clip_grad_norm", type=float, default=1.0, help="clip grad norm"
40 | )
41 | parser.add_argument(
42 | "--checkpoint_every",
43 | type=int,
44 | default=100,
45 | help="save checkpoint every specified epochs",
46 | )
47 | parser.add_argument("--eval_freq", type=int, default=10, help="freqadv eval")
48 |
49 | # Data loading
50 | parser.add_argument("--limit_cores_used", type=str2bool, default=False)
51 | parser.add_argument("--cpu_core_start", type=int, default=0, help="start core")
52 | parser.add_argument("--cpu_core_end", type=int, default=16, help="end core")
53 | parser.add_argument(
54 | "--data_root",
55 | type=str,
56 | default=None,
57 | help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset",
58 | )
59 | parser.add_argument("--batchsize", type=int, default=100)
60 | parser.add_argument("--num_workers", type=int, default=4)
61 |
62 | # General model options
63 | parser.add_argument("--model", type=str, default="akorn", help="model")
64 | parser.add_argument("--L", type=int, default=1, help="num of layers")
65 | parser.add_argument("--T", type=int, default=16, help="Timesteps")
66 | parser.add_argument("--ch", type=int, default=512, help="num of channels")
67 | parser.add_argument("--heads", type=int, default=8)
68 |
69 | # AKOrN options
70 | parser.add_argument("--N", type=int, default=4)
71 | parser.add_argument("--gamma", type=float, default=1.0, help="step size")
72 | parser.add_argument("--J", type=str, default="attn", help="connectivity")
73 | parser.add_argument("--use_omega", type=str2bool, default=True)
74 | parser.add_argument("--global_omg", type=str2bool, default=True)
75 | parser.add_argument("--learn_omg", type=str2bool, default=False)
76 | parser.add_argument("--init_omg", type=float, default=0.1)
77 | parser.add_argument("--nl", type=str2bool, default=True)
78 |
79 | parser.add_argument("--speed_test", action="store_true")
80 |
81 | args = parser.parse_args()
82 |
83 | print("Exp name: ", args.exp_name)
84 |
85 | torch.backends.cudnn.benchmark = True
86 | torch.backends.cuda.enable_flash_sdp(enabled=True)
87 |
88 | if args.seed is not None:
89 | import random
90 | import numpy as np
91 |
92 | torch.manual_seed(args.seed)
93 | random.seed(args.seed)
94 | np.random.seed(args.seed)
95 |
96 | def worker_init_fn(worker_id):
97 | os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end))
98 |
99 | if args.data_root is not None:
100 | rootdir = args.data_root
101 | else:
102 | rootdir = "./data/sudoku"
103 |
104 | trainloader = torch.utils.data.DataLoader(
105 | SudokuDataset(rootdir, train=True),
106 | batch_size=args.batchsize,
107 | shuffle=True,
108 | num_workers=args.num_workers,
109 | worker_init_fn=worker_init_fn,
110 | )
111 | testloader = torch.utils.data.DataLoader(
112 | SudokuDataset(rootdir, train=False),
113 | batch_size=100,
114 | shuffle=False,
115 | num_workers=args.num_workers,
116 | worker_init_fn=worker_init_fn,
117 | )
118 |
119 | jobdir = f"runs/{args.exp_name}/"
120 | writer = SummaryWriter(jobdir)
121 |
122 | # only compute digit-wise accuracy
123 | from source.evals.sudoku.evals import compute_board_accuracy
124 | def compute_acc(net, loader):
125 | net.eval()
126 | correct = 0
127 | total = 0
128 | correct_input = 0
129 | total_input = 0
130 | for X, Y, is_input in loader:
131 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda()
132 |
133 | with torch.no_grad():
134 | out = net(X, is_input)
135 |
136 | _, _, board_accuracy = compute_board_accuracy(out, Y, is_input)
137 | correct += board_accuracy.sum().item()
138 | total += board_accuracy.shape[0]
139 |
140 | # digit wise input accuracy
141 | out = out.argmax(dim=-1)
142 | Y = Y.argmax(dim=-1)
143 | mask = (1 - is_input).view(out.shape)
144 | correct_input += ((1 - mask) * (out == Y)).sum().item()
145 | total_input += (1 - mask).sum().item()
146 |
147 | acc = correct / total
148 | input_acc = correct_input / total_input
149 | return acc, input_acc, (total, correct), (total_input, correct_input)
150 |
151 | if args.model == "akorn":
152 | print(
153 | f"n: {args.N}, ch: {args.ch}, L: {args.L}, T: {args.T}, type of J: {args.J}"
154 | )
155 | net = SudokuAKOrN(
156 | n=args.N,
157 | ch=args.ch,
158 | L=args.L,
159 | T=args.T,
160 | gamma=args.gamma,
161 | J=args.J,
162 | use_omega=args.use_omega,
163 | global_omg=args.global_omg,
164 | init_omg=args.init_omg,
165 | learn_omg=args.learn_omg,
166 | nl=args.nl,
167 | heads=args.heads,
168 | )
169 | elif args.model == "itrsa":
170 | net = SudokuTransformer(
171 | ch=args.ch,
172 | blocks=args.L,
173 | heads=args.heads,
174 | mlp_dim=args.ch * 2,
175 | T=args.T,
176 | gta=False,
177 | )
178 | else:
179 | raise NotImplementedError
180 |
181 | net.cuda()
182 |
183 | total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
184 | print(f"Total number of parameters: {total_params}")
185 |
186 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
187 |
188 | ema = EMA(net, beta=args.beta, update_every=10, update_after_step=100)
189 |
190 | criterion = torch.nn.CrossEntropyLoss(reduction="none")
191 |
192 | # Measure speed
193 | if args.speed_test:
194 | it_sp = 0
195 | time_per_iter = []
196 | import numpy as np
197 |
198 | for epoch in range(args.epochs):
199 | total_loss = 0
200 |
201 | for X, Y, is_input in tqdm.tqdm(trainloader):
202 | net.train()
203 | ema.train()
204 | X, Y, is_input = X.to(torch.int32).cuda(), Y.cuda(), is_input.cuda()
205 |
206 | if args.speed_test:
207 | start = torch.cuda.Event(enable_timing=True)
208 | end = torch.cuda.Event(enable_timing=True)
209 | start.record()
210 |
211 | out = net(X, is_input)
212 |
213 | out = out.reshape(-1, 9)
214 | Y = Y.argmax(dim=-1).reshape(-1)
215 |
216 | loss = criterion(out, Y).mean()
217 |
218 | optimizer.zero_grad()
219 | loss.backward()
220 | if args.clip_grad_norm > 0.:
221 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad_norm)
222 | optimizer.step()
223 |
224 | if args.speed_test:
225 | end.record()
226 | torch.cuda.synchronize()
227 | time_elapsed_per_iter = start.elapsed_time(end)
228 | time_per_iter.append(time_elapsed_per_iter)
229 | print(time_elapsed_per_iter)
230 | it_sp = it_sp + 1
231 | if it_sp == 100:
232 | np.save(os.path.join(jobdir, "time.npy"), np.array(time_per_iter))
233 | exit(0)
234 |
235 | total_loss += loss.item()
236 | ema.update()
237 |
238 | total_loss = total_loss / len(trainloader)
239 |
240 | writer.add_scalar("training loss", total_loss, epoch)
241 | print(f"Epoch [{epoch+1}/{args.epochs}], Loss: {total_loss:.4f}")
242 |
243 | if (epoch + 1) % args.eval_freq == 0:
244 |
245 | acc, input_acc, stats, stats_input = compute_acc(net, testloader)
246 | writer.add_scalar("test/accuracy", acc, epoch)
247 | writer.add_scalar("test/input_accuracy", input_acc, epoch)
248 | print(f"[Test]: Total blanks:{stats[0]}, Accuracy: {acc}")
249 | print(
250 | f"[Test]: Total given squares:{stats_input[0]}, Accuracy on given digits: {input_acc}"
251 | )
252 |
253 | # EMA evals
254 | acc, input_acc, stats, stats_input = compute_acc(ema.ema_model, testloader)
255 | writer.add_scalar("ema_test/accuracy", acc, epoch)
256 | writer.add_scalar("ema_test/input_accuracy", input_acc, epoch)
257 | print(f"[EMA Test]: Total blanks:{stats[0]}, Accuracy: {acc}")
258 | print(
259 | f"[EMA Test]: Total given squares:{stats_input[0]}, Accuracy on given digits: {input_acc}"
260 | )
261 |
262 | if (epoch + 1) % args.checkpoint_every == 0:
263 | save_checkpoint(net, optimizer, epoch, total_loss, checkpoint_dir=jobdir)
264 | save_model(ema, epoch, checkpoint_dir=jobdir, prefix="ema")
265 |
266 | torch.save(net.state_dict(), os.path.join(jobdir, f"model.pth"))
267 | torch.save(ema.state_dict(), os.path.join(jobdir, f"ema_model.pth"))
268 |
--------------------------------------------------------------------------------