├── .gitignore ├── Generators ├── SimMIMGen.py └── WDTaggerGen.py ├── Metrics ├── ConfusionMatrix.py ├── Precision.py └── Recall.py ├── Models ├── ConvNext.py ├── EVA02.py ├── HiViT.py ├── SimMIM.py ├── SwinV2.py ├── VGG.py ├── ViT.py └── __init__.py ├── Translators └── TFRecord.py ├── pretraining_loop.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Generators/SimMIMGen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2): 6 | gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1) 7 | gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0) 8 | return gamma_1_sample / (gamma_1_sample + gamma_2_sample) 9 | 10 | 11 | class DataGenerator: 12 | def __init__( 13 | self, 14 | records_path, 15 | num_classes=2380, 16 | image_size=320, 17 | batch_size=32, 18 | num_devices=0, 19 | noise_level=0, 20 | mixup_alpha=0.2, 21 | rotation_ratio=0.25, 22 | cutout_max_pct=0.25, 23 | cutout_patches=1, 24 | random_resize_method=True, 25 | mask_patch_size=32, 26 | model_patch_size=4, 27 | mask_ratio=0.6, 28 | repeat=True, 29 | ): 30 | """ 31 | Noise level 1: augmentations I will never train without 32 | unless I'm dealing with extremely small networks 33 | (Random rotation, random cropping and random flipping) 34 | 35 | Noise level 2: more advanced stuff (MixUp) 36 | """ 37 | 38 | self.records_path = records_path 39 | self.num_classes = num_classes 40 | self.image_size = image_size 41 | self.batch_size = batch_size 42 | self.num_devices = num_devices 43 | self.noise_level = noise_level 44 | self.mixup_alpha = mixup_alpha 45 | self.rotation_ratio = rotation_ratio 46 | self.random_resize_method = random_resize_method 47 | 48 | self.cutout_max_pct = cutout_max_pct 49 | self.cutout_replace = 127 50 | self.cutout_patches = cutout_patches 51 | 52 | self.mask_patch_size = mask_patch_size 53 | self.model_patch_size = model_patch_size 54 | self.mask_ratio = mask_ratio 55 | 56 | self.repeat = repeat 57 | 58 | def gen_mask(self, image, labels): 59 | rand_size = self.image_size // self.mask_patch_size 60 | scale = self.mask_patch_size // self.model_patch_size 61 | 62 | token_count = rand_size**2 63 | mask_count = tf.math.ceil(token_count * self.mask_ratio) 64 | mask_count = tf.cast(mask_count, tf.int32) 65 | 66 | mask_idx = tf.random.uniform((token_count,)) 67 | mask_idx = tf.argsort(mask_idx)[:mask_count] 68 | 69 | mask = tf.reduce_max(tf.one_hot(mask_idx, token_count, dtype=tf.uint8), axis=0) 70 | mask = tf.reshape(mask, (rand_size, rand_size)) 71 | mask = tf.repeat(mask, scale, axis=0) 72 | mask = tf.repeat(mask, scale, axis=1) 73 | return image, mask, labels 74 | 75 | def parse_single_record(self, example_proto): 76 | feature_description = { 77 | "image_id": tf.io.FixedLenFeature([], tf.int64), 78 | "image_bytes": tf.io.FixedLenFeature([], tf.string), 79 | "label_indexes": tf.io.VarLenFeature(tf.int64), 80 | } 81 | 82 | # Parse the input 'tf.train.Example' proto using the dictionary above. 83 | parsed_example = tf.io.parse_single_example(example_proto, feature_description) 84 | image_tensor = tf.io.decode_jpeg(parsed_example["image_bytes"], channels=3) 85 | 86 | # RGB -> BGR (legacy reasons) 87 | image_tensor = tf.gather(image_tensor, axis=2, indices=[2, 1, 0]) 88 | 89 | # Nel TFRecord mettiamo solo gli indici per questioni di spazio 90 | # Emula MultiLabelBinarizer a partire dagli indici per ottenere un tensor di soli 0 e 1 91 | label_indexes = tf.sparse.to_dense( 92 | parsed_example["label_indexes"], 93 | default_value=0, 94 | ) 95 | one_hots = tf.one_hot(label_indexes, self.num_classes) 96 | labels = tf.reduce_max(one_hots, axis=0) 97 | labels = tf.cast(labels, tf.float32) 98 | 99 | return image_tensor, labels 100 | 101 | def random_flip(self, image, labels): 102 | image = tf.image.random_flip_left_right(image) 103 | return image, labels 104 | 105 | def random_crop(self, image, labels): 106 | image_shape = tf.shape(image) 107 | height = image_shape[0] 108 | width = image_shape[1] 109 | 110 | factor = tf.random.uniform(shape=[], minval=0.87, maxval=0.998) 111 | 112 | # Assuming this is a standard 512x512 Danbooru20xx SFW image 113 | new_height = new_width = tf.cast(tf.cast(height, tf.float32) * factor, tf.int32) 114 | 115 | offset_height = tf.random.uniform( 116 | shape=[], 117 | minval=0, 118 | maxval=(height - new_height), 119 | dtype=tf.int32, 120 | ) 121 | offset_width = tf.random.uniform( 122 | shape=[], 123 | minval=0, 124 | maxval=(width - new_width), 125 | dtype=tf.int32, 126 | ) 127 | image = tf.image.crop_to_bounding_box( 128 | image, 129 | offset_height, 130 | offset_width, 131 | new_height, 132 | new_width, 133 | ) 134 | return image, labels 135 | 136 | def random_rotate(self, images, masks, labels): 137 | bs, h, w, c = tf.unstack(tf.shape(images)) 138 | 139 | h = tf.cast(h, tf.float32) 140 | w = tf.cast(w, tf.float32) 141 | radians = np.pi * self.rotation_ratio 142 | radians = tf.random.uniform(shape=(bs,), minval=-radians, maxval=radians) 143 | 144 | cos_angles = tf.math.cos(radians) 145 | sin_angles = tf.math.sin(radians) 146 | x_offset = ((w - 1) - (cos_angles * (w - 1) - sin_angles * (h - 1))) / 2.0 147 | y_offset = ((h - 1) - (sin_angles * (w - 1) + cos_angles * (h - 1))) / 2.0 148 | zeros = tf.zeros((bs,), tf.float32) 149 | 150 | transforms = [ 151 | cos_angles, 152 | -sin_angles, 153 | x_offset, 154 | sin_angles, 155 | cos_angles, 156 | y_offset, 157 | zeros, 158 | zeros, 159 | ] 160 | transforms = tf.transpose(transforms, (1, 0)) 161 | images = tf.raw_ops.ImageProjectiveTransformV3( 162 | images=images, 163 | transforms=transforms, 164 | output_shape=[h, w], 165 | fill_value=255, 166 | interpolation="BILINEAR", 167 | fill_mode="CONSTANT", 168 | ) 169 | return images, masks, labels 170 | 171 | def resize(self, image, labels): 172 | if self.random_resize_method: 173 | # During training mix algos up to make the model a bit more more resilient 174 | # to the different image resizing implementations out there (TF, OpenCV, PIL, ...) 175 | method_index = tf.random.uniform((), minval=0, maxval=3, dtype=tf.int32) 176 | if method_index == 0: 177 | image = tf.image.resize( 178 | images=image, 179 | size=(self.image_size, self.image_size), 180 | method="area", 181 | antialias=True, 182 | ) 183 | elif method_index == 1: 184 | image = tf.image.resize( 185 | images=image, 186 | size=(self.image_size, self.image_size), 187 | method="bilinear", 188 | antialias=True, 189 | ) 190 | else: 191 | image = tf.image.resize( 192 | images=image, 193 | size=(self.image_size, self.image_size), 194 | method="bicubic", 195 | antialias=True, 196 | ) 197 | else: 198 | image = tf.image.resize( 199 | images=image, 200 | size=(self.image_size, self.image_size), 201 | method="area", 202 | antialias=True, 203 | ) 204 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) 205 | return image, labels 206 | 207 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 208 | def cutout(self, image, labels): 209 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image. 210 | This operation applies a (2*pad_size x 2*pad_size) mask of zeros to 211 | a random location within `img`. The pixel values filled in will be of the 212 | value `replace`. The located where the mask will be applied is randomly 213 | chosen uniformly over the whole image. 214 | Args: 215 | image: An image Tensor of type uint8. 216 | pad_size: Specifies how big the zero mask that will be generated is that 217 | is applied to the image. The mask will be of size 218 | (2*pad_size x 2*pad_size). 219 | replace: What pixel value to fill in the image in the area that has 220 | the cutout mask applied to it. 221 | Returns: 222 | An image Tensor that is of type uint8. 223 | """ 224 | pad_pct = self.cutout_max_pct 225 | replace = self.cutout_replace 226 | 227 | image_height = tf.shape(image)[0] 228 | image_width = tf.shape(image)[1] 229 | 230 | img_area = image_height * image_width 231 | pad_area = tf.cast(img_area, dtype=tf.float32) * pad_pct 232 | pad_size = tf.cast(tf.math.sqrt(pad_area) / 2, dtype=tf.int32) 233 | 234 | # Sample the center location in the image where the zero mask will be applied. 235 | cutout_center_height = tf.random.uniform( 236 | shape=[], 237 | minval=0, 238 | maxval=image_height, 239 | dtype=tf.int32, 240 | ) 241 | 242 | cutout_center_width = tf.random.uniform( 243 | shape=[], 244 | minval=0, 245 | maxval=image_width, 246 | dtype=tf.int32, 247 | ) 248 | 249 | lower_pad = tf.maximum(0, cutout_center_height - pad_size) 250 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) 251 | left_pad = tf.maximum(0, cutout_center_width - pad_size) 252 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) 253 | 254 | cutout_shape = [ 255 | image_height - (lower_pad + upper_pad), 256 | image_width - (left_pad + right_pad), 257 | ] 258 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 259 | mask = tf.pad( 260 | tf.zeros(cutout_shape, dtype=image.dtype), 261 | padding_dims, 262 | constant_values=1, 263 | ) 264 | mask = tf.expand_dims(mask, -1) 265 | mask = tf.tile(mask, [1, 1, 3]) 266 | image = tf.where( 267 | tf.equal(mask, 0), 268 | tf.ones_like(image, dtype=image.dtype) * replace, 269 | image, 270 | ) 271 | return image, labels 272 | 273 | def mixup_single(self, images, masks, labels): 274 | alpha = self.mixup_alpha 275 | batch_size = tf.shape(images)[0] 276 | 277 | # Unpack one dataset, generate a second 278 | # by shuffling the input one on the batch axis 279 | images_one = tf.cast(images, tf.float32) 280 | labels_one = tf.cast(labels, tf.float32) 281 | 282 | idxs = tf.random.shuffle(tf.range(batch_size)) 283 | images_two = tf.gather(images_one, idxs, axis=0) 284 | labels_two = tf.gather(labels_one, idxs, axis=0) 285 | 286 | # Sample lambda and reshape it to do the mixup 287 | l = sample_beta_distribution(batch_size, alpha, alpha) 288 | x_l = tf.reshape(l, (batch_size, 1, 1, 1)) 289 | y_l = tf.reshape(l, (batch_size, 1)) 290 | 291 | # Perform mixup on both images and labels by combining a pair of images/labels 292 | # (one from each dataset) into one image/label 293 | images = images_one * x_l + images_two * (1 - x_l) 294 | labels = labels_one * y_l + labels_two * (1 - y_l) 295 | 296 | images = tf.cast(tf.clip_by_value(images, 0, 255), tf.uint8) 297 | return images, masks, labels 298 | 299 | def genDS(self): 300 | files = tf.data.Dataset.list_files(self.records_path) 301 | files = files.cache() 302 | 303 | if self.repeat: 304 | files = files.repeat() 305 | 306 | dataset = files.interleave( 307 | tf.data.TFRecordDataset, 308 | num_parallel_calls=tf.data.AUTOTUNE, 309 | deterministic=False, 310 | ) 311 | dataset = dataset.ignore_errors() 312 | dataset = dataset.shuffle(2 * self.batch_size) 313 | dataset = dataset.map( 314 | self.parse_single_record, 315 | num_parallel_calls=tf.data.AUTOTUNE, 316 | ) 317 | 318 | if self.noise_level >= 1: 319 | dataset = dataset.map(self.random_flip, num_parallel_calls=tf.data.AUTOTUNE) 320 | dataset = dataset.map(self.random_crop, num_parallel_calls=tf.data.AUTOTUNE) 321 | 322 | # Resize before batching. Especially important if random_crop is enabled 323 | dataset = dataset.map(self.resize, num_parallel_calls=tf.data.AUTOTUNE) 324 | 325 | if self.noise_level >= 2 and self.cutout_max_pct > 0.0: 326 | for _ in range(self.cutout_patches): 327 | dataset = dataset.map(self.cutout, num_parallel_calls=tf.data.AUTOTUNE) 328 | 329 | dataset = dataset.map(self.gen_mask, num_parallel_calls=tf.data.AUTOTUNE) 330 | 331 | dataset = dataset.batch( 332 | self.batch_size, 333 | drop_remainder=True, 334 | num_parallel_calls=tf.data.AUTOTUNE, 335 | ) 336 | 337 | # Rotation is very slow on CPU. Rotating a batch of resized images is much faster 338 | if self.noise_level >= 1 and self.rotation_ratio > 0.0: 339 | dataset = dataset.map( 340 | self.random_rotate, 341 | num_parallel_calls=tf.data.AUTOTUNE, 342 | ) 343 | 344 | if self.noise_level >= 2 and self.mixup_alpha > 0.0: 345 | dataset = dataset.map( 346 | self.mixup_single, 347 | num_parallel_calls=tf.data.AUTOTUNE, 348 | ) 349 | 350 | if self.num_devices > 0: 351 | dataset = dataset.batch( 352 | self.num_devices, 353 | drop_remainder=True, 354 | num_parallel_calls=tf.data.AUTOTUNE, 355 | ) 356 | 357 | dataset = dataset.map( 358 | lambda images, masks, labels: ( 359 | { 360 | "images": tf.cast(images, tf.float32) * (1.0 / 127.5) - 1, 361 | "masks": masks, 362 | } 363 | ), 364 | num_parallel_calls=tf.data.AUTOTUNE, 365 | ) 366 | 367 | dataset = dataset.prefetch(tf.data.AUTOTUNE) 368 | return dataset 369 | -------------------------------------------------------------------------------- /Generators/WDTaggerGen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2): 6 | gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1) 7 | gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0) 8 | return gamma_1_sample / (gamma_1_sample + gamma_2_sample) 9 | 10 | 11 | class DataGenerator: 12 | def __init__( 13 | self, 14 | records_path, 15 | num_classes=2380, 16 | image_size=320, 17 | batch_size=32, 18 | num_devices=0, 19 | noise_level=0, 20 | mixup_alpha=0.2, 21 | rotation_ratio=0.25, 22 | cutout_max_pct=0.25, 23 | cutout_patches=1, 24 | random_resize_method=True, 25 | repeat=True, 26 | ): 27 | """ 28 | Noise level 1: augmentations I will never train without 29 | unless I'm dealing with extremely small networks 30 | (Random rotation, random cropping and random flipping) 31 | 32 | Noise level 2: more advanced stuff (MixUp) 33 | """ 34 | 35 | self.records_path = records_path 36 | self.num_classes = num_classes 37 | self.image_size = image_size 38 | self.batch_size = batch_size 39 | self.num_devices = num_devices 40 | self.noise_level = noise_level 41 | self.mixup_alpha = mixup_alpha 42 | self.rotation_ratio = rotation_ratio 43 | self.random_resize_method = random_resize_method 44 | 45 | self.cutout_max_pct = cutout_max_pct 46 | self.cutout_replace = 127 47 | self.cutout_patches = cutout_patches 48 | 49 | self.repeat = repeat 50 | 51 | def parse_single_record(self, example_proto): 52 | feature_description = { 53 | "image_id": tf.io.FixedLenFeature([], tf.int64), 54 | "image_bytes": tf.io.FixedLenFeature([], tf.string), 55 | "label_indexes": tf.io.VarLenFeature(tf.int64), 56 | } 57 | 58 | # Parse the input 'tf.train.Example' proto using the dictionary above. 59 | parsed_example = tf.io.parse_single_example(example_proto, feature_description) 60 | image_tensor = tf.io.decode_jpeg(parsed_example["image_bytes"], channels=3) 61 | 62 | # RGB -> BGR (legacy reasons) 63 | image_tensor = tf.gather(image_tensor, axis=2, indices=[2, 1, 0]) 64 | 65 | # Nel TFRecord mettiamo solo gli indici per questioni di spazio 66 | # Emula MultiLabelBinarizer a partire dagli indici per ottenere un tensor di soli 0 e 1 67 | label_indexes = tf.sparse.to_dense( 68 | parsed_example["label_indexes"], 69 | default_value=0, 70 | ) 71 | one_hots = tf.one_hot(label_indexes, self.num_classes) 72 | labels = tf.reduce_max(one_hots, axis=0) 73 | labels = tf.cast(labels, tf.float32) 74 | 75 | sample = { 76 | "image_ids": parsed_example["image_id"], 77 | "images": image_tensor, 78 | "labels": labels, 79 | } 80 | return sample 81 | 82 | def random_flip(self, sample): 83 | image = sample["images"] 84 | image = tf.image.random_flip_left_right(image) 85 | 86 | sample["images"] = image 87 | return sample 88 | 89 | def random_crop(self, sample): 90 | image = sample["images"] 91 | image_shape = tf.shape(image) 92 | height = image_shape[0] 93 | width = image_shape[1] 94 | 95 | factor = tf.random.uniform(shape=[], minval=0.87, maxval=0.998) 96 | 97 | # Assuming this is a standard 512x512 Danbooru20xx SFW image 98 | new_height = new_width = tf.cast(tf.cast(height, tf.float32) * factor, tf.int32) 99 | 100 | offset_height = tf.random.uniform( 101 | shape=[], 102 | minval=0, 103 | maxval=(height - new_height), 104 | dtype=tf.int32, 105 | ) 106 | offset_width = tf.random.uniform( 107 | shape=[], 108 | minval=0, 109 | maxval=(width - new_width), 110 | dtype=tf.int32, 111 | ) 112 | image = tf.image.crop_to_bounding_box( 113 | image, 114 | offset_height, 115 | offset_width, 116 | new_height, 117 | new_width, 118 | ) 119 | 120 | sample["images"] = image 121 | return sample 122 | 123 | def random_rotate(self, sample): 124 | images = sample["images"] 125 | bs, h, w, c = tf.unstack(tf.shape(images)) 126 | 127 | h = tf.cast(h, tf.float32) 128 | w = tf.cast(w, tf.float32) 129 | radians = np.pi * self.rotation_ratio 130 | radians = tf.random.uniform(shape=(bs,), minval=-radians, maxval=radians) 131 | 132 | cos_angles = tf.math.cos(radians) 133 | sin_angles = tf.math.sin(radians) 134 | x_offset = ((w - 1) - (cos_angles * (w - 1) - sin_angles * (h - 1))) / 2.0 135 | y_offset = ((h - 1) - (sin_angles * (w - 1) + cos_angles * (h - 1))) / 2.0 136 | zeros = tf.zeros((bs,), tf.float32) 137 | 138 | transforms = [ 139 | cos_angles, 140 | -sin_angles, 141 | x_offset, 142 | sin_angles, 143 | cos_angles, 144 | y_offset, 145 | zeros, 146 | zeros, 147 | ] 148 | transforms = tf.transpose(transforms, (1, 0)) 149 | images = tf.raw_ops.ImageProjectiveTransformV3( 150 | images=images, 151 | transforms=transforms, 152 | output_shape=[h, w], 153 | fill_value=255, 154 | interpolation="BILINEAR", 155 | fill_mode="CONSTANT", 156 | ) 157 | 158 | sample["images"] = images 159 | return sample 160 | 161 | def resize(self, sample): 162 | image = sample["images"] 163 | if self.random_resize_method: 164 | # During training mix algos up to make the model a bit more more resilient 165 | # to the different image resizing implementations out there (TF, OpenCV, PIL, ...) 166 | method_index = tf.random.uniform((), minval=0, maxval=3, dtype=tf.int32) 167 | if method_index == 0: 168 | image = tf.image.resize( 169 | images=image, 170 | size=(self.image_size, self.image_size), 171 | method="area", 172 | antialias=True, 173 | ) 174 | elif method_index == 1: 175 | image = tf.image.resize( 176 | images=image, 177 | size=(self.image_size, self.image_size), 178 | method="bilinear", 179 | antialias=True, 180 | ) 181 | else: 182 | image = tf.image.resize( 183 | images=image, 184 | size=(self.image_size, self.image_size), 185 | method="bicubic", 186 | antialias=True, 187 | ) 188 | else: 189 | image = tf.image.resize( 190 | images=image, 191 | size=(self.image_size, self.image_size), 192 | method="area", 193 | antialias=True, 194 | ) 195 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) 196 | 197 | sample["images"] = image 198 | return sample 199 | 200 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 201 | def cutout(self, sample): 202 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image. 203 | This operation applies a (2*pad_size x 2*pad_size) mask of zeros to 204 | a random location within `img`. The pixel values filled in will be of the 205 | value `replace`. The located where the mask will be applied is randomly 206 | chosen uniformly over the whole image. 207 | Args: 208 | image: An image Tensor of type uint8. 209 | pad_size: Specifies how big the zero mask that will be generated is that 210 | is applied to the image. The mask will be of size 211 | (2*pad_size x 2*pad_size). 212 | replace: What pixel value to fill in the image in the area that has 213 | the cutout mask applied to it. 214 | Returns: 215 | An image Tensor that is of type uint8. 216 | """ 217 | image = sample["images"] 218 | pad_pct = self.cutout_max_pct 219 | replace = self.cutout_replace 220 | 221 | image_height = tf.shape(image)[0] 222 | image_width = tf.shape(image)[1] 223 | 224 | img_area = image_height * image_width 225 | pad_area = tf.cast(img_area, dtype=tf.float32) * pad_pct 226 | pad_size = tf.cast(tf.math.sqrt(pad_area) / 2, dtype=tf.int32) 227 | 228 | # Sample the center location in the image where the zero mask will be applied. 229 | cutout_center_height = tf.random.uniform( 230 | shape=[], 231 | minval=0, 232 | maxval=image_height, 233 | dtype=tf.int32, 234 | ) 235 | 236 | cutout_center_width = tf.random.uniform( 237 | shape=[], 238 | minval=0, 239 | maxval=image_width, 240 | dtype=tf.int32, 241 | ) 242 | 243 | lower_pad = tf.maximum(0, cutout_center_height - pad_size) 244 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) 245 | left_pad = tf.maximum(0, cutout_center_width - pad_size) 246 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) 247 | 248 | cutout_shape = [ 249 | image_height - (lower_pad + upper_pad), 250 | image_width - (left_pad + right_pad), 251 | ] 252 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 253 | mask = tf.pad( 254 | tf.zeros(cutout_shape, dtype=image.dtype), 255 | padding_dims, 256 | constant_values=1, 257 | ) 258 | mask = tf.expand_dims(mask, -1) 259 | mask = tf.tile(mask, [1, 1, 3]) 260 | image = tf.where( 261 | tf.equal(mask, 0), 262 | tf.ones_like(image, dtype=image.dtype) * replace, 263 | image, 264 | ) 265 | 266 | sample["images"] = image 267 | return sample 268 | 269 | def mixup_single(self, sample): 270 | images = sample["images"] 271 | labels = sample["labels"] 272 | alpha = self.mixup_alpha 273 | batch_size = tf.shape(images)[0] 274 | 275 | # Unpack one dataset, generate a second 276 | # by shuffling the input one on the batch axis 277 | images_one = tf.cast(images, tf.float32) 278 | labels_one = tf.cast(labels, tf.float32) 279 | 280 | idxs = tf.random.shuffle(tf.range(batch_size)) 281 | images_two = tf.gather(images_one, idxs, axis=0) 282 | labels_two = tf.gather(labels_one, idxs, axis=0) 283 | 284 | # Sample lambda and reshape it to do the mixup 285 | l = sample_beta_distribution(batch_size, alpha, alpha) 286 | x_l = tf.reshape(l, (batch_size, 1, 1, 1)) 287 | y_l = tf.reshape(l, (batch_size, 1)) 288 | 289 | # Perform mixup on both images and labels by combining a pair of images/labels 290 | # (one from each dataset) into one image/label 291 | images = images_one * x_l + images_two * (1 - x_l) 292 | labels = labels_one * y_l + labels_two * (1 - y_l) 293 | 294 | images = tf.cast(tf.clip_by_value(images, 0, 255), tf.uint8) 295 | 296 | sample["images"] = images 297 | sample["labels"] = labels 298 | return sample 299 | 300 | def genDS(self): 301 | files = tf.data.Dataset.list_files(self.records_path) 302 | files = files.cache() 303 | 304 | if self.repeat: 305 | files = files.repeat() 306 | 307 | dataset = files.interleave( 308 | tf.data.TFRecordDataset, 309 | num_parallel_calls=tf.data.AUTOTUNE, 310 | deterministic=False, 311 | ) 312 | dataset = dataset.ignore_errors() 313 | dataset = dataset.shuffle(2 * self.batch_size) 314 | dataset = dataset.map( 315 | self.parse_single_record, 316 | num_parallel_calls=tf.data.AUTOTUNE, 317 | ) 318 | 319 | if self.noise_level >= 1: 320 | dataset = dataset.map(self.random_flip, num_parallel_calls=tf.data.AUTOTUNE) 321 | dataset = dataset.map(self.random_crop, num_parallel_calls=tf.data.AUTOTUNE) 322 | 323 | # Resize before batching. Especially important if random_crop is enabled 324 | dataset = dataset.map(self.resize, num_parallel_calls=tf.data.AUTOTUNE) 325 | 326 | if self.noise_level >= 2 and self.cutout_max_pct > 0.0: 327 | for _ in range(self.cutout_patches): 328 | dataset = dataset.map(self.cutout, num_parallel_calls=tf.data.AUTOTUNE) 329 | 330 | dataset = dataset.batch( 331 | self.batch_size, 332 | drop_remainder=True, 333 | num_parallel_calls=tf.data.AUTOTUNE, 334 | ) 335 | 336 | # Rotation is very slow on CPU. Rotating a batch of resized images is much faster 337 | if self.noise_level >= 1 and self.rotation_ratio > 0.0: 338 | dataset = dataset.map( 339 | self.random_rotate, 340 | num_parallel_calls=tf.data.AUTOTUNE, 341 | ) 342 | 343 | if self.noise_level >= 2 and self.mixup_alpha > 0.0: 344 | dataset = dataset.map( 345 | self.mixup_single, 346 | num_parallel_calls=tf.data.AUTOTUNE, 347 | ) 348 | 349 | if self.num_devices > 0: 350 | dataset = dataset.batch( 351 | self.num_devices, 352 | drop_remainder=True, 353 | num_parallel_calls=tf.data.AUTOTUNE, 354 | ) 355 | 356 | dataset = dataset.map( 357 | lambda sample: ( 358 | { 359 | "image_ids": sample["image_ids"], 360 | "images": tf.cast(sample["images"], tf.float32) * (1.0 / 127.5) - 1, 361 | "labels": sample["labels"], 362 | } 363 | ), 364 | num_parallel_calls=tf.data.AUTOTUNE, 365 | ) 366 | 367 | dataset = dataset.prefetch(tf.data.AUTOTUNE) 368 | return dataset 369 | -------------------------------------------------------------------------------- /Metrics/ConfusionMatrix.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax.numpy as jnp 3 | from clu import metrics 4 | 5 | 6 | @flax.struct.dataclass 7 | class ConfusionMatrix(metrics.Metric): 8 | true_positives: jnp.array 9 | true_negatives: jnp.array 10 | false_positives: jnp.array 11 | false_negatives: jnp.array 12 | 13 | @classmethod 14 | def empty(cls, averaging, num_classes): 15 | if averaging == "micro": 16 | shape = (1,) 17 | elif averaging == "macro": 18 | shape = (num_classes,) 19 | 20 | return cls( 21 | true_positives=jnp.zeros(shape, dtype=jnp.int32), 22 | true_negatives=jnp.zeros(shape, dtype=jnp.int32), 23 | false_positives=jnp.zeros(shape, dtype=jnp.int32), 24 | false_negatives=jnp.zeros(shape, dtype=jnp.int32), 25 | ) 26 | 27 | @classmethod 28 | def from_model_output( 29 | cls, 30 | *, 31 | logits: jnp.array, 32 | labels: jnp.array, 33 | from_logits: bool, 34 | threshold: float, 35 | averaging: str, 36 | **_ 37 | ): 38 | preds = logits 39 | 40 | if from_logits: 41 | preds = flax.linen.activation.sigmoid(preds) 42 | 43 | labels = labels > threshold 44 | preds = preds > threshold 45 | 46 | if averaging == "micro": 47 | axis = None 48 | elif averaging == "macro": 49 | axis = 0 50 | 51 | return cls( 52 | true_positives=((preds == 1) & (labels == 1)).sum(axis=axis), 53 | true_negatives=((preds == 0) & (labels == 0)).sum(axis=axis), 54 | false_positives=((preds == 1) & (labels == 0)).sum(axis=axis), 55 | false_negatives=((preds == 0) & (labels == 1)).sum(axis=axis), 56 | ) 57 | 58 | def merge(self, other: metrics.Metric): 59 | return type(self)( 60 | true_positives=self.true_positives + other.true_positives, 61 | true_negatives=self.true_negatives + other.true_negatives, 62 | false_positives=self.false_positives + other.false_positives, 63 | false_negatives=self.false_negatives + other.false_negatives, 64 | ) 65 | 66 | def compute(self): 67 | print("Must override compute()") 68 | raise NotImplementedError 69 | 70 | 71 | def mcc(threshold, num_classes, from_logits, averaging): 72 | @flax.struct.dataclass 73 | class MCC(ConfusionMatrix): 74 | """ 75 | Computes the Matthews correlation coefficient 76 | from model outputs 'logits' and 'labels'. 77 | 78 | The used formula helps avoiding overflow 79 | https://leimao.github.io/blog/Matthews-Correlation-Coefficient/ 80 | """ 81 | 82 | @classmethod 83 | def empty(cls): 84 | return super().empty(averaging, num_classes) 85 | 86 | @classmethod 87 | def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_): 88 | return super().from_model_output( 89 | logits=logits, 90 | labels=labels, 91 | from_logits=from_logits, 92 | threshold=threshold, 93 | averaging=averaging, 94 | ) 95 | 96 | def compute(self): 97 | N = ( 98 | self.true_positives 99 | + self.false_negatives 100 | + self.false_positives 101 | + self.true_negatives 102 | ) 103 | S = (self.true_positives + self.false_negatives) / N 104 | P = (self.true_positives + self.false_positives) / N 105 | numerator = (self.true_positives / N) - (S * P) 106 | denominator = S * P * (1 - S) * (1 - P) 107 | denominator = jnp.maximum(denominator, 1e-12) 108 | denominator = jnp.sqrt(denominator) 109 | return jnp.mean(numerator / denominator) 110 | 111 | return MCC 112 | 113 | 114 | def f1score(threshold, num_classes, from_logits, averaging): 115 | @flax.struct.dataclass 116 | class F1Score(ConfusionMatrix): 117 | """ 118 | Computes the F1 score 119 | from model outputs 'logits' and 'labels'. 120 | """ 121 | 122 | @classmethod 123 | def empty(cls): 124 | return super().empty(averaging, num_classes) 125 | 126 | @classmethod 127 | def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_): 128 | return super().from_model_output( 129 | logits=logits, 130 | labels=labels, 131 | from_logits=from_logits, 132 | threshold=threshold, 133 | averaging=averaging, 134 | ) 135 | 136 | def compute(self): 137 | numerator = 2 * self.true_positives 138 | denominator = ( 139 | (2 * self.true_positives) + self.false_positives + self.false_negatives 140 | ) 141 | 142 | idx = jnp.where(denominator == 0) 143 | numerator = numerator.at[idx].set(1) 144 | denominator = denominator.at[idx].set(1) 145 | 146 | return jnp.mean(numerator / denominator) 147 | 148 | return F1Score 149 | -------------------------------------------------------------------------------- /Metrics/Precision.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax.numpy as jnp 3 | from clu import metrics 4 | 5 | 6 | @flax.struct.dataclass 7 | class Precision(metrics.Metric): 8 | """ 9 | Computes the micro-averaged precision 10 | from model outputs 'logits' and 'labels'. 11 | """ 12 | 13 | @classmethod 14 | def with_config(cls, threshold: float, from_logits: bool): 15 | @flax.struct.dataclass 16 | class WithConfig(cls): 17 | true_positives: jnp.array 18 | pred_positives: jnp.array 19 | 20 | @classmethod 21 | def empty(cls): 22 | return cls( 23 | true_positives=jnp.array(0, jnp.int32), 24 | pred_positives=jnp.array(0, jnp.int32), 25 | ) 26 | 27 | @classmethod 28 | def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_): 29 | preds = logits 30 | 31 | if from_logits: 32 | preds = flax.linen.activation.sigmoid(preds) 33 | 34 | labels = labels > threshold 35 | preds = preds > threshold 36 | 37 | return cls( 38 | true_positives=((preds == 1) & (labels == 1)).sum(), 39 | pred_positives=(preds == 1).sum(), 40 | ) 41 | 42 | def merge(self, other: metrics.Metric) -> metrics.Metric: 43 | return type(self)( 44 | true_positives=self.true_positives + other.true_positives, 45 | pred_positives=self.pred_positives + other.pred_positives, 46 | ) 47 | 48 | def compute(self): 49 | return self.true_positives / self.pred_positives 50 | 51 | return WithConfig 52 | -------------------------------------------------------------------------------- /Metrics/Recall.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax.numpy as jnp 3 | from clu import metrics 4 | 5 | 6 | @flax.struct.dataclass 7 | class Recall(metrics.Metric): 8 | """ 9 | Computes the micro-averaged recall 10 | from model outputs 'logits' and 'labels'. 11 | """ 12 | 13 | @classmethod 14 | def with_config(cls, threshold: float, from_logits: bool): 15 | @flax.struct.dataclass 16 | class WithConfig(cls): 17 | true_positives: jnp.array 18 | false_negatives: jnp.array 19 | 20 | @classmethod 21 | def empty(cls): 22 | return cls( 23 | true_positives=jnp.array(0, jnp.int32), 24 | false_negatives=jnp.array(0, jnp.int32), 25 | ) 26 | 27 | @classmethod 28 | def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_): 29 | preds = logits 30 | 31 | if from_logits: 32 | preds = flax.linen.activation.sigmoid(preds) 33 | 34 | labels = labels > threshold 35 | preds = preds > threshold 36 | 37 | return cls( 38 | true_positives=((preds == 1) & (labels == 1)).sum(), 39 | false_negatives=((preds == 0) & (labels == 1)).sum(), 40 | ) 41 | 42 | def merge(self, other: metrics.Metric) -> metrics.Metric: 43 | return type(self)( 44 | true_positives=self.true_positives + other.true_positives, 45 | false_negatives=self.false_negatives + other.false_negatives, 46 | ) 47 | 48 | def compute(self): 49 | return self.true_positives / ( 50 | self.true_positives + self.false_negatives 51 | ) 52 | 53 | return WithConfig 54 | -------------------------------------------------------------------------------- /Models/ConvNext.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Callable 4 | 5 | import jax.numpy as jnp 6 | import jax.typing as jt 7 | import numpy as np 8 | from flax import linen 9 | 10 | 11 | class LayerScale(linen.Module): 12 | dim: int 13 | layer_scale_init_value: float = 1e-6 14 | 15 | dtype: jt.DTypeLike = jnp.float32 16 | 17 | def setup(self): 18 | self.gamma = self.variable( 19 | "params", 20 | "gamma", 21 | lambda x: self.layer_scale_init_value * jnp.ones((x,)), 22 | self.dim, 23 | ).value 24 | 25 | def __call__(self, x): 26 | return x * self.gamma.astype(self.dtype) 27 | 28 | 29 | class ConvNextBlock(linen.Module): 30 | drop_path_ratio: float 31 | 32 | bottleneck_ratio: float = 4.0 33 | layer_scale_init_value: float = 1e-6 34 | use_conv_bias: bool = True 35 | 36 | norm_layer: Callable = linen.LayerNorm 37 | 38 | dtype: jt.DTypeLike = jnp.float32 39 | 40 | @linen.compact 41 | def __call__(self, x, train: bool = False): 42 | _, _, _, C = x.shape 43 | hidden_size = int(C * self.bottleneck_ratio) 44 | kernel_init = linen.initializers.truncated_normal(0.02) 45 | 46 | shortcut = x 47 | x = linen.Conv( 48 | features=C, 49 | kernel_size=(7, 7), 50 | feature_group_count=C, 51 | kernel_init=kernel_init, 52 | use_bias=self.use_conv_bias, 53 | dtype=self.dtype, 54 | )(x) 55 | x = self.norm_layer()(x) 56 | x = linen.Conv( 57 | features=hidden_size, 58 | kernel_size=(1, 1), 59 | kernel_init=kernel_init, 60 | use_bias=self.use_conv_bias, 61 | dtype=self.dtype, 62 | )(x) 63 | x = linen.gelu(x) 64 | x = linen.Conv( 65 | features=C, 66 | kernel_size=(1, 1), 67 | kernel_init=kernel_init, 68 | use_bias=self.use_conv_bias, 69 | dtype=self.dtype, 70 | )(x) 71 | x = LayerScale( 72 | dim=C, 73 | layer_scale_init_value=self.layer_scale_init_value, 74 | dtype=self.dtype, 75 | )(x) 76 | x = linen.Dropout( 77 | rate=self.drop_path_ratio, 78 | broadcast_dims=(1, 2, 3), 79 | )(x, deterministic=not train) 80 | x = shortcut + x 81 | return x 82 | 83 | 84 | class BasicLayer(linen.Module): 85 | depth: int 86 | embed_dim: int 87 | 88 | drop_path_ratio: tuple[float, ...] 89 | 90 | downsample: bool = True 91 | bottleneck_ratio: float = 4.0 92 | layer_scale_init_value: float = 1e-6 93 | use_conv_bias: bool = True 94 | 95 | norm_layer: Callable = linen.LayerNorm 96 | 97 | dtype: jt.DTypeLike = jnp.float32 98 | 99 | @linen.compact 100 | def __call__(self, x, train: bool = False): 101 | if self.downsample: 102 | kernel_init = linen.initializers.truncated_normal(0.02) 103 | 104 | x = self.norm_layer()(x) 105 | x = linen.Conv( 106 | features=self.embed_dim, 107 | kernel_size=(2, 2), 108 | strides=(2, 2), 109 | kernel_init=kernel_init, 110 | use_bias=self.use_conv_bias, 111 | dtype=self.dtype, 112 | )(x) 113 | 114 | for i in range(self.depth): 115 | x = ConvNextBlock( 116 | drop_path_ratio=self.drop_path_ratio[i], 117 | bottleneck_ratio=self.bottleneck_ratio, 118 | layer_scale_init_value=self.layer_scale_init_value, 119 | use_conv_bias=self.use_conv_bias, 120 | norm_layer=self.norm_layer, 121 | dtype=self.dtype, 122 | )(x, train=train) 123 | return x 124 | 125 | 126 | class PatchEmbed(linen.Module): 127 | r"""Image to Patch Embedding 128 | 129 | Args: 130 | patch_size (int): Patch token size. Default: 4. 131 | embed_dim (int): Number of linear projection output channels. Default: 96. 132 | norm_layer (nn.Module, optional): Normalization layer. Default: None 133 | """ 134 | 135 | patch_size: int = 4 136 | embed_dim: int = 96 137 | use_conv_bias: bool = True 138 | norm_layer: Callable = linen.LayerNorm 139 | dtype: jt.DTypeLike = jnp.float32 140 | 141 | @linen.compact 142 | def __call__(self, x): 143 | B, _, _, _ = x.shape 144 | patch_size = (self.patch_size, self.patch_size) 145 | 146 | kernel_init = linen.initializers.truncated_normal(0.02) 147 | x = linen.Conv( 148 | self.embed_dim, 149 | kernel_size=patch_size, 150 | strides=patch_size, 151 | kernel_init=kernel_init, 152 | use_bias=self.use_conv_bias, 153 | dtype=self.dtype, 154 | )(x) 155 | x = self.norm_layer()(x) 156 | return x 157 | 158 | 159 | # Cfr. arXiv:2103.17239. Found this to work better 160 | # than the paper default (1e-6) especially for the tiny variant. 161 | def cait_layer_scale_eps(depth): 162 | if depth <= 18: 163 | return 0.1 164 | elif depth <= 24: 165 | return 1e-4 166 | else: 167 | return 1e-5 168 | 169 | 170 | class ConvNext(linen.Module): 171 | image_size: int = 224 172 | patch_size: int = 4 173 | num_classes: int = 1000 174 | 175 | depths: tuple[int, ...] = (3, 3, 27, 3) 176 | embed_dims: tuple[int, ...] = (128, 256, 512, 1024) 177 | 178 | drop_path_rate: float = 0.1 179 | 180 | use_norm_bias: bool = True 181 | use_conv_bias: bool = True 182 | 183 | norm_layer: Callable = linen.LayerNorm 184 | 185 | layer_norm_eps: float = 1e-6 186 | dtype: jt.DTypeLike = jnp.float32 187 | 188 | def setup(self): 189 | depths = self.depths 190 | num_layers = len(depths) 191 | norm_layer = partial( 192 | self.norm_layer, 193 | use_bias=self.use_norm_bias, 194 | epsilon=self.layer_norm_eps, 195 | dtype=self.dtype, 196 | ) 197 | 198 | layer_scale_init_value = cait_layer_scale_eps(sum(depths)) 199 | 200 | # split image into non-overlapping patches 201 | self.patch_embed = PatchEmbed( 202 | patch_size=self.patch_size, 203 | embed_dim=self.embed_dims[0], 204 | use_conv_bias=self.use_conv_bias, 205 | norm_layer=norm_layer, 206 | dtype=self.dtype, 207 | ) 208 | 209 | # stochastic depth with linear decay 210 | dpr = np.linspace(0, self.drop_path_rate, sum(depths)) 211 | dpr = [float(x) for x in dpr] 212 | 213 | # build layers 214 | convnext_body = [] 215 | for i_layer in range(num_layers): 216 | dpr_slice = tuple(dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])]) 217 | layer = BasicLayer( 218 | depth=depths[i_layer], 219 | embed_dim=self.embed_dims[i_layer], 220 | drop_path_ratio=dpr_slice, 221 | downsample=i_layer > 0, 222 | layer_scale_init_value=layer_scale_init_value, 223 | use_conv_bias=self.use_conv_bias, 224 | norm_layer=norm_layer, 225 | dtype=self.dtype, 226 | ) 227 | convnext_body.append(layer) 228 | self.convnext_body = convnext_body 229 | 230 | self.norm = norm_layer() 231 | self.head = ( 232 | linen.Dense(self.num_classes, dtype=self.dtype) 233 | if self.num_classes > 0 234 | else lambda x: x 235 | ) 236 | 237 | def __call__(self, x, train: bool = False): 238 | x = self.patch_embed(x) 239 | 240 | for layer in self.convnext_body: 241 | x = layer(x, train=train) 242 | 243 | x = jnp.mean(x, axis=(1, 2)) 244 | x = self.norm(x) 245 | x = self.head(x) 246 | return x 247 | 248 | @classmethod 249 | def build(cls, config, **kwargs): 250 | config = dataclasses.asdict(config) 251 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 252 | return cls(**config) 253 | 254 | def extend_parser(self, parser): 255 | parser.set_defaults(image_size=self.image_size) 256 | parser.set_defaults(patch_size=self.patch_size) 257 | parser.add_argument( 258 | "--drop-path-rate", 259 | default=self.drop_path_rate, 260 | help="Stochastic depth rate", 261 | type=float, 262 | ) 263 | 264 | parser.add_argument( 265 | "--enable-conv-bias", 266 | dest="use_conv_bias", 267 | help="Enable conv layers bias", 268 | action="store_true", 269 | ) 270 | parser.add_argument( 271 | "--disable-conv-bias", 272 | dest="use_conv_bias", 273 | help="Disable conv layers bias", 274 | action="store_false", 275 | ) 276 | parser.set_defaults(use_conv_bias=self.use_conv_bias) 277 | 278 | parser.add_argument( 279 | "--enable-norm-bias", 280 | dest="use_norm_bias", 281 | help="Enable norm layers bias", 282 | action="store_true", 283 | ) 284 | parser.add_argument( 285 | "--disable-norm-bias", 286 | dest="use_norm_bias", 287 | help="Disable norm layers bias", 288 | action="store_false", 289 | ) 290 | parser.set_defaults(use_norm_bias=self.use_norm_bias) 291 | return parser 292 | 293 | @staticmethod 294 | def get_simmim_orbax_txs(): 295 | # SimMIM checkpoint have no head params - don't try to restore them. 296 | # All the other params we care about are under the "encoder" subsection 297 | regex = r"(?!model/params/head)model/params/(.*)" 298 | action = r"model/params/encoder/\1" 299 | return [(regex, action)] 300 | 301 | def should_decay(self, path, _): 302 | is_kernel = path[-1].key == "kernel" 303 | verdict = is_kernel 304 | return verdict 305 | 306 | 307 | def convnext_tiny(): 308 | config = { 309 | "embed_dims": (96, 192, 384, 768), 310 | "depths": (3, 3, 9, 3), 311 | } 312 | return ConvNext(**config) 313 | 314 | 315 | def convnext_small(): 316 | config = { 317 | "embed_dims": (96, 192, 384, 768), 318 | "depths": (3, 3, 27, 3), 319 | } 320 | return ConvNext(**config) 321 | 322 | 323 | def convnext_base(): 324 | config = { 325 | "embed_dims": (128, 256, 512, 1024), 326 | "depths": (3, 3, 27, 3), 327 | } 328 | return ConvNext(**config) 329 | -------------------------------------------------------------------------------- /Models/EVA02.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Callable, Optional 4 | 5 | import einops 6 | import jax.numpy as jnp 7 | import jax.typing as jt 8 | import numpy as np 9 | from flax import linen 10 | 11 | 12 | class LayerNorm(linen.Module): 13 | epsilon: float = 1e-6 14 | use_bias: bool = True 15 | force_float32_reductions: bool = True 16 | 17 | dtype: jt.DTypeLike = jnp.float32 18 | 19 | @linen.compact 20 | def __call__(self, x): 21 | scale = self.param("scale", linen.initializers.zeros_init(), (x.shape[-1])) 22 | 23 | dtype = self.dtype 24 | if self.force_float32_reductions: 25 | dtype = jnp.promote_types(dtype, jnp.float32) 26 | x = x.astype(dtype) 27 | 28 | mean = jnp.mean(x, axis=-1, keepdims=True) 29 | 30 | var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) 31 | var = jnp.maximum(0.0, var - jnp.square(mean)) 32 | mul = jnp.reciprocal(jnp.sqrt(var + self.epsilon)) 33 | 34 | centered_inputs = x - mean 35 | normed_inputs = centered_inputs * mul 36 | 37 | scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1)) 38 | normed_inputs = normed_inputs * (1 + scale) 39 | if self.use_bias: 40 | bias = self.param("bias", linen.initializers.zeros_init(), (x.shape[-1])) 41 | bias = jnp.expand_dims(bias, axis=range(len(x.shape) - 1)) 42 | normed_inputs = normed_inputs + bias 43 | return normed_inputs.astype(self.dtype) 44 | 45 | 46 | class VisionRotaryEmbeddingFast(linen.Module): 47 | """Apply Rotary Position Embeddings (RoPE) 48 | 49 | Most of the code comes from the original repo: 50 | https://github.com/baaivision/EVA/blob/master/EVA-02/asuka/rope.py 51 | """ 52 | 53 | dim: int 54 | seq_len: int = 16 55 | theta: int = 10000 56 | 57 | @staticmethod 58 | def broadcat(tensors, dim=-1): 59 | num_tensors = len(tensors) 60 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 61 | assert ( 62 | len(shape_lens) == 1 63 | ), "tensors must all have the same number of dimensions" 64 | shape_len = list(shape_lens)[0] 65 | dim = (dim + shape_len) if dim < 0 else dim 66 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 67 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 68 | assert all( 69 | [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] 70 | ), "invalid dimensions for broadcastable concatentation" 71 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 72 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 73 | expanded_dims.insert(dim, (dim, dims[dim])) 74 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 75 | tensors = list( 76 | map(lambda t: np.broadcast_to(t[0], t[1]), zip(tensors, expandable_shapes)) 77 | ) 78 | return np.concatenate(tensors, axis=dim) 79 | 80 | @staticmethod 81 | def rotate_half(x): 82 | x = einops.rearrange(x, "... (d r) -> ... d r", r=2) 83 | x1, x2 = x[..., 0], x[..., 1] 84 | x = jnp.stack((-x2, x1), axis=-1) 85 | x = einops.rearrange(x, "... d r -> ... (d r)") 86 | return x 87 | 88 | def setup(self): 89 | exp = np.arange(0, self.dim, 2) / -self.dim 90 | freqs = self.theta**exp 91 | 92 | t = np.arange(self.seq_len) 93 | 94 | freqs = np.einsum("..., f -> ... f", t, freqs) 95 | freqs = einops.repeat(freqs, "... n -> ... (n r)", r=2) 96 | freqs = self.broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 97 | 98 | freqs_cos = np.reshape(np.cos(freqs), (-1, freqs.shape[-1])) 99 | freqs_sin = np.reshape(np.sin(freqs), (-1, freqs.shape[-1])) 100 | 101 | self.freqs_cos = self.variable( 102 | "eva02_constants", 103 | "freqs_cos", 104 | lambda: np.float32(freqs_cos), 105 | ).value 106 | self.freqs_sin = self.variable( 107 | "eva02_constants", 108 | "freqs_sin", 109 | lambda: np.float32(freqs_sin), 110 | ).value 111 | 112 | def __call__(self, x): 113 | return x * self.freqs_cos + self.rotate_half(x) * self.freqs_sin 114 | 115 | 116 | class Attention(linen.Module): 117 | num_extra_tokens: int 118 | 119 | dim: int 120 | num_heads: int 121 | 122 | rope: Callable 123 | 124 | qkv_bias: bool = True 125 | proj_bias: bool = True 126 | 127 | attn_drop_ratio: float = 0.0 128 | proj_drop_ratio: float = 0.0 129 | 130 | dtype: jt.DTypeLike = jnp.float32 131 | 132 | def setup(self): 133 | self.qkv = linen.Dense(self.dim * 3, use_bias=self.qkv_bias, dtype=self.dtype) 134 | self.attn_drop = linen.Dropout(self.attn_drop_ratio) 135 | self.proj = linen.Dense(self.dim, use_bias=self.proj_bias, dtype=self.dtype) 136 | self.proj_drop = linen.Dropout(self.proj_drop_ratio) 137 | self.softmax = partial(linen.softmax, axis=-1) 138 | 139 | def __call__(self, x, train: bool = False): 140 | B, N, C = x.shape 141 | qkv = self.qkv(x) 142 | qkv = jnp.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) 143 | qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) 144 | 145 | q, k, v = (qkv[0], qkv[1], qkv[2]) 146 | 147 | q_xtr = q[:, :, : self.num_extra_tokens, :] 148 | q_seq = q[:, :, self.num_extra_tokens :, :] 149 | q_seq = self.rope(q_seq).astype(v.dtype) 150 | q = jnp.concatenate([q_xtr, q_seq], axis=2) 151 | 152 | k_xtr = k[:, :, : self.num_extra_tokens, :] 153 | k_seq = k[:, :, self.num_extra_tokens :, :] 154 | k_seq = self.rope(k_seq).astype(v.dtype) 155 | k = jnp.concatenate([k_xtr, k_seq], axis=2) 156 | 157 | q = q / jnp.sqrt(q.shape[-1]).astype(q.dtype) 158 | attn = q @ jnp.transpose(k, (0, 1, 3, 2)) 159 | 160 | attn = self.softmax(attn.astype(jnp.float32)).astype(self.dtype) 161 | attn = self.attn_drop(attn, deterministic=not train) 162 | 163 | x = jnp.transpose(attn @ v, (0, 2, 1, 3)) 164 | x = jnp.reshape((x), (B, N, C)) 165 | x = self.proj(x) 166 | x = self.proj_drop(x, deterministic=not train) 167 | return x 168 | 169 | 170 | class SwiGLU(linen.Module): 171 | hidden_features: int 172 | scale_mlp: bool 173 | norm_layer: Callable 174 | 175 | use_bias: bool = True 176 | 177 | act_layer: Callable = linen.silu 178 | drop_ratio: float = 0.0 179 | 180 | dtype: jt.DTypeLike = jnp.float32 181 | 182 | @linen.compact 183 | def __call__(self, x, train: bool): 184 | out_dim = x.shape[-1] 185 | 186 | x = linen.Dense( 187 | self.hidden_features * 2, 188 | use_bias=self.use_bias, 189 | dtype=self.dtype, 190 | )(x) 191 | x1 = x[..., : self.hidden_features] 192 | x2 = x[..., self.hidden_features :] 193 | 194 | x = self.act_layer(x1) * x2 195 | x = linen.Dropout(self.drop_ratio)(x, deterministic=not train) 196 | 197 | if self.scale_mlp: 198 | x = self.norm_layer()(x) 199 | 200 | x = linen.Dense(out_dim, use_bias=self.use_bias, dtype=self.dtype)(x) 201 | return x 202 | 203 | 204 | class PosEmbed(linen.Module): 205 | dtype: jt.DTypeLike = jnp.float32 206 | 207 | @linen.compact 208 | def __call__(self, x): 209 | _, L, C = x.shape 210 | pos_emb_init = linen.initializers.normal(stddev=1 / np.sqrt(C)) 211 | pos_emb = self.param("pos_emb", pos_emb_init, (1, L, C)) 212 | pos_emb = pos_emb.astype(self.dtype) 213 | x = x + pos_emb 214 | return x 215 | 216 | 217 | class PatchEmbed(linen.Module): 218 | r"""Image to Patch Embedding 219 | 220 | Args: 221 | patch_size (int): Patch token size. Default: 16. 222 | embed_dim (int): Number of linear projection output channels. Default: 96. 223 | norm_layer (nn.Module, optional): Normalization layer. Default: None 224 | """ 225 | 226 | patch_size: int = 16 227 | embed_dim: int = 768 228 | 229 | use_bias: bool = True 230 | 231 | norm_layer: Optional[Callable] = None 232 | 233 | dtype: jt.DTypeLike = jnp.float32 234 | 235 | @linen.compact 236 | def __call__(self, x): 237 | B, _, _, _ = x.shape 238 | patch_size = (self.patch_size, self.patch_size) 239 | 240 | x = linen.Conv( 241 | self.embed_dim, 242 | kernel_size=patch_size, 243 | strides=patch_size, 244 | use_bias=self.use_bias, 245 | dtype=self.dtype, 246 | )(x) 247 | x = jnp.reshape(x, (B, -1, self.embed_dim)) 248 | if self.norm_layer is not None: 249 | x = self.norm_layer()(x) 250 | return x 251 | 252 | 253 | class EVA02TransformerBlock(linen.Module): 254 | num_extra_tokens: int 255 | 256 | mlp_dim: int 257 | num_heads: int 258 | drop_path_ratio: float 259 | 260 | scale_mlp: bool 261 | 262 | use_bias: bool 263 | 264 | norm_layer: Callable 265 | rope: Callable 266 | 267 | dtype: jt.DTypeLike = jnp.float32 268 | 269 | @linen.compact 270 | def __call__(self, x, train: bool = False): 271 | shortcut = x 272 | 273 | x = self.norm_layer()(x) 274 | x = Attention( 275 | num_extra_tokens=self.num_extra_tokens, 276 | dim=x.shape[-1], 277 | num_heads=self.num_heads, 278 | rope=self.rope, 279 | qkv_bias=self.use_bias, 280 | proj_bias=self.use_bias, 281 | dtype=self.dtype, 282 | )(x, train=train) 283 | x = linen.Dropout( 284 | rate=self.drop_path_ratio, 285 | broadcast_dims=(1, 2), 286 | )(x, deterministic=not train) 287 | x = shortcut + x 288 | 289 | shortcut = x 290 | x = self.norm_layer()(x) 291 | x = SwiGLU( 292 | hidden_features=self.mlp_dim, 293 | scale_mlp=self.scale_mlp, 294 | use_bias=self.use_bias, 295 | norm_layer=self.norm_layer, 296 | dtype=self.dtype, 297 | )(x, train=train) 298 | x = linen.Dropout( 299 | rate=self.drop_path_ratio, 300 | broadcast_dims=(1, 2), 301 | )(x, deterministic=not train) 302 | x = shortcut + x 303 | return x 304 | 305 | 306 | def make_norm_layer(layer_name): 307 | if layer_name == "reparam_layernorm": 308 | return LayerNorm 309 | elif layer_name == "linen_layernorm": 310 | return linen.LayerNorm 311 | 312 | 313 | class EVA02Transformer(linen.Module): 314 | image_size: int = 224 315 | patch_size: int = 16 316 | num_classes: int = 1000 317 | 318 | use_pos_emb: bool = True 319 | use_cls_token: bool = True 320 | num_register_tokens: int = 0 321 | 322 | num_layers: int = 12 323 | embed_dim: int = 768 324 | mlp_dim: int = 3072 325 | num_heads: int = 12 326 | 327 | scale_mlp: bool = True 328 | 329 | drop_path_rate: float = 0.1 330 | 331 | use_norm_bias: bool = True 332 | use_linear_bias: bool = True 333 | 334 | norm_layer: str = "linen_layernorm" 335 | 336 | layer_norm_eps: float = 1e-6 337 | dtype: jt.DTypeLike = jnp.float32 338 | 339 | def setup(self): 340 | norm_layer = make_norm_layer(self.norm_layer) 341 | norm_layer = partial( 342 | norm_layer, 343 | epsilon=self.layer_norm_eps, 344 | use_bias=self.use_norm_bias, 345 | dtype=self.dtype, 346 | ) 347 | 348 | self.patch_embed = PatchEmbed( 349 | patch_size=self.patch_size, 350 | embed_dim=self.embed_dim, 351 | use_bias=self.use_linear_bias, 352 | dtype=self.dtype, 353 | ) 354 | 355 | if self.use_cls_token: 356 | cls_token_init = linen.initializers.truncated_normal(stddev=0.02) 357 | self.cls_token = self.param( 358 | "cls_token", 359 | cls_token_init, 360 | (1, 1, self.embed_dim), 361 | ) 362 | 363 | if self.num_register_tokens: 364 | reg_token_init = linen.initializers.truncated_normal(stddev=0.02) 365 | self.reg_token = self.param( 366 | "reg_token", 367 | reg_token_init, 368 | (1, self.num_register_tokens, self.embed_dim), 369 | ) 370 | 371 | self.num_extra_tokens = int(self.use_cls_token) + self.num_register_tokens 372 | 373 | self.pos_emb = PosEmbed(dtype=self.dtype) if self.use_pos_emb else lambda x: x 374 | 375 | half_head_dim = self.embed_dim // self.num_heads // 2 376 | hw_seq_len = self.image_size // self.patch_size 377 | 378 | self.rope_emb = VisionRotaryEmbeddingFast(dim=half_head_dim, seq_len=hw_seq_len) 379 | 380 | # stochastic depth with linear decay 381 | dpr = np.linspace(0, self.drop_path_rate, self.num_layers) 382 | dpr = [float(x) for x in dpr] 383 | 384 | eva02_body = [] 385 | for i_layer in range(self.num_layers): 386 | layer = EVA02TransformerBlock( 387 | num_extra_tokens=self.num_extra_tokens, 388 | mlp_dim=self.mlp_dim, 389 | num_heads=self.num_heads, 390 | scale_mlp=self.scale_mlp, 391 | drop_path_ratio=dpr[i_layer], 392 | use_bias=self.use_linear_bias, 393 | norm_layer=norm_layer, 394 | rope=self.rope_emb, 395 | dtype=self.dtype, 396 | ) 397 | eva02_body.append(layer) 398 | self.eva02_body = eva02_body 399 | 400 | self.norm = norm_layer() 401 | self.head = ( 402 | linen.Dense( 403 | self.num_classes, 404 | use_bias=self.use_linear_bias, 405 | dtype=self.dtype, 406 | ) 407 | if self.num_classes > 0 408 | else lambda x: x 409 | ) 410 | 411 | def __call__(self, x, train: bool = False): 412 | x = self.patch_embed(x) 413 | 414 | if self.use_cls_token: 415 | B, N, C = x.shape 416 | b_cls = self.cls_token.astype(x.dtype) 417 | b_cls = jnp.broadcast_to(b_cls, (B, 1, C)) 418 | x = jnp.concatenate([b_cls, x], axis=1) 419 | 420 | x = self.pos_emb(x) 421 | 422 | if self.num_register_tokens: 423 | B, N, C = x.shape 424 | b_reg = self.reg_token.astype(x.dtype) 425 | b_reg = jnp.broadcast_to(b_reg, (B, self.num_register_tokens, C)) 426 | x = jnp.concatenate([b_reg, x], axis=1) 427 | 428 | for layer in self.eva02_body: 429 | x = layer(x, train=train) 430 | 431 | x = x[:, self.num_extra_tokens :] 432 | x = jnp.mean(x, axis=(1,)) 433 | x = self.norm(x) 434 | x = self.head(x) 435 | return x 436 | 437 | @classmethod 438 | def build(cls, config, **kwargs): 439 | config = dataclasses.asdict(config) 440 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 441 | return cls(**config) 442 | 443 | def extend_parser(self, parser): 444 | parser.set_defaults(patch_size=self.patch_size) 445 | parser.add_argument( 446 | "--drop-path-rate", 447 | default=self.drop_path_rate, 448 | help="Stochastic depth rate", 449 | type=float, 450 | ) 451 | 452 | parser.add_argument( 453 | "--enable-linear-bias", 454 | dest="use_linear_bias", 455 | help="Enable linear layers bias", 456 | action="store_true", 457 | ) 458 | parser.add_argument( 459 | "--disable-linear-bias", 460 | dest="use_linear_bias", 461 | help="Disable linear layers bias", 462 | action="store_false", 463 | ) 464 | parser.set_defaults(use_linear_bias=self.use_linear_bias) 465 | 466 | parser.add_argument( 467 | "--enable-norm-bias", 468 | dest="use_norm_bias", 469 | help="Enable norm layers bias", 470 | action="store_true", 471 | ) 472 | parser.add_argument( 473 | "--disable-norm-bias", 474 | dest="use_norm_bias", 475 | help="Disable norm layers bias", 476 | action="store_false", 477 | ) 478 | parser.set_defaults(use_norm_bias=self.use_norm_bias) 479 | 480 | parser.add_argument( 481 | "--enable-mlp-ln", 482 | dest="scale_mlp", 483 | help="Enable norm layer in SwiGLU block", 484 | action="store_true", 485 | ) 486 | parser.add_argument( 487 | "--disable-mlp-ln", 488 | dest="scale_mlp", 489 | help="Disable norm layer in SwiGLU block", 490 | action="store_false", 491 | ) 492 | parser.set_defaults(scale_mlp=self.scale_mlp) 493 | 494 | parser.add_argument( 495 | "--norm-layer", 496 | default=self.norm_layer, 497 | help="Normalization layer", 498 | type=str, 499 | ) 500 | 501 | parser.add_argument( 502 | "--enable-pos-emb", 503 | dest="use_pos_emb", 504 | help="Enable (learned) absolute positional embeddings", 505 | action="store_true", 506 | ) 507 | parser.add_argument( 508 | "--disable-pos-emb", 509 | dest="use_pos_emb", 510 | help="Disable (learned) absolute positional embeddings", 511 | action="store_false", 512 | ) 513 | parser.set_defaults(use_pos_emb=self.use_pos_emb) 514 | 515 | parser.add_argument( 516 | "--enable-cls-token", 517 | dest="use_cls_token", 518 | help="Enable cls token", 519 | action="store_true", 520 | ) 521 | parser.add_argument( 522 | "--disable-cls-token", 523 | dest="use_cls_token", 524 | help="Disable cls token", 525 | action="store_false", 526 | ) 527 | parser.set_defaults(use_cls_token=self.use_cls_token) 528 | 529 | parser.add_argument( 530 | "--num-register-tokens", 531 | default=self.num_register_tokens, 532 | help="Number of registers (arXiv:2309.16588)", 533 | type=int, 534 | ) 535 | return parser 536 | 537 | @staticmethod 538 | def get_simmim_orbax_txs(): 539 | # SimMIM checkpoint have no head params - don't try to restore them. 540 | # All the other params we care about are under the "encoder" subsection 541 | regex = r"(?!model/params/head)model/params/(.*)" 542 | action = r"model/params/encoder/\1" 543 | return [(regex, action)] 544 | 545 | def should_decay(self, path, _): 546 | is_kernel = path[-1].key == "kernel" 547 | is_scale = path[-1].key == "scale" 548 | is_scale = is_scale and self.norm_layer == "reparam_layernorm" 549 | verdict = is_kernel or is_scale 550 | return verdict 551 | 552 | 553 | def eva02_small(): 554 | config = { 555 | "num_layers": 12, 556 | "embed_dim": 384, 557 | "mlp_dim": (384 * 4 * 2) // 3, 558 | "num_heads": 6, 559 | "scale_mlp": False, 560 | } 561 | return EVA02Transformer(**config) 562 | 563 | 564 | def eva02_base(): 565 | config = { 566 | "num_layers": 12, 567 | "embed_dim": 768, 568 | "mlp_dim": (768 * 4 * 2) // 3, 569 | "num_heads": 12, 570 | "scale_mlp": True, 571 | } 572 | return EVA02Transformer(**config) 573 | 574 | 575 | def eva02_large(): 576 | config = { 577 | "num_layers": 24, 578 | "embed_dim": 1024, 579 | "mlp_dim": (1024 * 4 * 2) // 3, 580 | "num_heads": 16, 581 | "scale_mlp": True, 582 | } 583 | return EVA02Transformer(**config) 584 | -------------------------------------------------------------------------------- /Models/HiViT.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Callable, Optional, Union 4 | 5 | import jax.numpy as jnp 6 | import jax.typing as jt 7 | import numpy as np 8 | from flax import linen 9 | 10 | 11 | class RelativePositionBias(linen.Module): 12 | input_size: int 13 | num_heads: int 14 | 15 | dtype: jt.DTypeLike = jnp.float32 16 | 17 | def get_relative_position_index(self): 18 | # get pair-wise relative position index for each token inside the window 19 | coords_h = np.arange(self.input_size) 20 | coords_w = np.arange(self.input_size) 21 | 22 | # 2, Wh, Ww 23 | coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) 24 | 25 | # 2, Wh*Ww 26 | coords_flatten = np.reshape(coords, (2, -1)) 27 | 28 | # 2, Wh*Ww, Wh*Ww 29 | coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 30 | 31 | # Wh*Ww, Wh*Ww, 2 32 | coords = np.transpose(coords, (1, 2, 0)) 33 | 34 | # shift to start from 0 35 | coords[:, :, 0] = coords[:, :, 0] + (self.input_size - 1) 36 | coords[:, :, 1] = coords[:, :, 1] + (self.input_size - 1) 37 | coords[:, :, 0] = coords[:, :, 0] * (2 * self.input_size - 1) 38 | 39 | # Wh*Ww, Wh*Ww 40 | position_index = np.sum(coords, axis=-1) 41 | return position_index 42 | 43 | def setup(self): 44 | self.relative_position_bias_table = self.param( 45 | "relative_position_bias_table", 46 | linen.initializers.truncated_normal(stddev=0.02), 47 | ((2 * self.input_size - 1) * (2 * self.input_size - 1), self.num_heads), 48 | ) 49 | 50 | self.relative_position_index = self.variable( 51 | "hivit_constants", 52 | "relative_position_index", 53 | self.get_relative_position_index, 54 | ).value 55 | 56 | def __call__(self, x): 57 | rpe_index = jnp.reshape(self.relative_position_index, (-1,)) 58 | relative_position_bias = self.relative_position_bias_table[rpe_index] 59 | 60 | relative_position_bias = jnp.reshape( 61 | relative_position_bias, 62 | (self.input_size**2, self.input_size**2, -1), 63 | ) 64 | relative_position_bias = jnp.transpose(relative_position_bias, (2, 0, 1)) 65 | relative_position_bias = jnp.expand_dims(relative_position_bias, 0) 66 | 67 | x = x + relative_position_bias 68 | return x 69 | 70 | 71 | class Attention(linen.Module): 72 | input_size: int 73 | dim: int 74 | num_heads: int 75 | qkv_bias: bool = True 76 | qk_scale: Union[None, float] = None 77 | attn_drop_ratio: float = 0.0 78 | proj_drop_ratio: float = 0.0 79 | rpe_enabled: bool = True 80 | 81 | dtype: jt.DTypeLike = jnp.float32 82 | 83 | def setup(self): 84 | self.attention_bias = ( 85 | RelativePositionBias( 86 | self.input_size, 87 | self.num_heads, 88 | dtype=self.dtype, 89 | ) 90 | if self.rpe_enabled 91 | else lambda x: x 92 | ) 93 | 94 | self.qkv = linen.Dense(self.dim * 3, use_bias=self.qkv_bias, dtype=self.dtype) 95 | self.attn_drop = linen.Dropout(self.attn_drop_ratio) 96 | self.proj = linen.Dense(self.dim, dtype=self.dtype) 97 | self.proj_drop = linen.Dropout(self.proj_drop_ratio) 98 | self.softmax = partial(linen.softmax, axis=-1) 99 | 100 | def __call__(self, x, train: bool = False): 101 | B, N, C = x.shape 102 | qkv = self.qkv(x) 103 | qkv = jnp.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) 104 | qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) 105 | 106 | q, k, v = (qkv[0], qkv[1], qkv[2]) 107 | 108 | q = q / jnp.sqrt(q.shape[-1]).astype(q.dtype) 109 | attn = q @ jnp.transpose(k, (0, 1, 3, 2)) 110 | 111 | attn = self.attention_bias(attn) 112 | 113 | attn = self.softmax(attn).astype(self.dtype) 114 | attn = self.attn_drop(attn, deterministic=not train) 115 | 116 | x = jnp.transpose(attn @ v, (0, 2, 1, 3)) 117 | x = jnp.reshape((x), (B, N, C)) 118 | x = self.proj(x) 119 | x = self.proj_drop(x, deterministic=not train) 120 | return x 121 | 122 | 123 | class PatchMerging(linen.Module): 124 | r"""Patch Merging Layer. 125 | 126 | Args: 127 | input_resolution (tuple[int]): Resolution of input feature. 128 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 129 | """ 130 | 131 | input_resolution: tuple[int, int] 132 | norm_layer: Callable = linen.LayerNorm 133 | 134 | dtype: jt.DTypeLike = jnp.float32 135 | 136 | @linen.compact 137 | def __call__(self, x): 138 | """ 139 | x: B, H*W, C 140 | """ 141 | H, W = self.input_resolution 142 | B, L, C = x.shape 143 | 144 | x = jnp.reshape(x, (B, H // 2, 2, W // 2, 2, C)) # B H/2 nH W/2 nW C 145 | x = jnp.transpose(x, (0, 1, 3, 2, 4, 5)) # B H/2 W/2 nH nW C 146 | x = jnp.reshape(x, (B, (H // 2) * (W // 2), 4 * C)) # B H/2*W/2 4*C 147 | 148 | x = self.norm_layer()(x) 149 | x = linen.Dense(2 * C, use_bias=False, dtype=self.dtype)(x) 150 | return x 151 | 152 | 153 | class MLP(linen.Module): 154 | hidden_features: int 155 | act_layer: Callable = linen.gelu 156 | drop_ratio: float = 0.0 157 | 158 | dtype: jt.DTypeLike = jnp.float32 159 | 160 | @linen.compact 161 | def __call__(self, x, train: bool): 162 | out_dim = x.shape[-1] 163 | 164 | x = linen.Dense(self.hidden_features, dtype=self.dtype)(x) 165 | x = self.act_layer(x) 166 | x = linen.Dropout(self.drop_ratio)(x, deterministic=not train) 167 | x = linen.Dense(out_dim, dtype=self.dtype)(x) 168 | return x 169 | 170 | 171 | class PosEmbed(linen.Module): 172 | dtype: jt.DTypeLike = jnp.float32 173 | 174 | @linen.compact 175 | def __call__(self, x): 176 | _, L, C = x.shape 177 | pos_emb_init = linen.initializers.normal(stddev=1 / np.sqrt(C)) 178 | pos_emb = self.param("pos_emb", pos_emb_init, (1, L, C)) 179 | pos_emb = pos_emb.astype(self.dtype) 180 | x = x + pos_emb 181 | return x 182 | 183 | 184 | class PatchEmbed(linen.Module): 185 | r"""Image to Patch Embedding 186 | 187 | Args: 188 | patch_size (int): Patch token size. Default: 4. 189 | embed_dim (int): Number of linear projection output channels. Default: 96. 190 | norm_layer (nn.Module, optional): Normalization layer. Default: None 191 | """ 192 | 193 | patch_size: int = 4 194 | embed_dim: int = 96 195 | internal_patches: int = 4 196 | 197 | norm_layer: Optional[Callable] = None 198 | 199 | dtype: jt.DTypeLike = jnp.float32 200 | 201 | def patches_reshape(self, x): 202 | B, H, W, C = x.shape 203 | nH = nW = self.internal_patches 204 | H = H // self.internal_patches 205 | W = W // self.internal_patches 206 | x = jnp.reshape(x, (B, H, nH, W, nW, C)) 207 | x = jnp.transpose(x, (0, 1, 3, 2, 4, 5)) 208 | x = jnp.reshape(x, (B, H * W * nH * nW, C)) 209 | return x 210 | 211 | @linen.compact 212 | def __call__(self, x): 213 | B, _, _, _ = x.shape 214 | patch_size = (self.patch_size, self.patch_size) 215 | 216 | x = linen.Conv( 217 | self.embed_dim, 218 | kernel_size=patch_size, 219 | strides=patch_size, 220 | dtype=self.dtype, 221 | )(x) 222 | 223 | x = self.patches_reshape(x) 224 | 225 | if self.norm_layer is not None: 226 | x = self.norm_layer()(x) 227 | return x 228 | 229 | 230 | class HierarchicalViTBlock(linen.Module): 231 | mlp_dim: int 232 | num_heads: Union[None, int] 233 | drop_path_ratio: float 234 | 235 | norm_layer: Callable 236 | 237 | dtype: jt.DTypeLike = jnp.float32 238 | 239 | @linen.compact 240 | def __call__(self, x, train: bool = False): 241 | shortcut = x 242 | 243 | x = self.norm_layer()(x) 244 | if self.num_heads: 245 | _, L, C = x.shape 246 | x = Attention( 247 | input_size=int(L**0.5), 248 | dim=C, 249 | num_heads=self.num_heads, 250 | dtype=self.dtype, 251 | )(x, train=train) 252 | else: 253 | x = MLP(hidden_features=self.mlp_dim, dtype=self.dtype)(x, train=train) 254 | 255 | x = linen.Dropout( 256 | rate=self.drop_path_ratio, 257 | broadcast_dims=(1, 2), 258 | )(x, deterministic=not train) 259 | x = shortcut + x 260 | 261 | shortcut = x 262 | x = self.norm_layer()(x) 263 | x = MLP(hidden_features=self.mlp_dim, dtype=self.dtype)(x, train=train) 264 | x = linen.Dropout( 265 | rate=self.drop_path_ratio, 266 | broadcast_dims=(1, 2), 267 | )(x, deterministic=not train) 268 | x = shortcut + x 269 | return x 270 | 271 | 272 | class BasicLayer(linen.Module): 273 | depth: int 274 | 275 | num_heads: Union[None, int] = 12 276 | mlp_ratio: float = 4.0 277 | pos_emb_enabled: bool = False 278 | drop_path_ratio: tuple[float, ...] = (0.0,) 279 | 280 | downsample: Union[None, Callable] = None 281 | 282 | norm_layer: Callable = linen.LayerNorm 283 | 284 | dtype: jt.DTypeLike = jnp.float32 285 | 286 | @linen.compact 287 | def __call__(self, x, train: bool = False): 288 | B, L, C = x.shape 289 | 290 | if self.pos_emb_enabled: 291 | x = PosEmbed(dtype=self.dtype)(x) 292 | 293 | mlp_dim = int(C * self.mlp_ratio) 294 | for i in range(self.depth): 295 | x = HierarchicalViTBlock( 296 | mlp_dim=mlp_dim, 297 | num_heads=self.num_heads, 298 | drop_path_ratio=self.drop_path_ratio[i], 299 | norm_layer=self.norm_layer, 300 | dtype=self.dtype, 301 | )(x, train=train) 302 | 303 | # patch merging layer 304 | if self.downsample is not None: 305 | H = W = int(L**0.5) 306 | x = self.downsample( 307 | input_resolution=(H, W), 308 | norm_layer=self.norm_layer, 309 | dtype=self.dtype, 310 | )(x) 311 | 312 | return x 313 | 314 | 315 | class HierarchicalViT(linen.Module): 316 | patch_size: int = 4 317 | num_classes: int = 1000 318 | 319 | depths: tuple[int, ...] = (2, 2, 20) 320 | embed_dim: int = 192 321 | mlp_ratio: tuple[float, ...] = (3.0, 3.0, 4.0) 322 | num_heads: tuple[Union[None, int], ...] = (None, None, 12) 323 | pos_emb_delay: int = 2 324 | 325 | drop_path_rate: float = 0.1 326 | 327 | norm_layer: Callable = linen.LayerNorm 328 | 329 | layer_norm_eps: float = 1e-6 330 | dtype: jt.DTypeLike = jnp.float32 331 | 332 | def setup(self): 333 | depths = self.depths 334 | num_layers = len(depths) 335 | norm_layer = partial( 336 | self.norm_layer, 337 | epsilon=self.layer_norm_eps, 338 | dtype=self.dtype, 339 | ) 340 | 341 | self.patch_embed = PatchEmbed( 342 | patch_size=self.patch_size, 343 | embed_dim=self.embed_dim, 344 | norm_layer=norm_layer, 345 | dtype=self.dtype, 346 | ) 347 | 348 | # stochastic depth with linear decay 349 | dpr = np.linspace(0, self.drop_path_rate, sum(depths)) 350 | dpr = [float(x) for x in dpr] 351 | 352 | hivit_body = [] 353 | for i_layer in range(num_layers): 354 | dpr_slice = tuple(dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])]) 355 | layer = BasicLayer( 356 | depth=depths[i_layer], 357 | mlp_ratio=self.mlp_ratio[i_layer], 358 | num_heads=self.num_heads[i_layer], 359 | pos_emb_enabled=i_layer == self.pos_emb_delay, 360 | drop_path_ratio=dpr_slice, 361 | norm_layer=norm_layer, 362 | downsample=PatchMerging if (i_layer < num_layers - 1) else None, 363 | dtype=self.dtype, 364 | ) 365 | hivit_body.append(layer) 366 | self.hivit_body = hivit_body 367 | 368 | self.norm = norm_layer() 369 | self.head = ( 370 | linen.Dense(self.num_classes, dtype=self.dtype) 371 | if self.num_classes > 0 372 | else lambda x: x 373 | ) 374 | 375 | def __call__(self, x, train: bool = False): 376 | x = self.patch_embed(x) 377 | 378 | for layer in self.hivit_body: 379 | x = layer(x, train=train) 380 | 381 | x = jnp.mean(x, axis=(1,)) 382 | x = self.norm(x) 383 | x = self.head(x) 384 | return x 385 | 386 | @classmethod 387 | def build(cls, config, **kwargs): 388 | config = dataclasses.asdict(config) 389 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 390 | return cls(**config) 391 | 392 | def extend_parser(self, parser): 393 | parser.set_defaults(patch_size=self.patch_size) 394 | parser.add_argument( 395 | "--drop-path-rate", 396 | default=self.drop_path_rate, 397 | help="Stochastic depth rate", 398 | type=float, 399 | ) 400 | return parser 401 | 402 | @staticmethod 403 | def get_simmim_orbax_txs(): 404 | # SimMIM checkpoint have no head params - don't try to restore them. 405 | # All the other params we care about are under the "encoder" subsection 406 | regex = r"(?!model/params/head)model/params/(.*)" 407 | action = r"model/params/encoder/\1" 408 | return [(regex, action)] 409 | 410 | def should_decay(self, path, _): 411 | is_kernel = path[-1].key == "kernel" 412 | verdict = is_kernel 413 | return verdict 414 | 415 | 416 | def hivit_tiny(): 417 | config = { 418 | "depths": (1, 1, 10), 419 | "embed_dim": 96, 420 | "mlp_ratio": (3.0, 3.0, 4.0), 421 | "num_heads": (None, None, 6), 422 | } 423 | return HierarchicalViT(**config) 424 | 425 | 426 | def hivit_small(): 427 | config = { 428 | "depths": (2, 2, 20), 429 | "embed_dim": 96, 430 | "mlp_ratio": (3.0, 3.0, 4.0), 431 | "num_heads": (None, None, 6), 432 | } 433 | return HierarchicalViT(**config) 434 | 435 | 436 | def hivit_base(): 437 | config = { 438 | "depths": (2, 2, 20), 439 | "embed_dim": 128, 440 | "mlp_ratio": (3.0, 3.0, 4.0), 441 | "num_heads": (None, None, 6), 442 | } 443 | return HierarchicalViT(**config) 444 | -------------------------------------------------------------------------------- /Models/SimMIM.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any 3 | 4 | import einops 5 | import jax.numpy as jnp 6 | from flax import linen 7 | 8 | from .ConvNext import ConvNext 9 | from .EVA02 import EVA02Transformer 10 | from .HiViT import HierarchicalViT 11 | from .SwinV2 import SwinTransformerV2 12 | from .ViT import VisionTransformer 13 | 14 | 15 | class WindowedNorm(linen.Module): 16 | target_size: tuple[int, int] 17 | window_size: int = 47 18 | 19 | def get_targets_count(self): 20 | window_shape = (self.window_size, self.window_size) 21 | padding = ( 22 | (self.window_size // 2, self.window_size // 2), 23 | (self.window_size // 2, self.window_size // 2), 24 | ) 25 | 26 | targets_count = jnp.ones((1, self.target_size[0], self.target_size[1], 1)) 27 | 28 | targets_count = linen.avg_pool( 29 | targets_count, 30 | window_shape=window_shape, 31 | strides=(1, 1), 32 | padding=padding, 33 | count_include_pad=True, 34 | ) 35 | targets_count = targets_count * jnp.power(self.window_size, 2.0) 36 | targets_count = jnp.int32(jnp.rint(targets_count)) 37 | return targets_count 38 | 39 | def setup(self): 40 | self.targets_count = self.variable( 41 | "simmim_constants", 42 | "targets_count", 43 | self.get_targets_count, 44 | ).value 45 | 46 | def __call__(self, targets): 47 | window_size = self.window_size 48 | 49 | window_shape = (window_size, window_size) 50 | padding = ( 51 | (window_size // 2, window_size // 2), 52 | (window_size // 2, window_size // 2), 53 | ) 54 | 55 | targets_ = targets 56 | 57 | targets_square = jnp.power(targets, 2.0) 58 | 59 | targets_mean = linen.avg_pool( 60 | targets, 61 | window_shape=window_shape, 62 | strides=(1, 1), 63 | padding=padding, 64 | count_include_pad=False, 65 | ) 66 | targets_square_mean = linen.avg_pool( 67 | targets_square, 68 | window_shape=window_shape, 69 | strides=(1, 1), 70 | padding=padding, 71 | count_include_pad=False, 72 | ) 73 | 74 | targets_var = targets_square_mean - jnp.power(targets_mean, 2.0) 75 | targets_var = targets_var * (self.targets_count / (self.targets_count - 1)) 76 | targets_var = jnp.maximum(targets_var, 0.0) 77 | 78 | targets_ = (targets_ - targets_mean) / jnp.sqrt(targets_var + 1.0e-6) 79 | 80 | return targets_ 81 | 82 | 83 | class SwinTransformerV2ForSimMIM(SwinTransformerV2): 84 | def setup(self): 85 | super().setup() 86 | 87 | token_init = linen.initializers.normal(0.02) 88 | self.mask_token = self.param("mask_token", token_init, (1, 1, self.embed_dim)) 89 | 90 | def __call__(self, x, mask, train: bool = False): 91 | x = self.patch_embed(x) 92 | 93 | B, L, _ = x.shape 94 | mask_token = self.mask_token.astype(self.dtype) 95 | mask_tokens = jnp.broadcast_to(mask_token, (B, L, self.embed_dim)) 96 | mask = jnp.reshape(mask, (B, L, 1)).astype(mask_tokens.dtype) 97 | x = x * (1.0 - mask) + mask_tokens * mask 98 | 99 | x = self.pos_drop(x, deterministic=not train) 100 | 101 | for layer in self.swin_body: 102 | x = layer(x, train=train) 103 | 104 | x = self.norm(x) 105 | 106 | B, L, C = x.shape 107 | H = W = int(L**0.5) 108 | x = jnp.reshape(x, (B, H, W, C)) 109 | return x 110 | 111 | def get_stride(self): 112 | return self.patch_size * 2 ** (len(self.depths) - 1) 113 | 114 | 115 | class VisionTransformerForSimMIM(VisionTransformer): 116 | def setup(self): 117 | super().setup() 118 | 119 | token_init = linen.initializers.normal(0.02) 120 | self.mask_token = self.param("mask_token", token_init, (1, 1, self.embed_dim)) 121 | 122 | def __call__(self, x, mask, train: bool = False): 123 | x = self.patch_embed(x) 124 | 125 | B, L, _ = x.shape 126 | mask_tokens = jnp.broadcast_to(self.mask_token, (B, L, self.embed_dim)) 127 | mask = jnp.reshape(mask, (B, L, 1)).astype(mask_tokens.dtype) 128 | x = x * (1.0 - mask) + mask_tokens * mask 129 | 130 | x = self.pos_emb(x) 131 | 132 | for layer in self.vit_body: 133 | x = layer(x, train=train) 134 | 135 | x = self.norm(x) 136 | 137 | B, L, C = x.shape 138 | H = W = int(L**0.5) 139 | x = jnp.reshape(x, (B, H, W, C)) 140 | return x 141 | 142 | def get_stride(self): 143 | return self.patch_size 144 | 145 | 146 | class HierarchicalViTForSimMIM(HierarchicalViT): 147 | def setup(self): 148 | super().setup() 149 | 150 | token_init = linen.initializers.normal(0.02) 151 | self.mask_token = self.param("mask_token", token_init, (1, 1, self.embed_dim)) 152 | 153 | def __call__(self, x, mask, train: bool = False): 154 | x = self.patch_embed(x) 155 | 156 | B, L, _ = x.shape 157 | H = W = int(L**0.5) 158 | mask_token = self.mask_token.astype(self.dtype) 159 | mask_tokens = jnp.broadcast_to(mask_token, (B, L, self.embed_dim)) 160 | mask = jnp.reshape(mask, (B, H, W, 1)).astype(mask_tokens.dtype) 161 | mask = self.patch_embed.patches_reshape(mask) 162 | x = x * (1.0 - mask) + mask_tokens * mask 163 | 164 | for layer in self.hivit_body: 165 | x = layer(x, train=train) 166 | 167 | x = self.norm(x) 168 | 169 | B, L, C = x.shape 170 | H = W = int(L**0.5) 171 | x = jnp.reshape(x, (B, H, W, C)) 172 | return x 173 | 174 | def get_stride(self): 175 | return 16 176 | 177 | 178 | class ConvNextForSimMIM(ConvNext): 179 | def setup(self): 180 | super().setup() 181 | 182 | token_init = linen.initializers.normal(0.02) 183 | self.mask_token = self.param("mask_token", token_init, (self.embed_dims[0],)) 184 | 185 | def __call__(self, x, mask, train: bool = False): 186 | x = self.patch_embed(x) 187 | 188 | B, H, W, _ = x.shape 189 | mask_tokens = jnp.broadcast_to(self.mask_token, (B, H, W, self.embed_dims[0])) 190 | mask = jnp.reshape(mask, (B, H, W, 1)).astype(mask_tokens.dtype) 191 | x = x * (1.0 - mask) + mask_tokens * mask 192 | 193 | for layer in self.convnext_body: 194 | x = layer(x, train=train) 195 | 196 | x = self.norm(x) 197 | return x 198 | 199 | def get_stride(self): 200 | return 32 201 | 202 | 203 | class EVA02ForSimMIM(EVA02Transformer): 204 | def setup(self): 205 | super().setup() 206 | 207 | token_init = linen.initializers.normal(0.02) 208 | self.mask_token = self.param("mask_token", token_init, (1, 1, self.embed_dim)) 209 | 210 | def __call__(self, x, mask, train: bool = False): 211 | x = self.patch_embed(x) 212 | 213 | B, L, _ = x.shape 214 | mask_tokens = jnp.broadcast_to(self.mask_token, (B, L, self.embed_dim)) 215 | mask = jnp.reshape(mask, (B, L, 1)).astype(mask_tokens.dtype) 216 | x = x * (1.0 - mask) + mask_tokens * mask 217 | 218 | if self.use_cls_token: 219 | B, L, C = x.shape 220 | b_cls = self.cls_token.astype(x.dtype) 221 | b_cls = jnp.broadcast_to(b_cls, (B, 1, C)) 222 | x = jnp.concatenate([b_cls, x], axis=1) 223 | 224 | x = self.pos_emb(x) 225 | 226 | if self.num_register_tokens: 227 | B, L, C = x.shape 228 | b_reg = self.reg_token.astype(x.dtype) 229 | b_reg = jnp.broadcast_to(b_reg, (B, self.num_register_tokens, C)) 230 | x = jnp.concatenate([b_reg, x], axis=1) 231 | 232 | for layer in self.eva02_body: 233 | x = layer(x, train=train) 234 | 235 | x = self.norm(x) 236 | x = x[:, self.num_extra_tokens :] 237 | 238 | B, L, C = x.shape 239 | H = W = int(L**0.5) 240 | x = jnp.reshape(x, (B, H, W, C)) 241 | return x 242 | 243 | def get_stride(self): 244 | return self.patch_size 245 | 246 | 247 | class SimMIM(linen.Module): 248 | encoder: linen.Module = SwinTransformerV2ForSimMIM 249 | encoder_stride: int = 32 250 | 251 | patch_size: int = 4 252 | 253 | enable_windowed_norm: bool = False 254 | norm_patch_size: int = 47 255 | 256 | dtype: Any = jnp.float32 257 | 258 | @linen.compact 259 | def __call__(self, x, mask, train: bool = False): 260 | z = self.encoder(x, mask, train) 261 | x_rec = linen.Conv( 262 | features=self.encoder_stride**2 * 3, 263 | kernel_size=(1, 1), 264 | dtype=self.dtype, 265 | )(z) 266 | x_rec = einops.rearrange( 267 | x_rec, 268 | pattern="... h w (c b1 b2) -> ... (h b1) (w b2) c", 269 | b1=self.encoder_stride, 270 | b2=self.encoder_stride, 271 | ) 272 | 273 | mask = jnp.expand_dims( 274 | jnp.repeat( 275 | jnp.repeat(mask, self.patch_size, axis=1), 276 | self.patch_size, 277 | axis=2, 278 | ), 279 | axis=-1, 280 | ) 281 | 282 | B, H, W, C = x.shape 283 | if self.enable_windowed_norm: 284 | x = WindowedNorm(target_size=(H, W), window_size=self.norm_patch_size)(x) 285 | 286 | x_rec = x_rec.astype(x.dtype) 287 | loss_recon = jnp.abs(x - x_rec) 288 | loss = jnp.sum(loss_recon * mask) / (jnp.sum(mask) + 1e-5) / C 289 | 290 | return loss, x_rec 291 | 292 | @classmethod 293 | def build(cls, config, **kwargs): 294 | encoder = config.encoder.build(config.encoder, **kwargs) 295 | 296 | config = dataclasses.asdict(config) 297 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 298 | config["encoder"] = encoder 299 | config["encoder_stride"] = encoder.get_stride() 300 | return cls(**config) 301 | 302 | def extend_parser(self, parser): 303 | parser = self.encoder.extend_parser(parser) 304 | parser.add_argument( 305 | "--enable-windowed-norm", 306 | action="store_true", 307 | help="Use windowed norm of input images as reconstruction target in SimMIM", 308 | ) 309 | return parser 310 | 311 | def should_decay(self, path, _): 312 | if path[0].key == "encoder": 313 | return self.encoder.should_decay(path[1:], _) 314 | 315 | is_kernel = path[-1].key == "kernel" 316 | verdict = is_kernel 317 | return verdict 318 | 319 | 320 | def simmim_swinv2_tiny(): 321 | config = { 322 | "embed_dim": 96, 323 | "depths": (2, 2, 6, 2), 324 | "num_heads": (3, 6, 12, 24), 325 | } 326 | encoder = SwinTransformerV2ForSimMIM(**config) 327 | 328 | config = { 329 | "encoder": encoder, 330 | "encoder_stride": encoder.get_stride(), 331 | "patch_size": encoder.patch_size, 332 | } 333 | return SimMIM(**config) 334 | 335 | 336 | def simmim_swinv2_base(): 337 | config = { 338 | "embed_dim": 128, 339 | "depths": (2, 2, 18, 2), 340 | "num_heads": (4, 8, 16, 32), 341 | } 342 | encoder = SwinTransformerV2ForSimMIM(**config) 343 | 344 | config = { 345 | "encoder": encoder, 346 | "encoder_stride": encoder.get_stride(), 347 | "patch_size": encoder.patch_size, 348 | } 349 | return SimMIM(**config) 350 | 351 | 352 | def simmim_swinv2_large(): 353 | config = { 354 | "embed_dim": 192, 355 | "depths": (2, 2, 18, 2), 356 | "num_heads": (6, 12, 24, 48), 357 | } 358 | encoder = SwinTransformerV2ForSimMIM(**config) 359 | 360 | config = { 361 | "encoder": encoder, 362 | "encoder_stride": encoder.get_stride(), 363 | "patch_size": encoder.patch_size, 364 | } 365 | return SimMIM(**config) 366 | 367 | 368 | def simmim_vit_small(): 369 | config = { 370 | "num_layers": 12, 371 | "embed_dim": 384, 372 | "mlp_dim": 1536, 373 | "num_heads": 6, 374 | } 375 | encoder = VisionTransformerForSimMIM(**config) 376 | 377 | config = { 378 | "encoder": encoder, 379 | "encoder_stride": encoder.get_stride(), 380 | "patch_size": encoder.patch_size, 381 | } 382 | return SimMIM(**config) 383 | 384 | 385 | def simmim_vit_base(): 386 | config = { 387 | "num_layers": 12, 388 | "embed_dim": 768, 389 | "mlp_dim": 3072, 390 | "num_heads": 12, 391 | } 392 | encoder = VisionTransformerForSimMIM(**config) 393 | 394 | config = { 395 | "encoder": encoder, 396 | "encoder_stride": encoder.patch_size, 397 | "patch_size": encoder.patch_size, 398 | } 399 | return SimMIM(**config) 400 | 401 | 402 | def simmim_vit_large(): 403 | config = { 404 | "num_layers": 24, 405 | "embed_dim": 1024, 406 | "mlp_dim": 4096, 407 | "num_heads": 16, 408 | } 409 | encoder = VisionTransformerForSimMIM(**config) 410 | 411 | config = { 412 | "encoder": encoder, 413 | "encoder_stride": encoder.patch_size, 414 | "patch_size": encoder.patch_size, 415 | } 416 | return SimMIM(**config) 417 | 418 | 419 | def simmim_hivit_tiny(): 420 | config = { 421 | "depths": (1, 1, 10), 422 | "embed_dim": 96, 423 | "mlp_ratio": (3.0, 3.0, 4.0), 424 | "num_heads": (None, None, 6), 425 | } 426 | encoder = HierarchicalViTForSimMIM(**config) 427 | 428 | config = { 429 | "encoder": encoder, 430 | "encoder_stride": encoder.get_stride(), 431 | "patch_size": encoder.patch_size, 432 | } 433 | return SimMIM(**config) 434 | 435 | 436 | def simmim_hivit_small(): 437 | config = { 438 | "depths": (2, 2, 20), 439 | "embed_dim": 96, 440 | "mlp_ratio": (3.0, 3.0, 4.0), 441 | "num_heads": (None, None, 6), 442 | } 443 | encoder = HierarchicalViTForSimMIM(**config) 444 | 445 | config = { 446 | "encoder": encoder, 447 | "encoder_stride": encoder.get_stride(), 448 | "patch_size": encoder.patch_size, 449 | } 450 | return SimMIM(**config) 451 | 452 | 453 | def simmim_convnext_tiny(): 454 | config = { 455 | "embed_dims": (96, 192, 384, 768), 456 | "depths": (3, 3, 9, 3), 457 | } 458 | encoder = ConvNextForSimMIM(**config) 459 | 460 | config = { 461 | "encoder": encoder, 462 | "encoder_stride": encoder.get_stride(), 463 | "patch_size": encoder.patch_size, 464 | } 465 | return SimMIM(**config) 466 | 467 | 468 | def simmim_convnext_small(): 469 | config = { 470 | "embed_dims": (96, 192, 384, 768), 471 | "depths": (3, 3, 27, 3), 472 | } 473 | encoder = ConvNextForSimMIM(**config) 474 | 475 | config = { 476 | "encoder": encoder, 477 | "encoder_stride": encoder.get_stride(), 478 | "patch_size": encoder.patch_size, 479 | } 480 | return SimMIM(**config) 481 | 482 | 483 | def simmim_convnext_base(): 484 | config = { 485 | "embed_dims": (128, 256, 512, 1024), 486 | "depths": (3, 3, 27, 3), 487 | } 488 | encoder = ConvNextForSimMIM(**config) 489 | 490 | config = { 491 | "encoder": encoder, 492 | "encoder_stride": encoder.get_stride(), 493 | "patch_size": encoder.patch_size, 494 | } 495 | return SimMIM(**config) 496 | 497 | 498 | def simmim_eva02_small(): 499 | config = { 500 | "num_layers": 12, 501 | "embed_dim": 384, 502 | "mlp_dim": (384 * 4 * 2) // 3, 503 | "num_heads": 6, 504 | "scale_mlp": False, 505 | } 506 | encoder = EVA02ForSimMIM(**config) 507 | 508 | config = { 509 | "encoder": encoder, 510 | "encoder_stride": encoder.get_stride(), 511 | "patch_size": encoder.patch_size, 512 | } 513 | return SimMIM(**config) 514 | 515 | 516 | def simmim_eva02_base(): 517 | config = { 518 | "num_layers": 12, 519 | "embed_dim": 768, 520 | "mlp_dim": (768 * 4 * 2) // 3, 521 | "num_heads": 12, 522 | "scale_mlp": True, 523 | } 524 | encoder = EVA02ForSimMIM(**config) 525 | 526 | config = { 527 | "encoder": encoder, 528 | "encoder_stride": encoder.patch_size, 529 | "patch_size": encoder.patch_size, 530 | } 531 | return SimMIM(**config) 532 | 533 | 534 | def simmim_eva02_large(): 535 | config = { 536 | "num_layers": 24, 537 | "embed_dim": 1024, 538 | "mlp_dim": (1024 * 4 * 2) // 3, 539 | "num_heads": 16, 540 | "scale_mlp": True, 541 | } 542 | encoder = EVA02ForSimMIM(**config) 543 | 544 | config = { 545 | "encoder": encoder, 546 | "encoder_stride": encoder.patch_size, 547 | "patch_size": encoder.patch_size, 548 | } 549 | return SimMIM(**config) 550 | -------------------------------------------------------------------------------- /Models/SwinV2.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Callable, Optional, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import jax.typing as jt 8 | import numpy as np 9 | from flax import linen 10 | 11 | 12 | class MLP(linen.Module): 13 | hidden_features: int 14 | act_layer: Callable = linen.gelu 15 | drop_ratio: float = 0.0 16 | dtype: jt.DTypeLike = jnp.float32 17 | 18 | @linen.compact 19 | def __call__(self, x, train: bool): 20 | out_dim = x.shape[-1] 21 | 22 | x = linen.Dense(self.hidden_features, dtype=self.dtype)(x) 23 | x = self.act_layer(x) 24 | x = linen.Dropout(self.drop_ratio)(x, deterministic=not train) 25 | x = linen.Dense(out_dim, dtype=self.dtype)(x) 26 | x = linen.Dropout(self.drop_ratio)(x, deterministic=not train) 27 | return x 28 | 29 | 30 | class RelativePositionBias(linen.Module): 31 | window_size: tuple[int, int] 32 | num_heads: int 33 | pretrained_window_size: tuple[int, int] 34 | dtype: jt.DTypeLike = jnp.float32 35 | 36 | def get_relative_coords_table(self): 37 | coords_h = np.arange(-(self.window_size[0] - 1), self.window_size[0]) 38 | coords_w = np.arange(-(self.window_size[1] - 1), self.window_size[1]) 39 | 40 | # 1, 2*Wh-1, 2*Ww-1, 2 41 | coords_table = np.meshgrid(coords_h, coords_w, indexing="ij") 42 | coords_table = np.stack(coords_table) 43 | coords_table = np.transpose(coords_table, (1, 2, 0)) 44 | coords_table = np.expand_dims(coords_table, 0) 45 | coords_table = coords_table.astype(np.float32) 46 | 47 | if self.pretrained_window_size[0] > 0: 48 | coords_table[:, :, :, 0] = coords_table[:, :, :, 0] / ( 49 | self.pretrained_window_size[0] - 1 50 | ) 51 | coords_table[:, :, :, 1] = coords_table[:, :, :, 1] / ( 52 | self.pretrained_window_size[1] - 1 53 | ) 54 | else: 55 | coords_table[:, :, :, 0] = coords_table[:, :, :, 0] / ( 56 | self.window_size[0] - 1 57 | ) 58 | coords_table[:, :, :, 1] = coords_table[:, :, :, 1] / ( 59 | self.window_size[1] - 1 60 | ) 61 | 62 | # normalize to -8, 8 63 | coords_table = coords_table * 8 64 | coord_table_sign = np.sign(coords_table) 65 | coords_table = np.log2(np.abs(coords_table) + 1.0) 66 | coords_table = coord_table_sign * coords_table / np.log2(8) 67 | return coords_table 68 | 69 | def get_relative_position_index(self): 70 | # get pair-wise relative position index for each token inside the window 71 | coords_h = np.arange(self.window_size[0]) 72 | coords_w = np.arange(self.window_size[1]) 73 | 74 | # 2, Wh, Ww 75 | coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) 76 | 77 | # 2, Wh*Ww 78 | coords_flatten = np.reshape(coords, (2, -1)) 79 | 80 | # 2, Wh*Ww, Wh*Ww 81 | coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 82 | 83 | # Wh*Ww, Wh*Ww, 2 84 | coords = np.transpose(coords, (1, 2, 0)) 85 | 86 | # shift to start from 0 87 | coords[:, :, 0] = coords[:, :, 0] + (self.window_size[0] - 1) 88 | coords[:, :, 1] = coords[:, :, 1] + (self.window_size[1] - 1) 89 | coords[:, :, 0] = coords[:, :, 0] * (2 * self.window_size[1] - 1) 90 | 91 | # Wh*Ww, Wh*Ww 92 | position_index = np.sum(coords, axis=-1) 93 | return position_index 94 | 95 | def setup(self): 96 | self.relative_coords_table = self.variable( 97 | "swinv2_constants", 98 | "relative_coords_table", 99 | self.get_relative_coords_table, 100 | ).value 101 | 102 | self.relative_position_index = self.variable( 103 | "swinv2_constants", 104 | "relative_position_index", 105 | self.get_relative_position_index, 106 | ).value 107 | 108 | # mlp to generate continuous relative position bias 109 | self.cpb_mlp = linen.Sequential( 110 | [ 111 | linen.Dense(512, use_bias=True, dtype=self.dtype), 112 | linen.relu, 113 | linen.Dense(self.num_heads, use_bias=False, dtype=self.dtype), 114 | ], 115 | ) 116 | 117 | def __call__(self, x): 118 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table) 119 | rpe_index = jnp.reshape(self.relative_position_index, (-1,)) 120 | 121 | # Wh*Ww,Wh*Ww,nH 122 | relative_position_bias_table = jnp.reshape( 123 | relative_position_bias_table, (-1, self.num_heads) 124 | ) 125 | relative_position_bias = jnp.reshape( 126 | relative_position_bias_table[rpe_index], 127 | ( 128 | self.window_size[0] * self.window_size[1], 129 | self.window_size[0] * self.window_size[1], 130 | -1, 131 | ), 132 | ) 133 | 134 | # nH, Wh*Ww, Wh*Ww 135 | relative_position_bias = jnp.transpose(relative_position_bias, (2, 0, 1)) 136 | relative_position_bias = 16 * linen.sigmoid(relative_position_bias) 137 | relative_position_bias = jnp.expand_dims(relative_position_bias, 0) 138 | x = x + relative_position_bias 139 | return x 140 | 141 | 142 | def l2_normalize(x): 143 | rnorm = jax.lax.rsqrt(jnp.maximum(jnp.sum((x * x), axis=-1, keepdims=True), 1e-12)) 144 | return x * rnorm 145 | 146 | 147 | class WindowAttention(linen.Module): 148 | r"""Window based multi-head self attention (W-MSA) module with relative position bias. 149 | It supports both of shifted and non-shifted window. 150 | 151 | Args: 152 | dim (int): Number of input channels. 153 | window_size (tuple[int]): The height and width of the window. 154 | num_heads (int): Number of attention heads. 155 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 156 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 157 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 158 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training. 159 | """ 160 | 161 | dim: int 162 | window_size: tuple[int, int] # Wh, Ww 163 | num_heads: int 164 | qkv_bias: bool = True 165 | attn_drop_ratio: float = 0.0 166 | proj_drop_ratio: float = 0.0 167 | pretrained_window_size: tuple[int, int] = (0, 0) 168 | dtype: jt.DTypeLike = jnp.float32 169 | 170 | def setup(self): 171 | self.logit_scale = self.variable( 172 | "params", 173 | "logit_scale", 174 | lambda x: jnp.log(10 * jnp.ones((x, 1, 1))), 175 | self.num_heads, 176 | ).value 177 | 178 | self.qkv = linen.Dense(self.dim * 3, use_bias=False, dtype=self.dtype) 179 | if self.qkv_bias: 180 | bias_init = linen.initializers.zeros_init() 181 | self.q_bias = self.param("q_bias", bias_init, (self.dim,)) 182 | self.v_bias = self.param("v_bias", bias_init, (self.dim,)) 183 | 184 | self.attention_bias = RelativePositionBias( 185 | self.window_size, 186 | self.num_heads, 187 | self.pretrained_window_size, 188 | dtype=self.dtype, 189 | ) 190 | 191 | self.attn_drop = linen.Dropout(self.attn_drop_ratio) 192 | self.proj = linen.Dense(self.dim, dtype=self.dtype) 193 | self.proj_drop = linen.Dropout(self.proj_drop_ratio) 194 | self.softmax = partial(linen.softmax, axis=-1) 195 | 196 | def __call__(self, x, train: bool, mask=None): 197 | """ 198 | Args: 199 | x: input features with shape of (num_windows*B, N, C) 200 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 201 | """ 202 | B_, N, C = x.shape 203 | 204 | qkv = self.qkv(x) 205 | if self.qkv_bias: 206 | q_bias = self.q_bias.astype(self.dtype) 207 | v_bias = self.v_bias.astype(self.dtype) 208 | qkv_bias = jnp.concatenate( 209 | ( 210 | q_bias, 211 | jnp.zeros_like(v_bias), 212 | v_bias, 213 | ) 214 | ) 215 | qkv = qkv + qkv_bias 216 | 217 | qkv = jnp.reshape(qkv, (B_, N, 3, self.num_heads, -1)) 218 | qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) 219 | q, k, v = (qkv[0], qkv[1], qkv[2]) 220 | 221 | q_norm = l2_normalize(q) 222 | k_norm = l2_normalize(k) 223 | attn = q_norm @ jnp.transpose(k_norm, (0, 1, 3, 2)) 224 | 225 | logit_scale = jnp.minimum(self.logit_scale, np.log(100.0)) 226 | logit_scale = logit_scale.astype(self.dtype) 227 | logit_scale = jnp.exp(logit_scale) 228 | attn = attn * logit_scale 229 | 230 | attn = self.attention_bias(attn) 231 | 232 | if mask is not None: 233 | nW = mask.shape[0] 234 | mask = mask.astype(self.dtype) 235 | mask = jnp.expand_dims(jnp.expand_dims(mask, 1), 0) 236 | attn = jnp.reshape(attn, (B_ // nW, nW, self.num_heads, N, N)) + mask 237 | attn = jnp.reshape(attn, (-1, self.num_heads, N, N)) 238 | attn = self.softmax(attn) 239 | else: 240 | attn = self.softmax(attn) 241 | 242 | attn = self.attn_drop(attn, deterministic=not train) 243 | 244 | x = jnp.transpose(attn @ v, (0, 2, 1, 3)) 245 | x = jnp.reshape(x, (B_, N, C)) 246 | x = self.proj(x) 247 | x = self.proj_drop(x, deterministic=not train) 248 | return x 249 | 250 | 251 | def window_partition(x, window_size): 252 | B, H, W, C = x.shape 253 | windows = jnp.reshape( 254 | x, (B, H // window_size, window_size, W // window_size, window_size, C) 255 | ) 256 | windows = jnp.transpose(windows, (0, 1, 3, 2, 4, 5)) 257 | windows = jnp.reshape(windows, (-1, window_size, window_size, C)) 258 | return windows 259 | 260 | 261 | def window_reverse(windows, window_size, H, W): 262 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 263 | x = jnp.reshape( 264 | windows, (B, H // window_size, W // window_size, window_size, window_size, -1) 265 | ) 266 | x = jnp.transpose(x, (0, 1, 3, 2, 4, 5)) 267 | x = jnp.reshape(x, (B, H, W, -1)) 268 | return x 269 | 270 | 271 | class SwinTransformerBlock(linen.Module): 272 | r"""Swin Transformer Block. 273 | 274 | Args: 275 | dim (int): Number of input channels. 276 | input_resolution (tuple[int]): Input resulotion. 277 | num_heads (int): Number of attention heads. 278 | window_size (int): Window size. 279 | shift_size (int): Shift size for SW-MSA. 280 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 281 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 282 | drop (float, optional): Dropout rate. Default: 0.0 283 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 284 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 285 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 286 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 287 | pretrained_window_size (int): Window size in pre-training. 288 | """ 289 | 290 | dim: int 291 | input_resolution: tuple[int, int] 292 | num_heads: int 293 | window_size: int = 7 294 | shift_size: int = 0 295 | mlp_ratio: float = 4.0 296 | qkv_bias: bool = True 297 | drop_ratio: float = 0.0 298 | attn_drop_ratio: float = 0.0 299 | drop_path_ratio: float = 0.0 300 | act_layer: Callable = linen.gelu 301 | norm_layer: Callable = linen.LayerNorm 302 | pretrained_window_size: int = 0 303 | dtype: jt.DTypeLike = jnp.float32 304 | 305 | def setup(self): 306 | self.norm1 = self.norm_layer() 307 | self.attn = WindowAttention( 308 | self.dim, 309 | window_size=(self.window_size, self.window_size), 310 | num_heads=self.num_heads, 311 | qkv_bias=self.qkv_bias, 312 | attn_drop_ratio=self.attn_drop_ratio, 313 | proj_drop_ratio=self.drop_ratio, 314 | pretrained_window_size=( 315 | self.pretrained_window_size, 316 | self.pretrained_window_size, 317 | ), 318 | dtype=self.dtype, 319 | ) 320 | 321 | self.drop_path = linen.Dropout(rate=self.drop_path_ratio, broadcast_dims=(1, 2)) 322 | self.norm2 = self.norm_layer() 323 | mlp_hidden_dim = int(self.dim * self.mlp_ratio) 324 | self.mlp = MLP( 325 | hidden_features=mlp_hidden_dim, 326 | act_layer=self.act_layer, 327 | drop_ratio=self.drop_ratio, 328 | dtype=self.dtype, 329 | ) 330 | 331 | if self.shift_size > 0: 332 | # calculate attention mask for SW-MSA 333 | H, W = self.input_resolution 334 | img_mask = jnp.zeros((1, H, W, 1)) # 1 H W 1 335 | h_slices = ( 336 | slice(0, -self.window_size), 337 | slice(-self.window_size, -self.shift_size), 338 | slice(-self.shift_size, None), 339 | ) 340 | w_slices = ( 341 | slice(0, -self.window_size), 342 | slice(-self.window_size, -self.shift_size), 343 | slice(-self.shift_size, None), 344 | ) 345 | cnt = 0 346 | for h in h_slices: 347 | for w in w_slices: 348 | img_mask = img_mask.at[:, h, w, :].set(cnt) 349 | cnt += 1 350 | 351 | # nW, window_size, window_size, 1 352 | mask_windows = window_partition(img_mask, self.window_size) 353 | mask_windows = jnp.reshape( 354 | mask_windows, (-1, self.window_size * self.window_size) 355 | ) 356 | attn_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims( 357 | mask_windows, 2 358 | ) 359 | attn_mask = jnp.where(attn_mask != 0, float(-100.0), attn_mask) 360 | attn_mask = jnp.where(attn_mask == 0, float(0.0), attn_mask) 361 | else: 362 | attn_mask = None 363 | 364 | self.attn_mask = self.variable( 365 | "swinv2_constants", "attn_mask", lambda: attn_mask 366 | ).value 367 | 368 | def __call__(self, x, train: bool): 369 | H, W = self.input_resolution 370 | B, L, C = x.shape 371 | 372 | shortcut = x 373 | x = jnp.reshape(x, (B, H, W, C)) 374 | 375 | # cyclic shift 376 | if self.shift_size > 0: 377 | shifted_x = jnp.roll( 378 | x, 379 | shift=(-self.shift_size, -self.shift_size), 380 | axis=(1, 2), 381 | ) 382 | else: 383 | shifted_x = x 384 | 385 | # partition windows 386 | # nW*B, window_size, window_size, C 387 | x_windows = window_partition(shifted_x, self.window_size) 388 | 389 | # nW*B, window_size*window_size, C 390 | x_windows = jnp.reshape(x_windows, (-1, self.window_size * self.window_size, C)) 391 | 392 | # W-MSA/SW-MSA 393 | # nW*B, window_size*window_size, C 394 | attn_windows = self.attn(x_windows, train=train, mask=self.attn_mask) 395 | 396 | # merge windows 397 | attn_windows = jnp.reshape( 398 | attn_windows, 399 | (-1, self.window_size, self.window_size, C), 400 | ) 401 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 402 | 403 | # reverse cyclic shift 404 | if self.shift_size > 0: 405 | x = jnp.roll( 406 | shifted_x, 407 | shift=(self.shift_size, self.shift_size), 408 | axis=(1, 2), 409 | ) 410 | else: 411 | x = shifted_x 412 | x = jnp.reshape(x, (B, H * W, C)) 413 | x = shortcut + self.drop_path(self.norm1(x), deterministic=not train) 414 | 415 | # FFN 416 | x = x + self.drop_path(self.norm2(self.mlp(x, train)), deterministic=not train) 417 | return x 418 | 419 | 420 | class PatchMerging(linen.Module): 421 | r"""Patch Merging Layer. 422 | 423 | Args: 424 | input_resolution (tuple[int]): Resolution of input feature. 425 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 426 | """ 427 | 428 | input_resolution: tuple[int, int] 429 | norm_layer: Callable = linen.LayerNorm 430 | dtype: jt.DTypeLike = jnp.float32 431 | 432 | @linen.compact 433 | def __call__(self, x): 434 | """ 435 | x: B, H*W, C 436 | """ 437 | H, W = self.input_resolution 438 | B, L, C = x.shape 439 | 440 | x = jnp.reshape(x, (B, H // 2, 2, W // 2, 2, C)) # B H/2 nH W/2 nW C 441 | x = jnp.transpose(x, (0, 1, 3, 4, 2, 5)) # B H/2 W/2 nW nH C 442 | x = jnp.reshape(x, (B, (H // 2) * (W // 2), 4 * C)) # B H/2*W/2 4*C 443 | 444 | x = linen.Dense(2 * C, use_bias=False, dtype=self.dtype)(x) 445 | x = self.norm_layer()(x) 446 | 447 | return x 448 | 449 | 450 | class BasicLayer(linen.Module): 451 | """A basic Swin Transformer layer for one stage. 452 | 453 | Args: 454 | dim (int): Number of input channels. 455 | input_resolution (tuple[int]): Input resolution. 456 | depth (int): Number of blocks. 457 | num_heads (int): Number of attention heads. 458 | window_size (int): Local window size. 459 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 460 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 461 | drop (float, optional): Dropout rate. Default: 0.0 462 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 463 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 464 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 465 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 466 | pretrained_window_size (int): Local window size in pre-training. 467 | """ 468 | 469 | dim: int 470 | input_resolution: tuple[int, int] 471 | depth: int 472 | num_heads: int 473 | window_size: int 474 | mlp_ratio: float = 4.0 475 | qkv_bias: bool = True 476 | drop_ratio: float = 0.0 477 | attn_drop_ratio: float = 0.0 478 | drop_path_ratio: Union[float, tuple[float, ...]] = 0.0 479 | norm_layer: Callable = linen.LayerNorm 480 | downsample: Optional[Callable] = None 481 | pretrained_window_size: int = 0 482 | dtype: jt.DTypeLike = jnp.float32 483 | 484 | @linen.compact 485 | def __call__(self, x, train: bool): 486 | for i in range(self.depth): 487 | window_size = self.window_size 488 | shift_size = 0 if (i % 2 == 0) else window_size // 2 489 | drop_path_ratio = ( 490 | self.drop_path_ratio[i] 491 | if isinstance(self.drop_path_ratio, tuple) 492 | else self.drop_path_ratio 493 | ) 494 | 495 | # if window size is larger than input resolution, we don't partition windows 496 | if min(self.input_resolution) <= window_size: 497 | shift_size = 0 498 | window_size = min(self.input_resolution) 499 | 500 | x = SwinTransformerBlock( 501 | dim=self.dim, 502 | input_resolution=self.input_resolution, 503 | num_heads=self.num_heads, 504 | window_size=window_size, 505 | shift_size=shift_size, 506 | mlp_ratio=self.mlp_ratio, 507 | qkv_bias=self.qkv_bias, 508 | drop_ratio=self.drop_ratio, 509 | attn_drop_ratio=self.attn_drop_ratio, 510 | drop_path_ratio=drop_path_ratio, 511 | norm_layer=self.norm_layer, 512 | pretrained_window_size=self.pretrained_window_size, 513 | dtype=self.dtype, 514 | )(x, train) 515 | 516 | # patch merging layer 517 | if self.downsample is not None: 518 | x = self.downsample( 519 | self.input_resolution, 520 | norm_layer=self.norm_layer, 521 | dtype=self.dtype, 522 | )(x) 523 | return x 524 | 525 | 526 | class PatchEmbed(linen.Module): 527 | r"""Image to Patch Embedding 528 | 529 | Args: 530 | patch_size (int): Patch token size. Default: 4. 531 | embed_dim (int): Number of linear projection output channels. Default: 96. 532 | norm_layer (nn.Module, optional): Normalization layer. Default: None 533 | """ 534 | 535 | patch_size: int = 4 536 | embed_dim: int = 96 537 | norm_layer: Optional[Callable] = None 538 | dtype: jt.DTypeLike = jnp.float32 539 | 540 | @linen.compact 541 | def __call__(self, x): 542 | B, _, _, _ = x.shape 543 | patch_size = (self.patch_size, self.patch_size) 544 | 545 | x = linen.Conv( 546 | self.embed_dim, 547 | kernel_size=patch_size, 548 | strides=patch_size, 549 | dtype=self.dtype, 550 | )(x) 551 | x = jnp.reshape(x, (B, -1, self.embed_dim)) 552 | if self.norm_layer is not None: 553 | x = self.norm_layer()(x) 554 | return x 555 | 556 | 557 | class SwinTransformerV2(linen.Module): 558 | r"""Swin Transformer 559 | A JAX/Flax impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 560 | https://arxiv.org/pdf/2103.14030 561 | 562 | Args: 563 | image_size (int | tuple(int)): Input image size. Default 224 564 | patch_size (int | tuple(int)): Patch size. Default: 4 565 | in_chans (int): Number of input image channels. Default: 3 566 | num_classes (int): Number of classes for classification head. Default: 1000 567 | embed_dim (int): Patch embedding dimension. Default: 96 568 | depths (tuple(int)): Depth of each Swin Transformer layer. 569 | num_heads (tuple(int)): Number of attention heads in different layers. 570 | window_size (int): Window size. Default: 7 571 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 572 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 573 | drop_rate (float): Dropout rate. Default: 0 574 | attn_drop_rate (float): Attention dropout rate. Default: 0 575 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 576 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 577 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 578 | pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. 579 | """ 580 | 581 | image_size: int = 224 582 | patch_size: int = 4 583 | in_chans: int = 3 584 | num_classes: int = 1000 585 | 586 | embed_dim: int = 96 587 | depths: tuple[int, ...] = (2, 2, 6, 2) 588 | num_heads: tuple[int, ...] = (3, 6, 12, 24) 589 | 590 | window_size: int = 7 591 | mlp_ratio: float = 4.0 592 | qkv_bias: bool = True 593 | 594 | drop_rate: float = 0.0 595 | attn_drop_rate: float = 0.0 596 | drop_path_rate: float = 0.1 597 | 598 | norm_layer: Callable = linen.LayerNorm 599 | patch_norm: bool = True 600 | 601 | pretrained_window_sizes: tuple[int, ...] = (0, 0, 0, 0) 602 | 603 | layer_norm_eps: float = 1e-5 604 | dtype: jt.DTypeLike = jnp.float32 605 | 606 | def setup(self): 607 | depths = self.depths 608 | num_layers = len(depths) 609 | norm_layer = partial( 610 | self.norm_layer, 611 | epsilon=self.layer_norm_eps, 612 | dtype=self.dtype, 613 | ) 614 | 615 | patch_resolution = self.image_size // self.patch_size 616 | patches_resolution = [patch_resolution, patch_resolution] 617 | 618 | # split image into non-overlapping patches 619 | self.patch_embed = PatchEmbed( 620 | patch_size=self.patch_size, 621 | embed_dim=self.embed_dim, 622 | norm_layer=norm_layer if self.patch_norm else None, 623 | dtype=self.dtype, 624 | ) 625 | 626 | self.pos_drop = linen.Dropout(rate=self.drop_rate) 627 | 628 | # stochastic depth with linear decay 629 | dpr = [float(x) for x in np.linspace(0, self.drop_path_rate, sum(depths))] 630 | 631 | # build layers 632 | swin_body = [] 633 | for i_layer in range(num_layers): 634 | dpr_slice = tuple(dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])]) 635 | layer = BasicLayer( 636 | dim=int(self.embed_dim * 2**i_layer), 637 | input_resolution=( 638 | patches_resolution[0] // (2**i_layer), 639 | patches_resolution[1] // (2**i_layer), 640 | ), 641 | depth=depths[i_layer], 642 | num_heads=self.num_heads[i_layer], 643 | window_size=self.window_size, 644 | mlp_ratio=self.mlp_ratio, 645 | qkv_bias=self.qkv_bias, 646 | drop_ratio=self.drop_rate, 647 | attn_drop_ratio=self.attn_drop_rate, 648 | drop_path_ratio=dpr_slice, 649 | norm_layer=norm_layer, 650 | downsample=PatchMerging if (i_layer < num_layers - 1) else None, 651 | pretrained_window_size=self.pretrained_window_sizes[i_layer], 652 | dtype=self.dtype, 653 | ) 654 | swin_body.append(layer) 655 | self.swin_body = swin_body 656 | 657 | self.norm = norm_layer() 658 | self.head = ( 659 | linen.Dense(self.num_classes, dtype=self.dtype) 660 | if self.num_classes > 0 661 | else lambda x: x 662 | ) 663 | 664 | def __call__(self, x, train: bool = False): 665 | x = self.patch_embed(x) 666 | x = self.pos_drop(x, deterministic=not train) 667 | 668 | for layer in self.swin_body: 669 | x = layer(x, train=train) 670 | 671 | x = self.norm(x) 672 | x = jnp.mean(x, axis=(1,)) 673 | x = self.head(x) 674 | return x 675 | 676 | @classmethod 677 | def build(cls, config, **kwargs): 678 | config = dataclasses.asdict(config) 679 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 680 | return cls(**config) 681 | 682 | def extend_parser(self, parser): 683 | parser.set_defaults(image_size=self.image_size) 684 | parser.set_defaults(patch_size=self.patch_size) 685 | parser.add_argument( 686 | "--window-size", 687 | default=self.window_size, 688 | help="SwinV2 window size", 689 | type=int, 690 | ) 691 | parser.add_argument( 692 | "--drop-path-rate", 693 | default=self.drop_path_rate, 694 | help="Stochastic depth rate", 695 | type=float, 696 | ) 697 | return parser 698 | 699 | @staticmethod 700 | def get_simmim_orbax_txs(): 701 | # SimMIM checkpoint have no head params - don't try to restore them. 702 | # All the other params we care about are under the "encoder" subsection 703 | regex = r"(?!model/params/head)model/params/(.*)" 704 | action = r"model/params/encoder/\1" 705 | return [(regex, action)] 706 | 707 | def should_decay(self, path, _): 708 | is_kernel = path[-1].key == "kernel" 709 | verdict = is_kernel 710 | return verdict 711 | 712 | 713 | def swinv2_tiny(): 714 | config = { 715 | "embed_dim": 96, 716 | "depths": (2, 2, 6, 2), 717 | "num_heads": (3, 6, 12, 24), 718 | } 719 | return SwinTransformerV2(**config) 720 | 721 | 722 | def swinv2_base(): 723 | config = { 724 | "embed_dim": 128, 725 | "depths": (2, 2, 18, 2), 726 | "num_heads": (4, 8, 16, 32), 727 | } 728 | return SwinTransformerV2(**config) 729 | 730 | 731 | def swinv2_large(): 732 | config = { 733 | "embed_dim": 192, 734 | "depths": (2, 2, 18, 2), 735 | "num_heads": (6, 12, 24, 48), 736 | } 737 | return SwinTransformerV2(**config) 738 | -------------------------------------------------------------------------------- /Models/VGG.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import jax.numpy as jnp 4 | import jax.typing as jt 5 | from flax import linen 6 | 7 | 8 | class VGGStage(linen.Module): 9 | filters: int = 64 10 | kernel_sizes: tuple[int] = (3,) 11 | dtype: jt.DTypeLike = jnp.float32 12 | 13 | @linen.compact 14 | def __call__(self, x): 15 | for kernel_size in self.kernel_sizes: 16 | k_init = linen.initializers.normal(stddev=(10e-2) ** 0.5) 17 | x = linen.Conv( 18 | self.filters, 19 | kernel_size=(kernel_size, kernel_size), 20 | strides=(1, 1), 21 | kernel_init=k_init, 22 | dtype=self.dtype, 23 | )(x) 24 | x = linen.relu(x) 25 | x = linen.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 26 | return x 27 | 28 | 29 | class VGGNetwork(linen.Module): 30 | patch_size: int = 3 31 | num_classes: int = 1000 32 | 33 | filters: tuple[int] = (64, 128, 256, 512, 512) 34 | kernel_sizes: tuple[tuple[int]] = ((3,), (3,), (3, 3), (3, 3), (3, 3)) 35 | 36 | dtype: jt.DTypeLike = jnp.float32 37 | 38 | def setup(self): 39 | if self.num_classes > 0: 40 | k_init = linen.initializers.normal(stddev=(10e-2) ** 0.5) 41 | self.head = linen.Dense( 42 | self.num_classes, 43 | kernel_init=k_init, 44 | dtype=self.dtype, 45 | ) 46 | else: 47 | self.head = lambda x: x 48 | 49 | @linen.compact 50 | def __call__(self, x, train: bool = False): 51 | for filters, kernel_sizes in zip(self.filters, self.kernel_sizes): 52 | x = VGGStage( 53 | filters=filters, 54 | kernel_sizes=kernel_sizes, 55 | dtype=self.dtype, 56 | )(x) 57 | 58 | b, h, w, c = x.shape 59 | x = jnp.reshape(x, (b, h * w * c)) 60 | 61 | k_init = linen.initializers.normal(stddev=(10e-2) ** 0.5) 62 | x = linen.Dense(4096, kernel_init=k_init, dtype=self.dtype)(x) 63 | x = linen.relu(x) 64 | x = linen.Dropout(0.5, deterministic=not train)(x) 65 | x = linen.Dense(4096, kernel_init=k_init, dtype=self.dtype)(x) 66 | x = linen.relu(x) 67 | x = linen.Dropout(0.5, deterministic=not train)(x) 68 | 69 | x = self.head(x) 70 | return x 71 | 72 | @classmethod 73 | def build(cls, config, **kwargs): 74 | config = dataclasses.asdict(config) 75 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 76 | return cls(**config) 77 | 78 | def extend_parser(self, parser): 79 | parser.set_defaults(patch_size=self.patch_size) 80 | return parser 81 | 82 | @staticmethod 83 | def get_simmim_orbax_txs(): 84 | # SimMIM checkpoint have no head params - don't try to restore them. 85 | # All the other params we care about are under the "encoder" subsection 86 | regex = r"(?!model/params/head)model/params/(.*)" 87 | action = r"model/params/encoder/\1" 88 | return [(regex, action)] 89 | 90 | def should_decay(self, path, _): 91 | is_kernel = path[-1].key == "kernel" 92 | verdict = is_kernel 93 | return verdict 94 | 95 | 96 | def vgg11(): 97 | config = { 98 | "kernel_sizes": ( 99 | (3,), 100 | (3,), 101 | (3, 3), 102 | (3, 3), 103 | (3, 3), 104 | ) 105 | } 106 | return VGGNetwork(**config) 107 | 108 | 109 | def vgg13(): 110 | config = { 111 | "kernel_sizes": ( 112 | (3, 3), 113 | (3, 3), 114 | (3, 3), 115 | (3, 3), 116 | (3, 3), 117 | ) 118 | } 119 | return VGGNetwork(**config) 120 | 121 | 122 | def vgg19(): 123 | config = { 124 | "kernel_sizes": ( 125 | (3, 3), 126 | (3, 3), 127 | (3, 3, 3, 3), 128 | (3, 3, 3, 3), 129 | (3, 3, 3, 3), 130 | ) 131 | } 132 | return VGGNetwork(**config) 133 | -------------------------------------------------------------------------------- /Models/ViT.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from typing import Callable, Optional 4 | 5 | import jax.numpy as jnp 6 | import jax.typing as jt 7 | import numpy as np 8 | from flax import linen 9 | 10 | 11 | class LayerNorm(linen.Module): 12 | epsilon: float = 1e-6 13 | use_bias: bool = True 14 | force_float32_reductions: bool = True 15 | 16 | dtype: jt.DTypeLike = jnp.float32 17 | 18 | @linen.compact 19 | def __call__(self, x): 20 | scale = self.param("scale", linen.initializers.zeros_init(), (x.shape[-1])) 21 | 22 | dtype = self.dtype 23 | if self.force_float32_reductions: 24 | dtype = jnp.promote_types(dtype, jnp.float32) 25 | x = x.astype(dtype) 26 | 27 | mean = jnp.mean(x, axis=-1, keepdims=True) 28 | 29 | var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) 30 | var = jnp.maximum(0.0, var - jnp.square(mean)) 31 | mul = jnp.reciprocal(jnp.sqrt(var + self.epsilon)) 32 | 33 | centered_inputs = x - mean 34 | normed_inputs = centered_inputs * mul 35 | 36 | scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1)) 37 | normed_inputs = normed_inputs * (1 + scale) 38 | if self.use_bias: 39 | bias = self.param("bias", linen.initializers.zeros_init(), (x.shape[-1])) 40 | bias = jnp.expand_dims(bias, axis=range(len(x.shape) - 1)) 41 | normed_inputs = normed_inputs + bias 42 | return normed_inputs.astype(self.dtype) 43 | 44 | 45 | class Attention(linen.Module): 46 | dim: int 47 | num_heads: int 48 | qkv_bias: bool = True 49 | attn_drop_ratio: float = 0.0 50 | proj_drop_ratio: float = 0.0 51 | 52 | dtype: jt.DTypeLike = jnp.float32 53 | 54 | def setup(self): 55 | self.qkv = linen.Dense(self.dim * 3, use_bias=self.qkv_bias, dtype=self.dtype) 56 | self.attn_drop = linen.Dropout(self.attn_drop_ratio) 57 | self.proj = linen.Dense(self.dim, dtype=self.dtype) 58 | self.proj_drop = linen.Dropout(self.proj_drop_ratio) 59 | self.softmax = partial(linen.softmax, axis=-1) 60 | 61 | def __call__(self, x, train: bool = False): 62 | B, N, C = x.shape 63 | qkv = self.qkv(x) 64 | qkv = jnp.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) 65 | qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) 66 | 67 | q, k, v = (qkv[0], qkv[1], qkv[2]) 68 | 69 | q = q / jnp.sqrt(q.shape[-1]).astype(q.dtype) 70 | attn = q @ jnp.transpose(k, (0, 1, 3, 2)) 71 | 72 | attn = self.softmax(attn.astype(jnp.float32)).astype(self.dtype) 73 | attn = self.attn_drop(attn, deterministic=not train) 74 | 75 | x = jnp.transpose(attn @ v, (0, 2, 1, 3)) 76 | x = jnp.reshape((x), (B, N, C)) 77 | x = self.proj(x) 78 | x = self.proj_drop(x, deterministic=not train) 79 | return x 80 | 81 | 82 | class MLP(linen.Module): 83 | hidden_features: int 84 | act_layer: Callable = linen.gelu 85 | drop_ratio: float = 0.0 86 | 87 | dtype: jt.DTypeLike = jnp.float32 88 | 89 | @linen.compact 90 | def __call__(self, x, train: bool): 91 | out_dim = x.shape[-1] 92 | 93 | x = linen.Dense(self.hidden_features, dtype=self.dtype)(x) 94 | x = self.act_layer(x) 95 | x = linen.Dropout(self.drop_ratio)(x, deterministic=not train) 96 | x = linen.Dense(out_dim, dtype=self.dtype)(x) 97 | return x 98 | 99 | 100 | class PosEmbed(linen.Module): 101 | dtype: jt.DTypeLike = jnp.float32 102 | 103 | @linen.compact 104 | def __call__(self, x): 105 | _, L, C = x.shape 106 | pos_emb_init = linen.initializers.normal(stddev=1 / np.sqrt(C)) 107 | pos_emb = self.param("pos_emb", pos_emb_init, (1, L, C)) 108 | pos_emb = pos_emb.astype(self.dtype) 109 | x = x + pos_emb 110 | return x 111 | 112 | 113 | class PatchEmbed(linen.Module): 114 | r"""Image to Patch Embedding 115 | 116 | Args: 117 | patch_size (int): Patch token size. Default: 16. 118 | embed_dim (int): Number of linear projection output channels. Default: 96. 119 | norm_layer (nn.Module, optional): Normalization layer. Default: None 120 | """ 121 | 122 | patch_size: int = 16 123 | embed_dim: int = 768 124 | 125 | norm_layer: Optional[Callable] = None 126 | 127 | dtype: jt.DTypeLike = jnp.float32 128 | 129 | @linen.compact 130 | def __call__(self, x): 131 | B, _, _, _ = x.shape 132 | patch_size = (self.patch_size, self.patch_size) 133 | 134 | x = linen.Conv( 135 | self.embed_dim, 136 | kernel_size=patch_size, 137 | strides=patch_size, 138 | dtype=self.dtype, 139 | )(x) 140 | x = jnp.reshape(x, (B, -1, self.embed_dim)) 141 | if self.norm_layer is not None: 142 | x = self.norm_layer()(x) 143 | return x 144 | 145 | 146 | class VisionTransformerBlock(linen.Module): 147 | mlp_dim: int 148 | num_heads: int 149 | drop_path_ratio: float 150 | 151 | norm_layer: Callable 152 | 153 | dtype: jt.DTypeLike = jnp.float32 154 | 155 | @linen.compact 156 | def __call__(self, x, train: bool = False): 157 | shortcut = x 158 | 159 | x = self.norm_layer()(x) 160 | x = Attention( 161 | dim=x.shape[-1], 162 | num_heads=self.num_heads, 163 | dtype=self.dtype, 164 | )(x, train=train) 165 | x = linen.Dropout( 166 | rate=self.drop_path_ratio, 167 | broadcast_dims=(1, 2), 168 | )(x, deterministic=not train) 169 | x = shortcut + x 170 | 171 | shortcut = x 172 | x = self.norm_layer()(x) 173 | x = MLP(hidden_features=self.mlp_dim, dtype=self.dtype)(x, train=train) 174 | x = linen.Dropout( 175 | rate=self.drop_path_ratio, 176 | broadcast_dims=(1, 2), 177 | )(x, deterministic=not train) 178 | x = shortcut + x 179 | return x 180 | 181 | 182 | def make_norm_layer(layer_name): 183 | if layer_name == "reparam_layernorm": 184 | return LayerNorm 185 | elif layer_name == "linen_layernorm": 186 | return linen.LayerNorm 187 | 188 | 189 | class VisionTransformer(linen.Module): 190 | patch_size: int = 16 191 | num_classes: int = 1000 192 | 193 | num_layers: int = 12 194 | embed_dim: int = 768 195 | mlp_dim: int = 3072 196 | num_heads: int = 12 197 | 198 | drop_path_rate: float = 0.1 199 | 200 | norm_layer: str = "linen_layernorm" 201 | 202 | layer_norm_eps: float = 1e-5 203 | dtype: jt.DTypeLike = jnp.float32 204 | 205 | def setup(self): 206 | norm_layer = make_norm_layer(self.norm_layer) 207 | norm_layer = partial( 208 | norm_layer, 209 | epsilon=self.layer_norm_eps, 210 | dtype=self.dtype, 211 | ) 212 | 213 | self.patch_embed = PatchEmbed( 214 | patch_size=self.patch_size, 215 | embed_dim=self.embed_dim, 216 | dtype=self.dtype, 217 | ) 218 | 219 | self.pos_emb = PosEmbed(dtype=self.dtype) 220 | 221 | # stochastic depth with linear decay 222 | dpr = np.linspace(0, self.drop_path_rate, self.num_layers) 223 | dpr = [float(x) for x in dpr] 224 | 225 | vit_body = [] 226 | for i in range(self.num_layers): 227 | layer = VisionTransformerBlock( 228 | mlp_dim=self.mlp_dim, 229 | num_heads=self.num_heads, 230 | drop_path_ratio=dpr[i], 231 | norm_layer=norm_layer, 232 | dtype=self.dtype, 233 | ) 234 | vit_body.append(layer) 235 | self.vit_body = vit_body 236 | 237 | self.norm = norm_layer() 238 | self.head = ( 239 | linen.Dense(self.num_classes, dtype=self.dtype) 240 | if self.num_classes > 0 241 | else lambda x: x 242 | ) 243 | 244 | def __call__(self, x, train: bool = False): 245 | x = self.patch_embed(x) 246 | 247 | x = self.pos_emb(x) 248 | 249 | for layer in self.vit_body: 250 | x = layer(x, train=train) 251 | 252 | x = self.norm(x) 253 | x = jnp.mean(x, axis=(1,)) 254 | x = self.head(x) 255 | return x 256 | 257 | @classmethod 258 | def build(cls, config, **kwargs): 259 | config = dataclasses.asdict(config) 260 | config = {key: kwargs[key] if key in kwargs else config[key] for key in config} 261 | return cls(**config) 262 | 263 | def extend_parser(self, parser): 264 | parser.set_defaults(patch_size=self.patch_size) 265 | parser.add_argument( 266 | "--drop-path-rate", 267 | default=self.drop_path_rate, 268 | help="Stochastic depth rate", 269 | type=float, 270 | ) 271 | 272 | parser.add_argument( 273 | "--norm-layer", 274 | default=self.norm_layer, 275 | help="Normalization layer", 276 | type=str, 277 | ) 278 | return parser 279 | 280 | @staticmethod 281 | def get_simmim_orbax_txs(): 282 | # SimMIM checkpoint have no head params - don't try to restore them. 283 | # All the other params we care about are under the "encoder" subsection 284 | regex = r"(?!model/params/head)model/params/(.*)" 285 | action = r"model/params/encoder/\1" 286 | return [(regex, action)] 287 | 288 | def should_decay(self, path, _): 289 | is_kernel = path[-1].key == "kernel" 290 | is_scale = path[-1].key == "scale" 291 | is_scale = is_scale and self.norm_layer == "reparam_layernorm" 292 | verdict = is_kernel or is_scale 293 | return verdict 294 | 295 | 296 | def vit_small(): 297 | config = { 298 | "num_layers": 12, 299 | "embed_dim": 384, 300 | "mlp_dim": 1536, 301 | "num_heads": 6, 302 | } 303 | return VisionTransformer(**config) 304 | 305 | 306 | def vit_base(): 307 | config = { 308 | "num_layers": 12, 309 | "embed_dim": 768, 310 | "mlp_dim": 3072, 311 | "num_heads": 12, 312 | } 313 | return VisionTransformer(**config) 314 | 315 | 316 | def vit_large(): 317 | config = { 318 | "num_layers": 24, 319 | "embed_dim": 1024, 320 | "mlp_dim": 4096, 321 | "num_heads": 16, 322 | } 323 | return VisionTransformer(**config) 324 | -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ConvNext import convnext_base, convnext_small, convnext_tiny 2 | from .EVA02 import eva02_base, eva02_large, eva02_small 3 | from .HiViT import hivit_base, hivit_small, hivit_tiny 4 | from .SimMIM import ( 5 | simmim_convnext_base, 6 | simmim_convnext_small, 7 | simmim_convnext_tiny, 8 | simmim_eva02_base, 9 | simmim_eva02_large, 10 | simmim_eva02_small, 11 | simmim_hivit_small, 12 | simmim_hivit_tiny, 13 | simmim_swinv2_base, 14 | simmim_swinv2_large, 15 | simmim_swinv2_tiny, 16 | simmim_vit_base, 17 | simmim_vit_large, 18 | simmim_vit_small, 19 | ) 20 | from .SwinV2 import swinv2_base, swinv2_large, swinv2_tiny 21 | from .VGG import vgg11, vgg13, vgg19 22 | from .ViT import vit_base, vit_large, vit_small 23 | 24 | model_registry = { 25 | "swinv2_tiny": swinv2_tiny, 26 | "swinv2_base": swinv2_base, 27 | "swinv2_large": swinv2_large, 28 | "hivit_tiny": hivit_tiny, 29 | "hivit_small": hivit_small, 30 | "hivit_base": hivit_base, 31 | "vit_small": vit_small, 32 | "vit_base": vit_base, 33 | "vit_large": vit_large, 34 | "eva02_small": eva02_small, 35 | "eva02_base": eva02_base, 36 | "eva02_large": eva02_large, 37 | "convnext_tiny": convnext_tiny, 38 | "convnext_small": convnext_small, 39 | "convnext_base": convnext_base, 40 | "vgg11": vgg11, 41 | "vgg13": vgg13, 42 | "vgg19": vgg19, 43 | "simmim_swinv2_tiny": simmim_swinv2_tiny, 44 | "simmim_swinv2_base": simmim_swinv2_base, 45 | "simmim_swinv2_large": simmim_swinv2_large, 46 | "simmim_vit_small": simmim_vit_small, 47 | "simmim_vit_base": simmim_vit_base, 48 | "simmim_vit_large": simmim_vit_large, 49 | "simmim_eva02_small": simmim_eva02_small, 50 | "simmim_eva02_base": simmim_eva02_base, 51 | "simmim_eva02_large": simmim_eva02_large, 52 | "simmim_hivit_tiny": simmim_hivit_tiny, 53 | "simmim_hivit_small": simmim_hivit_small, 54 | "simmim_convnext_tiny": simmim_convnext_tiny, 55 | "simmim_convnext_small": simmim_convnext_small, 56 | "simmim_convnext_base": simmim_convnext_base, 57 | } 58 | -------------------------------------------------------------------------------- /Translators/TFRecord.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from PIL import Image, ImageOps 8 | 9 | 10 | def _bytes_feature(value): 11 | """Returns a bytes_list from a string / byte.""" 12 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) 13 | 14 | 15 | def _int64_feature(value): 16 | """Returns an int64_list from a bool / enum / int / uint.""" 17 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 18 | 19 | 20 | def _int64_list_feature(value): 21 | """Returns an int64_list from a bool / enum / int / uint.""" 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 23 | 24 | 25 | def load_existing_labels(label_mapping_filename): 26 | """Load existing label to index mapping from a file.""" 27 | if os.path.exists(label_mapping_filename): 28 | with open(label_mapping_filename, "r") as mapping_file: 29 | return json.load(mapping_file), True 30 | return {}, False 31 | 32 | 33 | def save_label_mapping(label_mapping_filename, label_to_index): 34 | """Save label to index mapping to a file.""" 35 | with open(label_mapping_filename, "w") as mapping_file: 36 | json.dump(label_to_index, mapping_file, indent=4) 37 | 38 | 39 | def prepare_image(image_path, target_size): 40 | """Handle EXIF orientation, colorpsace, alpha, resizing""" 41 | image = Image.open(image_path) 42 | image = ImageOps.exif_transpose(image) 43 | image = image.convert(mode="RGBA") 44 | 45 | canvas = Image.new("RGBA", image.size, (255, 255, 255)) 46 | canvas.alpha_composite(image) 47 | image = canvas.convert("RGB") 48 | 49 | # Pad image to square 50 | max_dim = max(image.size) 51 | pad_left = (max_dim - image.size[0]) // 2 52 | pad_top = (max_dim - image.size[1]) // 2 53 | 54 | padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) 55 | padded_image.paste(image, (pad_left, pad_top)) 56 | 57 | # Resize 58 | if max_dim != target_size: 59 | padded_image = padded_image.resize((target_size, target_size), Image.LANCZOS) 60 | return padded_image 61 | 62 | 63 | def create_tfrecord(dataset_folder, output_path, split_ratio=0.7, img_size=512): 64 | """Create a TFRecord file from images and label files and generate dataset JSON file.""" 65 | dataset_name = os.path.basename(os.path.normpath(dataset_folder)) 66 | 67 | os.makedirs(f"{output_path}/record_shards_train", exist_ok=True) 68 | os.makedirs(f"{output_path}/record_shards_val", exist_ok=True) 69 | 70 | train_writer = tf.io.TFRecordWriter(f"{output_path}/record_shards_train/{dataset_name}_train.tfrecord") 71 | val_writer = tf.io.TFRecordWriter(f"{output_path}/record_shards_val/{dataset_name}_val.tfrecord") 72 | 73 | image_files = [f for f in os.listdir(dataset_folder) if f.lower().endswith(("png", "jpg", "jpeg"))] 74 | label_files = [f for f in os.listdir(dataset_folder) if f.lower().endswith("txt")] 75 | 76 | image_files.sort() 77 | label_files.sort() 78 | 79 | # Create a set of image filenames without extensions for quick lookup 80 | image_file_set = set(os.path.splitext(f)[0].lower() for f in image_files) 81 | 82 | # Load existing label mapping 83 | label_mapping_filename = f"{output_path}/{dataset_name}_labels.json" 84 | label_to_index, mapping_exists = load_existing_labels(label_mapping_filename) 85 | index_to_label = [None] * (len(label_to_index) + 1) 86 | 87 | if mapping_exists: 88 | for label, index in label_to_index.items(): 89 | index_to_label[index] = label 90 | 91 | # Collect new labels and update mapping 92 | new_labels = set() 93 | 94 | for label_file in label_files: 95 | label_path = os.path.join(dataset_folder, label_file) 96 | image_name = os.path.splitext(label_file)[0].lower() 97 | 98 | # Check if there's a corresponding image file 99 | if image_name not in image_file_set: 100 | print(f"Skipping label file {image_name} because no corresponding image file found.") 101 | continue 102 | 103 | # Read labels and collect new labels 104 | with open(label_path, "r") as f: 105 | labels = f.read().strip().split(", ") 106 | new_labels.update(labels) 107 | 108 | # Update label to index mapping 109 | for label in new_labels: 110 | if label not in label_to_index: 111 | new_index = len(label_to_index) 112 | label_to_index[label] = new_index 113 | index_to_label.append(label) 114 | 115 | # Create a set of label filenames (without extension) for quick lookup 116 | label_file_set = set(os.path.splitext(f)[0].lower() for f in label_files) 117 | 118 | # Number of unique tags 119 | num_classes = len(label_to_index) 120 | 121 | # Number of valid samples 122 | num_samples = len([f for f in image_files if os.path.splitext(f)[0].lower() in label_file_set]) 123 | 124 | # Number of training and validation samples 125 | num_train_samples = int(num_samples * split_ratio) 126 | num_val_samples = num_samples - num_train_samples 127 | 128 | for idx, image_file in enumerate(image_files): 129 | image_name = os.path.splitext(image_file)[0].lower() 130 | label_file = f"{image_name}.txt" 131 | label_path = os.path.join(dataset_folder, label_file) 132 | 133 | # Check if there's a corresponding label file 134 | if image_name not in label_file_set: 135 | print(f"Skipping image file {image_name} because no corresponding label file was found.") 136 | continue 137 | 138 | # Read labels and convert to indices 139 | with open(label_path, "r") as f: 140 | labels = f.read().strip().split(", ") 141 | label_indices = [label_to_index[label] for label in labels if label in label_to_index] 142 | 143 | # Read image 144 | image_path = os.path.join(dataset_folder, image_file) 145 | image = prepare_image(image_path, img_size) 146 | image_np = np.array(image) 147 | 148 | image_id = hash(image_name) % 2**63 149 | 150 | # Create a feature 151 | feature = { 152 | "image_id": _int64_feature(image_id), 153 | "image_bytes": _bytes_feature(image_np), 154 | "label_indexes": _int64_list_feature(label_indices), 155 | } 156 | 157 | # Protocol buffers 158 | protobuf = tf.train.Example(features=tf.train.Features(feature=feature)) 159 | serialized_example = protobuf.SerializeToString() 160 | 161 | # Write the buffer to the TFRecord files 162 | if idx < num_train_samples: 163 | train_writer.write(serialized_example) 164 | else: 165 | val_writer.write(serialized_example) 166 | 167 | train_writer.close() 168 | val_writer.close() 169 | 170 | dataset_info = { 171 | "num_classes": num_classes, 172 | "train_samples": num_train_samples, 173 | "val_samples": num_val_samples, 174 | } 175 | 176 | json_filename = f"{output_path}/{dataset_name}.json" 177 | with open(json_filename, "w") as json_file: 178 | json.dump(dataset_info, json_file, indent=4) 179 | 180 | # Save updated label to index mapping 181 | save_label_mapping(label_mapping_filename, label_to_index) 182 | 183 | print(f"TFRecord files saved to {output_path}/record_shards_train and {output_path}/record_shards_val") 184 | print(f"Dataset JSON file saved to {json_filename}") 185 | print(f"Label mapping JSON file saved to {label_mapping_filename}") 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser(description="Create TFRecord file from images and label files") 190 | parser.add_argument( 191 | "--dataset-folder", 192 | type=str, 193 | help="Path to dataset folder containing both images and labels", 194 | ) 195 | parser.add_argument( 196 | "--output-path", 197 | type=str, 198 | help='Path to output files. Will place TFRecords into "record_shards_train" and "record_shards_val" folders', 199 | ) 200 | parser.add_argument( 201 | "--split-ratio", 202 | type=float, 203 | default=0.7, 204 | help="Ratio of training to total samples (default: 0.7)", 205 | ) 206 | parser.add_argument( 207 | "--img-size", 208 | type=int, 209 | default=512, 210 | help="Image size to resize all images to (default: 512)", 211 | ) 212 | 213 | args = parser.parse_args() 214 | 215 | # Use dataset folder as output if empty 216 | if args.output_path is None: 217 | args.output_path = args.dataset_folder 218 | 219 | create_tfrecord(args.dataset_folder, args.output_path, args.split_ratio, args.img_size) 220 | -------------------------------------------------------------------------------- /pretraining_loop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from typing import Any, Callable, Union 5 | 6 | import flax 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | import orbax.checkpoint 11 | import tensorflow as tf 12 | import wandb 13 | from clu import metrics 14 | from flax import jax_utils 15 | from flax.training import orbax_utils, train_state 16 | from tqdm import tqdm 17 | 18 | import Models 19 | from Generators.SimMIMGen import DataGenerator 20 | 21 | 22 | @flax.struct.dataclass 23 | class Metrics(metrics.Collection): 24 | loss: metrics.Metric 25 | 26 | 27 | class TrainState(train_state.TrainState): 28 | metrics: Metrics 29 | constants: Any 30 | 31 | 32 | def create_train_state( 33 | module, 34 | params_key, 35 | target_size: int, 36 | mask_input_size: int, 37 | num_classes: int, 38 | learning_rate: Union[float, Callable], 39 | optimizer_eps: float, 40 | grad_clip: float, 41 | weight_decay: float, 42 | ): 43 | """Creates an initial 'TrainState'.""" 44 | # initialize parameters by passing a template image 45 | variables = module.init( 46 | params_key, 47 | jnp.ones([1, target_size, target_size, 3]), 48 | mask=jnp.ones([1, mask_input_size, mask_input_size]), 49 | train=False, 50 | ) 51 | params = variables["params"] 52 | del variables["params"] 53 | constants = variables 54 | 55 | loss = metrics.Average.from_output("loss") 56 | collection = Metrics.create(loss=loss) 57 | 58 | wd_mask = jax.tree_util.tree_map_with_path(module.should_decay, params) 59 | tx = optax.lamb( 60 | learning_rate, 61 | weight_decay=weight_decay, 62 | eps=optimizer_eps, 63 | mask=wd_mask, 64 | ) 65 | tx = optax.chain(optax.clip_by_global_norm(grad_clip), tx) 66 | return TrainState.create( 67 | apply_fn=module.apply, 68 | params=params, 69 | tx=tx, 70 | metrics=collection.empty(), 71 | constants=constants, 72 | ) 73 | 74 | 75 | def train_step(state, batch, dropout_key): 76 | """Train for a single step.""" 77 | dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step) 78 | 79 | def loss_fn(params, **kwargs): 80 | loss, _ = state.apply_fn( 81 | {"params": params, **kwargs}, 82 | batch["images"], 83 | mask=batch["masks"], 84 | train=True, 85 | rngs={"dropout": dropout_train_key}, 86 | ) 87 | return loss 88 | 89 | grad_fn = jax.value_and_grad(loss_fn) 90 | loss, grads = grad_fn(state.params, **state.constants) 91 | grads = jax.lax.pmean(grads, axis_name="batch") 92 | state = state.apply_gradients(grads=grads) 93 | 94 | metric_updates = state.metrics.gather_from_model_output(loss=loss) 95 | metrics = state.metrics.merge(metric_updates) 96 | state = state.replace(metrics=metrics) 97 | return state 98 | 99 | 100 | def eval_step(*, state, batch): 101 | loss, _ = state.apply_fn( 102 | {"params": state.params, **state.constants}, 103 | batch["images"], 104 | mask=batch["masks"], 105 | train=False, 106 | ) 107 | 108 | metric_updates = state.metrics.gather_from_model_output(loss=loss) 109 | metrics = state.metrics.merge(metric_updates) 110 | state = state.replace(metrics=metrics) 111 | return state 112 | 113 | 114 | model_parser = argparse.ArgumentParser( 115 | description="Model variant to train", 116 | add_help=False, 117 | ) 118 | model_parser.add_argument( 119 | "--model-name", 120 | default="simmim_vit_small", 121 | help="Model variant to train", 122 | type=str, 123 | ) 124 | 125 | parser = argparse.ArgumentParser(description="Train a network") 126 | parser.add_argument( 127 | "--run-name", 128 | default=None, 129 | help="Run name. If left empty it gets autogenerated", 130 | type=str, 131 | ) 132 | parser.add_argument( 133 | "--wandb-project", 134 | default="tpu-tracking", 135 | help="WandB project", 136 | type=str, 137 | ) 138 | parser.add_argument( 139 | "--wandb-run-id", 140 | default=None, 141 | help="WandB run ID (8 chars code) to resume interrupted run", 142 | type=str, 143 | ) 144 | parser.add_argument( 145 | "--wandb-tags", 146 | nargs="*", 147 | help="Space separated list of tags for WandB", 148 | ) 149 | parser.add_argument( 150 | "--restore-params-ckpt", 151 | default="", 152 | help="Restore the parameters from the last step of the given orbax checkpoint. Must be an absolute path. WARNING: restores params only!", 153 | type=str, 154 | ) 155 | parser.add_argument( 156 | "--dataset-file", 157 | default="datasets/aibooru.json", 158 | help="JSON file with dataset specs", 159 | type=str, 160 | ) 161 | parser.add_argument( 162 | "--dataset-root", 163 | default="/home/smilingwolf/datasets", 164 | help="Dataset root, where the record_shards_train and record_shards_val folders are stored", 165 | type=str, 166 | ) 167 | parser.add_argument( 168 | "--checkpoints-root", 169 | default="/mnt/c/Users/SmilingWolf/Desktop/TFKeras/JAX/checkpoints", 170 | help="Checkpoints root, where the checkpoints will be stored following a // structure", 171 | type=str, 172 | ) 173 | parser.add_argument( 174 | "--checkpoints-keep", 175 | default=2, 176 | help="Number of best (by val_loss) checkpoints to keep. -1 to always keep the last checkpoint", 177 | type=int, 178 | ) 179 | parser.add_argument( 180 | "--epochs", 181 | default=50, 182 | help="Number of epochs to train for", 183 | type=int, 184 | ) 185 | parser.add_argument( 186 | "--warmup-epochs", 187 | default=5, 188 | help="Number of epochs to dedicate to linear warmup", 189 | type=int, 190 | ) 191 | parser.add_argument( 192 | "--batch-size", 193 | default=64, 194 | help="Per-device batch size", 195 | type=int, 196 | ) 197 | parser.add_argument( 198 | "--image-size", 199 | default=256, 200 | help="Image resolution in input to the network", 201 | type=int, 202 | ) 203 | parser.add_argument( 204 | "--patch-size", 205 | default=16, 206 | help="Size of the image patches", 207 | type=int, 208 | ) 209 | parser.add_argument( 210 | "--mask-patch-size", 211 | default=32, 212 | help="Size of the mask patches for SimMIM", 213 | type=int, 214 | ) 215 | parser.add_argument( 216 | "--learning-rate", 217 | default=0.001, 218 | help="Max learning rate", 219 | type=float, 220 | ) 221 | parser.add_argument( 222 | "--optimizer-eps", 223 | default=1e-6, 224 | help="Optimizer epsilon", 225 | type=float, 226 | ) 227 | parser.add_argument( 228 | "--grad-clip", 229 | default=1.0, 230 | help="Gradient clipping", 231 | type=float, 232 | ) 233 | parser.add_argument( 234 | "--weight-decay", 235 | default=0.0001, 236 | help="Weight decay", 237 | type=float, 238 | ) 239 | parser.add_argument( 240 | "--mixup-alpha", 241 | default=0.8, 242 | help="MixUp alpha", 243 | type=float, 244 | ) 245 | parser.add_argument( 246 | "--rotation-ratio", 247 | default=0.0, 248 | help="Rotation ratio as a fraction of PI", 249 | type=float, 250 | ) 251 | parser.add_argument( 252 | "--cutout-max-pct", 253 | default=0.1, 254 | help="Cutout area as a fraction of the total image area", 255 | type=float, 256 | ) 257 | parser.add_argument( 258 | "--cutout-patches", 259 | default=1, 260 | help="Number of cutout patches", 261 | type=int, 262 | ) 263 | model_arg, remaining = model_parser.parse_known_args() 264 | 265 | model_name = model_arg.model_name 266 | model_builder = Models.model_registry[model_name]() 267 | parser = model_builder.extend_parser(parser=parser) 268 | 269 | args = parser.parse_args(remaining) 270 | 271 | run_name = args.run_name 272 | if run_name is None: 273 | now = datetime.now() 274 | date_time = now.strftime("%Y_%m_%d_%Hh%Mm%Ss") 275 | run_name = f"{model_name}_{date_time}" 276 | 277 | checkpoints_root = args.checkpoints_root 278 | dataset_root = args.dataset_root 279 | with open(args.dataset_file) as f: 280 | dataset_specs = json.load(f) 281 | 282 | # Run params 283 | num_epochs = args.epochs 284 | warmup_epochs = args.warmup_epochs 285 | batch_size = args.batch_size 286 | compute_units = jax.device_count() 287 | global_batch_size = batch_size * compute_units 288 | restore_params_ckpt = args.restore_params_ckpt 289 | 290 | # Dataset params 291 | image_size = args.image_size 292 | num_classes = 0 293 | train_samples = dataset_specs["train_samples"] 294 | val_samples = dataset_specs["val_samples"] 295 | 296 | # Model hyperparams 297 | patch_size = args.patch_size 298 | learning_rate = args.learning_rate 299 | optimizer_eps = args.optimizer_eps 300 | grad_clip = args.grad_clip 301 | weight_decay = args.weight_decay 302 | 303 | # Augmentations hyperparams 304 | noise_level = 2 305 | mixup_alpha = args.mixup_alpha 306 | rotation_ratio = args.rotation_ratio 307 | cutout_max_pct = args.cutout_max_pct 308 | cutout_patches = args.cutout_patches 309 | random_resize_method = True 310 | mask_patch_size = args.mask_patch_size 311 | model_patch_size = patch_size 312 | mask_input_size = image_size // model_patch_size 313 | 314 | # WandB tracking 315 | train_config = {} 316 | train_config["model_name"] = model_name 317 | train_config["checkpoints_root"] = checkpoints_root 318 | train_config["dataset_root"] = dataset_root 319 | train_config["dataset_file"] = args.dataset_file 320 | train_config["num_epochs"] = num_epochs 321 | train_config["warmup_epochs"] = warmup_epochs 322 | train_config["batch_size"] = batch_size 323 | train_config["compute_units"] = compute_units 324 | train_config["global_batch_size"] = global_batch_size 325 | train_config["image_size"] = image_size 326 | train_config["num_classes"] = num_classes 327 | train_config["train_samples"] = train_samples 328 | train_config["val_samples"] = val_samples 329 | train_config["patch_size"] = patch_size 330 | train_config["learning_rate"] = learning_rate 331 | train_config["optimizer_eps"] = optimizer_eps 332 | train_config["grad_clip"] = grad_clip 333 | train_config["weight_decay"] = weight_decay 334 | train_config["noise_level"] = noise_level 335 | train_config["mixup_alpha"] = mixup_alpha 336 | train_config["rotation_ratio"] = rotation_ratio 337 | train_config["cutout_max_pct"] = cutout_max_pct 338 | train_config["cutout_patches"] = cutout_patches 339 | train_config["random_resize_method"] = random_resize_method 340 | train_config["mask_patch_size"] = mask_patch_size 341 | train_config["model_patch_size"] = model_patch_size 342 | train_config["mask_input_size"] = mask_input_size 343 | train_config["restore_params_ckpt"] = restore_params_ckpt 344 | 345 | # Add model specific arguments to WandB dict 346 | args_dict = vars(args) 347 | model_config = {key: args_dict[key] for key in args_dict if key not in train_config} 348 | del model_config["wandb_tags"] 349 | del model_config["run_name"] 350 | del model_config["epochs"] 351 | train_config.update(model_config) 352 | 353 | # WandB tracking 354 | wandb_args = dict( 355 | entity="smilingwolf", 356 | project=args.wandb_project, 357 | config=train_config, 358 | name=run_name, 359 | tags=args.wandb_tags, 360 | ) 361 | 362 | if args.wandb_run_id: 363 | wandb_args["id"] = args.wandb_run_id 364 | wandb_args["resume"] = "must" 365 | 366 | wandb_entity = wandb_args["entity"] 367 | wandb_project = wandb_args["project"] 368 | wandb_run_id = wandb_args["id"] 369 | wandb_run_path = f"{wandb_entity}/{wandb_project}/{wandb_run_id}" 370 | run_name = wandb.Api().run(wandb_run_path).name 371 | wandb_args["name"] = run_name 372 | 373 | wandb.init(**wandb_args) 374 | 375 | tf.random.set_seed(0) 376 | root_key = jax.random.key(0) 377 | params_key, dropout_key = jax.random.split(key=root_key, num=2) 378 | dropout_keys = jax.random.split(key=dropout_key, num=jax.device_count()) 379 | del root_key, dropout_key 380 | 381 | training_generator = DataGenerator( 382 | f"{dataset_root}/record_shards_train/*", 383 | num_classes=num_classes, 384 | image_size=image_size, 385 | batch_size=batch_size, 386 | num_devices=compute_units, 387 | noise_level=noise_level, 388 | mixup_alpha=mixup_alpha, 389 | rotation_ratio=rotation_ratio, 390 | cutout_max_pct=cutout_max_pct, 391 | cutout_patches=cutout_patches, 392 | random_resize_method=random_resize_method, 393 | mask_patch_size=mask_patch_size, 394 | model_patch_size=model_patch_size, 395 | mask_ratio=0.6, 396 | ) 397 | train_ds = training_generator.genDS() 398 | train_ds = jax_utils.prefetch_to_device(train_ds.as_numpy_iterator(), size=2) 399 | 400 | validation_generator = DataGenerator( 401 | f"{dataset_root}/record_shards_val/*", 402 | num_classes=num_classes, 403 | image_size=image_size, 404 | batch_size=batch_size, 405 | num_devices=compute_units, 406 | noise_level=0, 407 | mixup_alpha=0.0, 408 | rotation_ratio=0.0, 409 | cutout_max_pct=0.0, 410 | random_resize_method=False, 411 | mask_patch_size=mask_patch_size, 412 | model_patch_size=model_patch_size, 413 | mask_ratio=0.6, 414 | ) 415 | val_ds = validation_generator.genDS() 416 | val_ds = jax_utils.prefetch_to_device(val_ds.as_numpy_iterator(), size=2) 417 | 418 | model = model_builder.build( 419 | config=model_builder, 420 | image_size=image_size, 421 | patch_size=patch_size, 422 | num_classes=num_classes, 423 | dtype=jnp.bfloat16, 424 | **model_config, 425 | ) 426 | # tab_img = jnp.ones([1, image_size, image_size, 3]) 427 | # tab_mask = jnp.ones([1, mask_input_size, mask_input_size]) 428 | # print(model.tabulate(jax.random.key(0), tab_img, tab_mask, train=False)) 429 | 430 | num_steps_per_epoch = train_samples // global_batch_size 431 | learning_rate = optax.warmup_cosine_decay_schedule( 432 | init_value=learning_rate * 0.1, 433 | peak_value=learning_rate, 434 | warmup_steps=num_steps_per_epoch * warmup_epochs, 435 | decay_steps=num_steps_per_epoch * num_epochs, 436 | end_value=learning_rate * 0.01, 437 | ) 438 | 439 | state = create_train_state( 440 | model, 441 | params_key, 442 | image_size, 443 | mask_input_size, 444 | 0, 445 | learning_rate, 446 | optimizer_eps, 447 | grad_clip, 448 | weight_decay, 449 | ) 450 | del params_key 451 | 452 | metrics_history = {"train_loss": [], "val_loss": []} 453 | ckpt = {"model": state, "metrics_history": metrics_history} 454 | 455 | options_dict = dict( 456 | max_to_keep=args.checkpoints_keep, 457 | best_fn=lambda metrics: metrics["val_loss"], 458 | best_mode="min", 459 | ) 460 | if args.checkpoints_keep == -1: 461 | options_dict = dict(max_to_keep=1) 462 | 463 | orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() 464 | options = orbax.checkpoint.CheckpointManagerOptions( 465 | **options_dict, 466 | create=True, 467 | ) 468 | checkpoint_manager = orbax.checkpoint.CheckpointManager( 469 | f"{checkpoints_root}/{run_name}", 470 | orbax_checkpointer, 471 | options, 472 | ) 473 | 474 | if restore_params_ckpt: 475 | throwaway_manager = orbax.checkpoint.CheckpointManager( 476 | restore_params_ckpt, 477 | orbax_checkpointer, 478 | ) 479 | latest_epoch = throwaway_manager.latest_step() 480 | restored = throwaway_manager.restore(latest_epoch, items=ckpt) 481 | state = state.replace(params=restored["model"].params) 482 | del throwaway_manager 483 | 484 | latest_epoch = checkpoint_manager.latest_step() 485 | if latest_epoch is not None: 486 | restored = checkpoint_manager.restore(latest_epoch, items=ckpt) 487 | state = restored["model"] 488 | metrics_history = restored["metrics_history"] 489 | state = state.replace(metrics=state.metrics.empty()) 490 | else: 491 | latest_epoch = 0 492 | 493 | step = int(state.step) 494 | state = jax_utils.replicate(state) 495 | p_train_step = jax.pmap(train_step, axis_name="batch") 496 | p_eval_step = jax.pmap(eval_step, axis_name="batch") 497 | 498 | epochs = step // num_steps_per_epoch 499 | pbar = tqdm(total=num_steps_per_epoch) 500 | for batch in train_ds: 501 | # Run optimization steps over training batches and compute batch metrics 502 | # get updated train state (which contains the updated parameters) 503 | state = p_train_step(state=state, batch=batch, dropout_key=dropout_keys) 504 | 505 | if step % 224 == 0: 506 | merged_metrics = jax_utils.unreplicate(state.metrics) 507 | merged_metrics = jax.device_get(merged_metrics.loss.compute()) 508 | pbar.set_postfix(loss=f"{merged_metrics:.04f}") 509 | 510 | pbar.update(1) 511 | 512 | # one training epoch has passed 513 | if (step + 1) % num_steps_per_epoch == 0: 514 | # compute metrics 515 | merged_metrics = jax_utils.unreplicate(state.metrics) 516 | merged_metrics = jax.device_get(merged_metrics.compute()) 517 | for metric, value in merged_metrics.items(): 518 | # record metrics 519 | metrics_history[f"train_{metric}"].append(value) 520 | 521 | # reset train_metrics for validation 522 | empty_metrics = state.metrics.empty() 523 | empty_metrics = jax_utils.replicate(empty_metrics) 524 | state = state.replace(metrics=empty_metrics) 525 | 526 | # Compute metrics on the validation set after each training epoch 527 | for val_step, val_batch in enumerate(val_ds): 528 | state = p_eval_step(state=state, batch=val_batch) 529 | if val_step == val_samples // global_batch_size: 530 | break 531 | 532 | merged_metrics = jax_utils.unreplicate(state.metrics) 533 | merged_metrics = jax.device_get(merged_metrics.compute()) 534 | for metric, value in merged_metrics.items(): 535 | metrics_history[f"val_{metric}"].append(value) 536 | 537 | print( 538 | f"train epoch: {(step+1) // num_steps_per_epoch}, " 539 | f"loss: {metrics_history['train_loss'][-1]:.04f}" 540 | ) 541 | print( 542 | f"val epoch: {(step+1) // num_steps_per_epoch}, " 543 | f"loss: {metrics_history['val_loss'][-1]:.04f}" 544 | ) 545 | 546 | # Log Metrics to Weights & Biases 547 | wandb.log( 548 | { 549 | "train_loss": metrics_history["train_loss"][-1], 550 | "val_loss": metrics_history["val_loss"][-1], 551 | }, 552 | step=(step + 1) // num_steps_per_epoch, 553 | commit=True, 554 | ) 555 | 556 | if args.checkpoints_keep > 0: 557 | ckpt["model"] = jax.device_get(jax_utils.unreplicate(state)) 558 | ckpt["metrics_history"] = metrics_history 559 | save_args = orbax_utils.save_args_from_target(ckpt) 560 | checkpoint_manager.save( 561 | epochs, 562 | ckpt, 563 | save_kwargs={"save_args": save_args}, 564 | metrics={"val_loss": float(metrics_history["val_loss"][-1])}, 565 | ) 566 | 567 | # reset train_metrics for next training epoch 568 | empty_metrics = state.metrics.empty() 569 | empty_metrics = jax_utils.replicate(empty_metrics) 570 | state = state.replace(metrics=empty_metrics) 571 | 572 | epochs += 1 573 | if epochs == num_epochs: 574 | break 575 | 576 | pbar.reset() 577 | step += 1 578 | 579 | pbar.close() 580 | -------------------------------------------------------------------------------- /training_loop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from typing import Any, Callable, Union 5 | 6 | import flax 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import optax 11 | import orbax.checkpoint 12 | import tensorflow as tf 13 | import wandb 14 | from clu import metrics 15 | from flax import jax_utils 16 | from flax.training import orbax_utils, train_state 17 | from tqdm import tqdm 18 | 19 | import Models 20 | from Generators.WDTaggerGen import DataGenerator 21 | from Metrics.ConfusionMatrix import f1score, mcc 22 | 23 | 24 | @flax.struct.dataclass 25 | class Metrics(metrics.Collection): 26 | loss: metrics.Metric 27 | f1score: metrics.Metric 28 | mcc: metrics.Metric 29 | 30 | 31 | class TrainState(train_state.TrainState): 32 | metrics: Metrics 33 | constants: Any 34 | 35 | 36 | def create_optimizer_tx( 37 | module, 38 | params, 39 | learning_rate: Union[float, Callable], 40 | optimizer_eps: float, 41 | grad_clip: float, 42 | weight_decay: float, 43 | freeze_model_body: bool, 44 | ): 45 | def should_freeze(path, _): 46 | return "trainable" if "head" in path else "frozen" 47 | 48 | wd_mask = jax.tree_util.tree_map_with_path(module.should_decay, params) 49 | tx = optax.lamb( 50 | learning_rate, 51 | weight_decay=weight_decay, 52 | eps=optimizer_eps, 53 | mask=wd_mask, 54 | ) 55 | tx = optax.chain(optax.clip_by_global_norm(grad_clip), tx) 56 | 57 | if freeze_model_body: 58 | partition_optimizers = {"trainable": tx, "frozen": optax.set_to_zero()} 59 | param_partitions = flax.traverse_util.path_aware_map(should_freeze, params) 60 | tx = optax.multi_transform(partition_optimizers, param_partitions) 61 | return tx 62 | 63 | 64 | def create_train_state( 65 | module, 66 | params_key, 67 | target_size: int, 68 | num_classes: int, 69 | learning_rate: Union[float, Callable], 70 | optimizer_eps: float, 71 | grad_clip: float, 72 | weight_decay: float, 73 | freeze_model_body: bool = False, 74 | ): 75 | """Creates an initial 'TrainState'.""" 76 | # initialize parameters by passing a template image 77 | variables = module.init( 78 | params_key, 79 | jnp.ones([1, target_size, target_size, 3]), 80 | train=False, 81 | ) 82 | params = variables["params"] 83 | del variables["params"] 84 | constants = variables 85 | 86 | loss = metrics.Average.from_output("loss") 87 | f1score_metric = f1score( 88 | threshold=0.4, 89 | averaging="macro", 90 | num_classes=num_classes, 91 | from_logits=True, 92 | ) 93 | mcc_metric = mcc( 94 | threshold=0.4, 95 | averaging="macro", 96 | num_classes=num_classes, 97 | from_logits=True, 98 | ) 99 | collection = Metrics.create(loss=loss, f1score=f1score_metric, mcc=mcc_metric) 100 | 101 | tx = create_optimizer_tx( 102 | module, 103 | params, 104 | learning_rate, 105 | optimizer_eps, 106 | grad_clip, 107 | weight_decay, 108 | freeze_model_body, 109 | ) 110 | 111 | return TrainState.create( 112 | apply_fn=module.apply, 113 | params=params, 114 | tx=tx, 115 | metrics=collection.empty(), 116 | constants=constants, 117 | ) 118 | 119 | 120 | def train_step(state, batch, weights, dropout_key): 121 | """Train for a single step.""" 122 | dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step) 123 | 124 | def loss_fn(params, weights, **kwargs): 125 | logits = state.apply_fn( 126 | {"params": params, **kwargs}, 127 | batch["images"], 128 | train=True, 129 | rngs={"dropout": dropout_train_key}, 130 | ) 131 | loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch["labels"]) 132 | loss = loss * weights 133 | loss = loss.sum() / batch["labels"].shape[0] 134 | return loss, logits 135 | 136 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 137 | (loss, logits), grads = grad_fn(state.params, weights, **state.constants) 138 | grads = jax.lax.pmean(grads, axis_name="batch") 139 | state = state.apply_gradients(grads=grads) 140 | 141 | metric_updates = state.metrics.gather_from_model_output( 142 | logits=logits, 143 | labels=batch["labels"], 144 | loss=loss, 145 | ) 146 | metrics = state.metrics.merge(metric_updates) 147 | state = state.replace(metrics=metrics) 148 | return state 149 | 150 | 151 | def eval_step(*, state, batch): 152 | logits = state.apply_fn( 153 | {"params": state.params, **state.constants}, 154 | batch["images"], 155 | train=False, 156 | ) 157 | 158 | loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch["labels"]) 159 | loss = loss.sum() / batch["labels"].shape[0] 160 | metric_updates = state.metrics.gather_from_model_output( 161 | logits=logits, 162 | labels=batch["labels"], 163 | loss=loss, 164 | ) 165 | metrics = state.metrics.merge(metric_updates) 166 | state = state.replace(metrics=metrics) 167 | return state 168 | 169 | 170 | model_parser = argparse.ArgumentParser( 171 | description="Model variant to train", 172 | add_help=False, 173 | ) 174 | model_parser.add_argument( 175 | "--model-name", 176 | default="vit_small", 177 | help="Model variant to train", 178 | type=str, 179 | ) 180 | 181 | parser = argparse.ArgumentParser(description="Train a network") 182 | parser.add_argument( 183 | "--run-name", 184 | default=None, 185 | help="Run name. If left empty it gets autogenerated", 186 | type=str, 187 | ) 188 | parser.add_argument( 189 | "--wandb-project", 190 | default="tpu-tracking", 191 | help="WandB project", 192 | type=str, 193 | ) 194 | parser.add_argument( 195 | "--wandb-run-id", 196 | default=None, 197 | help="WandB run ID (8 chars code) to resume interrupted run", 198 | type=str, 199 | ) 200 | parser.add_argument( 201 | "--wandb-tags", 202 | nargs="*", 203 | help="Space separated list of tags for WandB", 204 | ) 205 | parser.add_argument( 206 | "--restore-params-ckpt", 207 | default="", 208 | help="Restore the parameters from the last step of the given orbax checkpoint. Must be an absolute path. WARNING: restores params only!", 209 | type=str, 210 | ) 211 | parser.add_argument( 212 | "--restore-simmim-ckpt", 213 | default="", 214 | help="Restore the parameters from the last step of the given SimMIM-pretrained orbax checkpoint. Must be an absolute path", 215 | type=str, 216 | ) 217 | parser.add_argument( 218 | "--freeze-model-body", 219 | action="store_true", 220 | help="Freeze the feature extraction layers, train classifier head only", 221 | ) 222 | parser.add_argument( 223 | "--dataset-file", 224 | default="datasets/aibooru.json", 225 | help="JSON file with dataset specs", 226 | type=str, 227 | ) 228 | parser.add_argument( 229 | "--dataset-root", 230 | default="/home/smilingwolf/datasets", 231 | help="Dataset root, where the record_shards_train and record_shards_val folders are stored", 232 | type=str, 233 | ) 234 | parser.add_argument( 235 | "--checkpoints-root", 236 | default="/mnt/c/Users/SmilingWolf/Desktop/TFKeras/JAX/checkpoints", 237 | help="Checkpoints root, where the checkpoints will be stored following a // structure", 238 | type=str, 239 | ) 240 | parser.add_argument( 241 | "--checkpoints-keep", 242 | default=2, 243 | help="Number of best (by val_loss) checkpoints to keep. -1 to always keep the last checkpoint", 244 | type=int, 245 | ) 246 | parser.add_argument( 247 | "--epochs", 248 | default=50, 249 | help="Number of epochs to train for", 250 | type=int, 251 | ) 252 | parser.add_argument( 253 | "--warmup-epochs", 254 | default=5, 255 | help="Number of epochs to dedicate to linear warmup", 256 | type=int, 257 | ) 258 | parser.add_argument( 259 | "--batch-size", 260 | default=64, 261 | help="Per-device batch size", 262 | type=int, 263 | ) 264 | parser.add_argument( 265 | "--image-size", 266 | default=256, 267 | help="Image resolution in input to the network", 268 | type=int, 269 | ) 270 | parser.add_argument( 271 | "--patch-size", 272 | default=16, 273 | help="Size of the image patches", 274 | type=int, 275 | ) 276 | parser.add_argument( 277 | "--learning-rate", 278 | default=0.001, 279 | help="Max learning rate", 280 | type=float, 281 | ) 282 | parser.add_argument( 283 | "--optimizer-eps", 284 | default=1e-6, 285 | help="Optimizer epsilon", 286 | type=float, 287 | ) 288 | parser.add_argument( 289 | "--grad-clip", 290 | default=1.0, 291 | help="Gradient clipping", 292 | type=float, 293 | ) 294 | parser.add_argument( 295 | "--weight-decay", 296 | default=0.0001, 297 | help="Weight decay", 298 | type=float, 299 | ) 300 | parser.add_argument( 301 | "--loss-weights-file", 302 | default=None, 303 | help="Numpy dump of weights to apply to the training loss", 304 | type=str, 305 | ) 306 | parser.add_argument( 307 | "--mixup-alpha", 308 | default=0.8, 309 | help="MixUp alpha", 310 | type=float, 311 | ) 312 | parser.add_argument( 313 | "--rotation-ratio", 314 | default=0.0, 315 | help="Rotation ratio as a fraction of PI", 316 | type=float, 317 | ) 318 | parser.add_argument( 319 | "--cutout-max-pct", 320 | default=0.1, 321 | help="Cutout area as a fraction of the total image area", 322 | type=float, 323 | ) 324 | parser.add_argument( 325 | "--cutout-patches", 326 | default=1, 327 | help="Number of cutout patches", 328 | type=int, 329 | ) 330 | model_arg, remaining = model_parser.parse_known_args() 331 | 332 | model_name = model_arg.model_name 333 | model_builder = Models.model_registry[model_name]() 334 | parser = model_builder.extend_parser(parser=parser) 335 | 336 | args = parser.parse_args(remaining) 337 | 338 | run_name = args.run_name 339 | if run_name is None: 340 | now = datetime.now() 341 | date_time = now.strftime("%Y_%m_%d_%Hh%Mm%Ss") 342 | run_name = f"{model_name}_{date_time}" 343 | 344 | checkpoints_root = args.checkpoints_root 345 | dataset_root = args.dataset_root 346 | with open(args.dataset_file) as f: 347 | dataset_specs = json.load(f) 348 | 349 | # Run params 350 | num_epochs = args.epochs 351 | warmup_epochs = args.warmup_epochs 352 | batch_size = args.batch_size 353 | compute_units = jax.device_count() 354 | global_batch_size = batch_size * compute_units 355 | restore_params_ckpt = args.restore_params_ckpt 356 | restore_simmim_ckpt = args.restore_simmim_ckpt 357 | 358 | # Dataset params 359 | image_size = args.image_size 360 | num_classes = dataset_specs["num_classes"] 361 | train_samples = dataset_specs["train_samples"] 362 | val_samples = dataset_specs["val_samples"] 363 | 364 | # Model hyperparams 365 | patch_size = args.patch_size 366 | learning_rate = args.learning_rate 367 | optimizer_eps = args.optimizer_eps 368 | grad_clip = args.grad_clip 369 | weight_decay = args.weight_decay 370 | loss_weights_file = args.loss_weights_file 371 | freeze_model_body = args.freeze_model_body 372 | 373 | # Augmentations hyperparams 374 | noise_level = 2 375 | mixup_alpha = args.mixup_alpha 376 | rotation_ratio = args.rotation_ratio 377 | cutout_max_pct = args.cutout_max_pct 378 | cutout_patches = args.cutout_patches 379 | random_resize_method = True 380 | 381 | # WandB tracking 382 | train_config = {} 383 | train_config["model_name"] = model_name 384 | train_config["checkpoints_root"] = checkpoints_root 385 | train_config["dataset_root"] = dataset_root 386 | train_config["dataset_file"] = args.dataset_file 387 | train_config["num_epochs"] = num_epochs 388 | train_config["warmup_epochs"] = warmup_epochs 389 | train_config["batch_size"] = batch_size 390 | train_config["compute_units"] = compute_units 391 | train_config["global_batch_size"] = global_batch_size 392 | train_config["image_size"] = image_size 393 | train_config["num_classes"] = num_classes 394 | train_config["train_samples"] = train_samples 395 | train_config["val_samples"] = val_samples 396 | train_config["patch_size"] = patch_size 397 | train_config["learning_rate"] = learning_rate 398 | train_config["optimizer_eps"] = optimizer_eps 399 | train_config["grad_clip"] = grad_clip 400 | train_config["weight_decay"] = weight_decay 401 | train_config["loss_weights_file"] = loss_weights_file 402 | train_config["noise_level"] = noise_level 403 | train_config["mixup_alpha"] = mixup_alpha 404 | train_config["rotation_ratio"] = rotation_ratio 405 | train_config["cutout_max_pct"] = cutout_max_pct 406 | train_config["cutout_patches"] = cutout_patches 407 | train_config["random_resize_method"] = random_resize_method 408 | train_config["restore_params_ckpt"] = restore_params_ckpt 409 | train_config["restore_simmim_ckpt"] = restore_simmim_ckpt 410 | train_config["freeze_model_body"] = freeze_model_body 411 | 412 | # Add model specific arguments to WandB dict 413 | args_dict = vars(args) 414 | model_config = {key: args_dict[key] for key in args_dict if key not in train_config} 415 | del model_config["wandb_tags"] 416 | del model_config["run_name"] 417 | del model_config["epochs"] 418 | train_config.update(model_config) 419 | 420 | # WandB tracking 421 | wandb_args = dict( 422 | entity="smilingwolf", 423 | project=args.wandb_project, 424 | config=train_config, 425 | name=run_name, 426 | tags=args.wandb_tags, 427 | ) 428 | 429 | if args.wandb_run_id: 430 | wandb_args["id"] = args.wandb_run_id 431 | wandb_args["resume"] = "must" 432 | 433 | wandb_entity = wandb_args["entity"] 434 | wandb_project = wandb_args["project"] 435 | wandb_run_id = wandb_args["id"] 436 | wandb_run_path = f"{wandb_entity}/{wandb_project}/{wandb_run_id}" 437 | run_name = wandb.Api().run(wandb_run_path).name 438 | wandb_args["name"] = run_name 439 | 440 | wandb.init(**wandb_args) 441 | 442 | tf.random.set_seed(0) 443 | root_key = jax.random.key(0) 444 | params_key, dropout_key = jax.random.split(key=root_key, num=2) 445 | dropout_keys = jax.random.split(key=dropout_key, num=jax.device_count()) 446 | del root_key, dropout_key 447 | 448 | training_generator = DataGenerator( 449 | f"{dataset_root}/record_shards_train/*", 450 | num_classes=num_classes, 451 | image_size=image_size, 452 | batch_size=batch_size, 453 | num_devices=compute_units, 454 | noise_level=noise_level, 455 | mixup_alpha=mixup_alpha, 456 | rotation_ratio=rotation_ratio, 457 | cutout_max_pct=cutout_max_pct, 458 | cutout_patches=cutout_patches, 459 | random_resize_method=random_resize_method, 460 | ) 461 | train_ds = training_generator.genDS() 462 | train_ds = jax_utils.prefetch_to_device(train_ds.as_numpy_iterator(), size=2) 463 | 464 | validation_generator = DataGenerator( 465 | f"{dataset_root}/record_shards_val/*", 466 | num_classes=num_classes, 467 | image_size=image_size, 468 | batch_size=batch_size, 469 | num_devices=compute_units, 470 | noise_level=0, 471 | mixup_alpha=0.0, 472 | rotation_ratio=0.0, 473 | cutout_max_pct=0.0, 474 | random_resize_method=False, 475 | ) 476 | val_ds = validation_generator.genDS() 477 | val_ds = jax_utils.prefetch_to_device(val_ds.as_numpy_iterator(), size=2) 478 | 479 | model = model_builder.build( 480 | config=model_builder, 481 | image_size=image_size, 482 | patch_size=patch_size, 483 | num_classes=num_classes, 484 | dtype=jnp.bfloat16, 485 | **model_config, 486 | ) 487 | # tab_img = jnp.ones([1, image_size, image_size, 3]) 488 | # print(model.tabulate(jax.random.key(0), tab_img, train=False)) 489 | 490 | num_steps_per_epoch = train_samples // global_batch_size 491 | learning_rate = optax.warmup_cosine_decay_schedule( 492 | init_value=learning_rate * 0.1, 493 | peak_value=learning_rate, 494 | warmup_steps=num_steps_per_epoch * warmup_epochs, 495 | decay_steps=num_steps_per_epoch * num_epochs, 496 | end_value=learning_rate * 0.01, 497 | ) 498 | 499 | state = create_train_state( 500 | model, 501 | params_key, 502 | image_size, 503 | num_classes, 504 | learning_rate, 505 | optimizer_eps, 506 | grad_clip, 507 | weight_decay, 508 | freeze_model_body, 509 | ) 510 | del params_key 511 | 512 | metrics_history = { 513 | "train_loss": [], 514 | "train_f1score": [], 515 | "train_mcc": [], 516 | "val_loss": [], 517 | "val_f1score": [], 518 | "val_mcc": [], 519 | } 520 | ckpt = {"model": state, "metrics_history": metrics_history} 521 | 522 | options_dict = dict( 523 | max_to_keep=args.checkpoints_keep, 524 | best_fn=lambda metrics: metrics["val_loss"], 525 | best_mode="min", 526 | ) 527 | if args.checkpoints_keep == -1: 528 | options_dict = dict(max_to_keep=1) 529 | 530 | orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() 531 | options = orbax.checkpoint.CheckpointManagerOptions( 532 | **options_dict, 533 | create=True, 534 | ) 535 | checkpoint_manager = orbax.checkpoint.CheckpointManager( 536 | f"{checkpoints_root}/{run_name}", 537 | orbax_checkpointer, 538 | options, 539 | ) 540 | 541 | if restore_params_ckpt or restore_simmim_ckpt: 542 | ckpt_path = restore_params_ckpt if restore_params_ckpt else restore_simmim_ckpt 543 | 544 | throwaway_manager = orbax.checkpoint.CheckpointManager( 545 | ckpt_path, 546 | orbax_checkpointer, 547 | ) 548 | latest_epoch = throwaway_manager.latest_step() 549 | restored = throwaway_manager.restore(latest_epoch) 550 | 551 | transforms = {} 552 | if restore_simmim_ckpt: 553 | tx_pairs = model.get_simmim_orbax_txs() 554 | for tx_regex, tx_action in tx_pairs: 555 | tx_action = orbax.checkpoint.Transform(original_key=tx_action) 556 | transforms[tx_regex] = tx_action 557 | 558 | restored = orbax.checkpoint.apply_transformations(restored, transforms, ckpt) 559 | 560 | state = state.replace(params=restored["model"].params) 561 | del throwaway_manager 562 | 563 | latest_epoch = checkpoint_manager.latest_step() 564 | if latest_epoch is not None: 565 | restored = checkpoint_manager.restore(latest_epoch, items=ckpt) 566 | state = restored["model"] 567 | metrics_history = restored["metrics_history"] 568 | state = state.replace(metrics=state.metrics.empty()) 569 | else: 570 | latest_epoch = 0 571 | 572 | # TODO: maybe the weights should be included in the TrainState? 573 | if loss_weights_file: 574 | label_weights = np.load(loss_weights_file, allow_pickle=False) 575 | else: 576 | label_weights = np.array([1.0]).astype(np.float32) 577 | label_weights = jax_utils.replicate(label_weights) 578 | 579 | step = int(state.step) 580 | state = jax_utils.replicate(state) 581 | p_train_step = jax.pmap(train_step, axis_name="batch") 582 | p_eval_step = jax.pmap(eval_step, axis_name="batch") 583 | 584 | epochs = step // num_steps_per_epoch 585 | pbar = tqdm(total=num_steps_per_epoch) 586 | for batch in train_ds: 587 | # Run optimization steps over training batches and compute batch metrics 588 | # get updated train state (which contains the updated parameters) 589 | state = p_train_step( 590 | state=state, 591 | batch=batch, 592 | weights=label_weights, 593 | dropout_key=dropout_keys, 594 | ) 595 | 596 | if step % 224 == 0: 597 | merged_metrics = jax_utils.unreplicate(state.metrics) 598 | merged_metrics = jax.device_get(merged_metrics.loss.compute()) 599 | pbar.set_postfix(loss=f"{merged_metrics:.04f}") 600 | 601 | pbar.update(1) 602 | 603 | # one training epoch has passed 604 | if (step + 1) % num_steps_per_epoch == 0: 605 | # compute metrics 606 | merged_metrics = jax_utils.unreplicate(state.metrics) 607 | merged_metrics = jax.device_get(merged_metrics.compute()) 608 | for metric, value in merged_metrics.items(): 609 | # record metrics 610 | metrics_history[f"train_{metric}"].append(value) 611 | 612 | # reset train_metrics for validation 613 | empty_metrics = state.metrics.empty() 614 | empty_metrics = jax_utils.replicate(empty_metrics) 615 | state = state.replace(metrics=empty_metrics) 616 | 617 | # Compute metrics on the validation set after each training epoch 618 | for val_step, val_batch in enumerate(val_ds): 619 | state = p_eval_step(state=state, batch=val_batch) 620 | if val_step == val_samples // global_batch_size: 621 | break 622 | 623 | merged_metrics = jax_utils.unreplicate(state.metrics) 624 | merged_metrics = jax.device_get(merged_metrics.compute()) 625 | for metric, value in merged_metrics.items(): 626 | metrics_history[f"val_{metric}"].append(value) 627 | 628 | print( 629 | f"train epoch: {(step+1) // num_steps_per_epoch}, " 630 | f"loss: {metrics_history['train_loss'][-1]:.04f}, " 631 | f"f1score: {metrics_history['train_f1score'][-1]*100:.02f}, " 632 | f"mcc: {metrics_history['train_mcc'][-1]*100:.02f}" 633 | ) 634 | print( 635 | f"val epoch: {(step+1) // num_steps_per_epoch}, " 636 | f"loss: {metrics_history['val_loss'][-1]:.04f}, " 637 | f"f1score: {metrics_history['val_f1score'][-1]*100:.02f}, " 638 | f"mcc: {metrics_history['val_mcc'][-1]*100:.02f}" 639 | ) 640 | 641 | # Log Metrics to Weights & Biases 642 | wandb.log( 643 | { 644 | "train_loss": metrics_history["train_loss"][-1], 645 | "train_f1score": metrics_history["train_f1score"][-1] * 100, 646 | "train_mcc": metrics_history["train_mcc"][-1] * 100, 647 | "val_loss": metrics_history["val_loss"][-1], 648 | "val_f1score": metrics_history["val_f1score"][-1] * 100, 649 | "val_mcc": metrics_history["val_mcc"][-1] * 100, 650 | }, 651 | step=(step + 1) // num_steps_per_epoch, 652 | commit=True, 653 | ) 654 | 655 | if args.checkpoints_keep > 0: 656 | ckpt["model"] = jax.device_get(jax_utils.unreplicate(state)) 657 | ckpt["metrics_history"] = metrics_history 658 | save_args = orbax_utils.save_args_from_target(ckpt) 659 | checkpoint_manager.save( 660 | epochs, 661 | ckpt, 662 | save_kwargs={"save_args": save_args}, 663 | metrics={"val_loss": float(metrics_history["val_loss"][-1])}, 664 | ) 665 | 666 | # reset train_metrics for next training epoch 667 | empty_metrics = state.metrics.empty() 668 | empty_metrics = jax_utils.replicate(empty_metrics) 669 | state = state.replace(metrics=empty_metrics) 670 | 671 | epochs += 1 672 | if epochs == num_epochs: 673 | break 674 | 675 | pbar.reset() 676 | step += 1 677 | 678 | pbar.close() 679 | --------------------------------------------------------------------------------