├── .gitattributes ├── .gitignore ├── README.md ├── docs ├── Documentation.md └── img │ ├── flax.png │ └── gpt2_diagram.png ├── flaxmodels ├── __init__.py ├── few_shot_gan_adaption │ ├── README.md │ ├── __init__.py │ ├── discriminator.py │ ├── generator.py │ ├── images │ │ ├── AmedeoModigliani.jpg │ │ ├── Babies.jpg │ │ ├── OttoDix.jpg │ │ ├── Rafael.jpg │ │ └── Sketches.jpg │ └── ops.py ├── gpt2 │ ├── README.md │ ├── __init__.py │ ├── gpt2.py │ ├── gpt2_demo.ipynb │ ├── ops.py │ ├── third_party │ │ ├── __init__.py │ │ └── huggingface_transformers │ │ │ ├── __init__.py │ │ │ ├── configuration_gpt2.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── file_utils.py │ │ │ ├── hf_api.py │ │ │ ├── logging.py │ │ │ ├── tokenization_utils.py │ │ │ ├── tokenization_utils_base.py │ │ │ └── versions.py │ └── tokenizer.py ├── resnet │ ├── README.md │ ├── __init__.py │ ├── ops.py │ ├── resnet.py │ └── resnet_demo.ipynb ├── stylegan2 │ ├── README.md │ ├── __init__.py │ ├── discriminator.py │ ├── generator.py │ ├── images │ │ ├── afhqcat.jpg │ │ ├── afhqdog.jpg │ │ ├── afhqwild.jpg │ │ ├── brecahad.jpg │ │ ├── car.jpg │ │ ├── cat.jpg │ │ ├── church.jpg │ │ ├── cifar10.jpg │ │ ├── ffhq.jpg │ │ ├── gen_images_w_trunc.jpg │ │ ├── gen_images_with_labels.jpg │ │ ├── gen_images_wo_trunc.jpg │ │ ├── horse.jpg │ │ ├── metfaces.jpg │ │ ├── style_mixing.jpg │ │ └── title.jpg │ ├── ops.py │ └── stylegan2_demo.ipynb ├── utils.py └── vgg │ ├── README.md │ ├── __init__.py │ ├── vgg.py │ └── vgg_demo.ipynb ├── setup.py ├── tests ├── aux_files │ └── elefant.jpg ├── gpt2 │ ├── aux_files │ │ ├── gpt2-large │ │ │ ├── gpt2-large_lmhead_input_embds_input.npy │ │ │ ├── gpt2-large_lmhead_input_embds_labels.npy │ │ │ ├── gpt2-large_lmhead_input_embds_logits_ref.npy │ │ │ ├── gpt2-large_lmhead_input_embds_loss_ref.npy │ │ │ ├── gpt2-large_lmhead_input_ids_input.npy │ │ │ ├── gpt2-large_lmhead_input_ids_labels.npy │ │ │ ├── gpt2-large_lmhead_input_ids_logits_ref.npy │ │ │ ├── gpt2-large_lmhead_input_ids_loss_ref.npy │ │ │ ├── gpt2-large_model_input_embds_input.npy │ │ │ ├── gpt2-large_model_input_embds_output_ref.npy │ │ │ ├── gpt2-large_model_input_ids_input.npy │ │ │ └── gpt2-large_model_input_ids_output_ref.npy │ │ ├── gpt2-medium │ │ │ ├── gpt2-medium_lmhead_input_embds_input.npy │ │ │ ├── gpt2-medium_lmhead_input_embds_labels.npy │ │ │ ├── gpt2-medium_lmhead_input_embds_logits_ref.npy │ │ │ ├── gpt2-medium_lmhead_input_embds_loss_ref.npy │ │ │ ├── gpt2-medium_lmhead_input_ids_input.npy │ │ │ ├── gpt2-medium_lmhead_input_ids_labels.npy │ │ │ ├── gpt2-medium_lmhead_input_ids_logits_ref.npy │ │ │ ├── gpt2-medium_lmhead_input_ids_loss_ref.npy │ │ │ ├── gpt2-medium_model_input_embds_input.npy │ │ │ ├── gpt2-medium_model_input_embds_output_ref.npy │ │ │ ├── gpt2-medium_model_input_ids_input.npy │ │ │ └── gpt2-medium_model_input_ids_output_ref.npy │ │ ├── gpt2-xl │ │ │ ├── gpt2-xl_lmhead_input_embds_input.npy │ │ │ ├── gpt2-xl_lmhead_input_embds_labels.npy │ │ │ ├── gpt2-xl_lmhead_input_embds_logits_ref.npy │ │ │ ├── gpt2-xl_lmhead_input_embds_loss_ref.npy │ │ │ ├── gpt2-xl_lmhead_input_ids_input.npy │ │ │ ├── gpt2-xl_lmhead_input_ids_labels.npy │ │ │ ├── gpt2-xl_lmhead_input_ids_logits_ref.npy │ │ │ ├── gpt2-xl_lmhead_input_ids_loss_ref.npy │ │ │ ├── gpt2-xl_model_input_embds_input.npy │ │ │ ├── gpt2-xl_model_input_embds_output_ref.npy │ │ │ ├── gpt2-xl_model_input_ids_input.npy │ │ │ └── gpt2-xl_model_input_ids_output_ref.npy │ │ └── gpt2 │ │ │ ├── gpt2_lmhead_input_embds_input.npy │ │ │ ├── gpt2_lmhead_input_embds_labels.npy │ │ │ ├── gpt2_lmhead_input_embds_logits_ref.npy │ │ │ ├── gpt2_lmhead_input_embds_loss_ref.npy │ │ │ ├── gpt2_lmhead_input_ids_input.npy │ │ │ ├── gpt2_lmhead_input_ids_labels.npy │ │ │ ├── gpt2_lmhead_input_ids_logits_ref.npy │ │ │ ├── gpt2_lmhead_input_ids_loss_ref.npy │ │ │ ├── gpt2_model_input_embds_input.npy │ │ │ ├── gpt2_model_input_embds_output_ref.npy │ │ │ ├── gpt2_model_input_ids_input.npy │ │ │ └── gpt2_model_input_ids_output_ref.npy │ ├── test_gpt2.py │ ├── test_gpt2_large.py │ ├── test_gpt2_medium.py │ └── test_gpt2_xl.py ├── resnet │ ├── aux_files │ │ ├── resnet101_elefant_output_ref.npy │ │ ├── resnet152_elefant_output_ref.npy │ │ ├── resnet18_elefant_output_ref.npy │ │ ├── resnet34_elefant_output_ref.npy │ │ └── resnet50_elefant_output_ref.npy │ ├── test_resnet101.py │ ├── test_resnet152.py │ ├── test_resnet18.py │ ├── test_resnet34.py │ └── test_resnet50.py ├── stylegan2 │ ├── discriminator │ │ ├── aux_files │ │ │ ├── afhqcat_input_img.npy │ │ │ ├── afhqcat_output_ref.npy │ │ │ ├── afhqdog_input_img.npy │ │ │ ├── afhqdog_output_ref.npy │ │ │ ├── afhqwild_input_img.npy │ │ │ ├── afhqwild_output_ref.npy │ │ │ ├── brecahad_input_img.npy │ │ │ ├── brecahad_output_ref.npy │ │ │ ├── car_input_img.npy │ │ │ ├── car_output_ref.npy │ │ │ ├── cat_input_img.npy │ │ │ ├── cat_output_ref.npy │ │ │ ├── church_input_img.npy │ │ │ ├── church_output_ref.npy │ │ │ ├── cifar10_input_img.npy │ │ │ ├── cifar10_input_label.npy │ │ │ ├── cifar10_output_ref.npy │ │ │ ├── ffhq_input_img.npy │ │ │ ├── ffhq_output_ref.npy │ │ │ ├── horse_input_img.npy │ │ │ ├── horse_output_ref.npy │ │ │ ├── metfaces_input_img.npy │ │ │ └── metfaces_output_ref.npy │ │ └── test_discriminator.py │ └── generator │ │ ├── aux_files │ │ ├── afhqcat_input_z.npy │ │ ├── afhqcat_output_ref.npy │ │ ├── afhqdog_input_z.npy │ │ ├── afhqdog_output_ref.npy │ │ ├── afhqwild_input_z.npy │ │ ├── afhqwild_output_ref.npy │ │ ├── brecahad_input_z.npy │ │ ├── brecahad_output_ref.npy │ │ ├── car_input_z.npy │ │ ├── car_output_ref.npy │ │ ├── cat_input_z.npy │ │ ├── cat_output_ref.npy │ │ ├── church_input_z.npy │ │ ├── church_output_ref.npy │ │ ├── cifar10_input_label.npy │ │ ├── cifar10_input_z.npy │ │ ├── cifar10_output_ref.npy │ │ ├── ffhq_input_z.npy │ │ ├── ffhq_output_ref.npy │ │ ├── horse_input_z.npy │ │ ├── horse_output_ref.npy │ │ ├── metfaces_input_z.npy │ │ └── metfaces_output_ref.npy │ │ └── test_generator.py └── vgg │ ├── aux_files │ ├── vgg16_elefant_output_ref.npy │ └── vgg19_elefant_output_ref.npy │ ├── test_vgg16.py │ └── test_vgg19.py └── training ├── few_shot_gan_adaption ├── README.md ├── checkpoint.py ├── data_pipeline.py ├── dataset_utils │ └── images_to_tfrecords.py ├── datatest.py ├── fid │ ├── __init__.py │ ├── core.py │ ├── inception.py │ └── utils.py ├── images │ └── overview.jpg ├── main.py ├── training.py ├── training_steps.py └── training_utils.py ├── resnet ├── README.md ├── main.py ├── requirements.txt └── training.py ├── stylegan2 ├── README.md ├── checkpoint.py ├── data_pipeline.py ├── dataset_utils │ ├── crop_image_borders.py │ └── images_to_tfrecords.py ├── fid │ ├── __init__.py │ ├── core.py │ ├── inception.py │ └── utils.py ├── generate_images.py ├── images │ ├── anime_grid.jpg │ ├── anime_overview.jpg │ ├── anime_style_mixing.jpg │ ├── ffhq_grid.jpg │ ├── ffhq_overview.jpg │ └── ffhq_style_mixing.jpg ├── main.py ├── requirements.txt ├── style_mixing.py ├── training.py ├── training_steps.py └── training_utils.py └── vgg ├── README.md ├── main.py ├── requirements.txt └── training.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.swp 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
flax
2 |

Flax Models

3 |
A collection of pretrained models in Flax.
4 | 5 |
6 | 7 | 8 | ### About 9 | The goal of this project is to make current deep learning models more easily available for the awesome Jax/Flax ecosystem. 10 | 11 | ### Models 12 | * GPT2 [[model](flaxmodels/gpt2)] 13 | * StyleGAN2 [[model](flaxmodels/stylegan2)] [[training](training/stylegan2)] 14 | * ResNet{18, 34, 50, 101, 152} [[model](flaxmodels/resnet)] [[training](training/resnet)] 15 | * VGG{16, 19} [[model](flaxmodels/vgg)] [[training](training/vgg)] 16 | * FewShotGanAdaption [[model](flaxmodels/few_shot_gan_adaption)] [[training](training/few_shot_gan_adaption)] 17 | 18 | 19 | ### Installation 20 | You will need Python 3.7 or later. 21 | 22 | 1. For GPU usage, follow the Jax installation with CUDA. 23 | 2. Then install: 24 | ```sh 25 | > pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git 26 | ``` 27 | For CPU-only you can skip step 1. 28 | 29 | ### Documentation 30 | The documentation for the models can be found [here](docs/Documentation.md#models). 31 | 32 | ### Checkpoints 33 | The checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documented [here](docs/Documentation.md#1-checkpoints). 34 | 35 | ### Testing 36 | To run the tests, pytest needs to be installed. 37 | ```sh 38 | > git clone https://github.com/matthias-wright/flaxmodels.git 39 | > cd flaxmodels 40 | > python -m pytest tests/ 41 | ``` 42 | See [here](docs/Documentation.md#2-testing) for an explanation of the testing strategy. 43 | 44 | 45 | ### Acknowledgments 46 | Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by Marta Matyszczyk. 47 | 48 | ### License 49 | Each model has an individual license. 50 | -------------------------------------------------------------------------------- /docs/img/flax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/docs/img/flax.png -------------------------------------------------------------------------------- /docs/img/gpt2_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/docs/img/gpt2_diagram.png -------------------------------------------------------------------------------- /flaxmodels/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .resnet import * 3 | from . import stylegan2 4 | from . import gpt2 5 | from . import few_shot_gan_adaption 6 | 7 | __version__ = '0.1.2' 8 | -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/README.md: -------------------------------------------------------------------------------- 1 | # Few-shot Image Generation via Cross-domain Correspondence 2 | 3 |
img
4 | 5 | Paper: https://arxiv.org/abs/2104.06820 6 | Repository: https://github.com/utkarshojha/few-shot-gan-adaptation 7 | 8 | ##### Table of Contents 9 | * [1. Basic Usage](#usage) 10 | * [2. Checkpoints](#checkpoints) 11 | * [3. Documentation](#documentation) 12 | * [4. Training](#training) 13 | * [5. Images](#images) 14 | * [6. License](#license) 15 | 16 | 17 | 18 | ## 1. Basic Usage 19 | 20 | ```python 21 | import jax 22 | import numpy as np 23 | import dill as pickle 24 | from PIL import Image 25 | 26 | import flaxmodels as fm 27 | 28 | ckpt = pickle.load(open('sketches.pickle', 'rb')) 29 | params = ckpt['params_ema_G'] 30 | 31 | generator = fm.few_shot_gan_adaption.Generator() 32 | 33 | # Seed 34 | key = jax.random.PRNGKey(0) 35 | 36 | # Input noise 37 | z = jax.random.normal(key, shape=(4, 512)) 38 | 39 | # Generate images 40 | images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const') 41 | 42 | # Normalize images to be in range [0, 1] 43 | images = (images - np.min(images)) / (np.max(images) - np.min(images)) 44 | 45 | # Save images 46 | for i in range(images.shape[0]): 47 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg') 48 | 49 | ``` 50 | 51 | 52 | ## 2. Checkpoints 53 | * [Sketches](https://www.dropbox.com/s/azr6b316juhme6c/sketches.pickle?dl=1) (357,2 MB) 54 | * [Amedeo Modigliani](https://www.dropbox.com/s/xrh4a7wt2kggn4v/amedeo_modigliani.pickle?dl=1) (357,2 MB) 55 | * [Babies](https://www.dropbox.com/s/ntictyzrisqg5zh/babies.pickle?dl=1) (357,2 MB) 56 | * [Otto Dix](https://www.dropbox.com/s/u1g18nv73uac21m/otto_dix.pickle?dl=1) (357,2 MB) 57 | * [Rafael](https://www.dropbox.com/s/b8w928s4wffuo2c/raphael.pickle?dl=1) (357,2 MB) 58 | 59 | 60 | 61 | ## 3. Documentation 62 | The documentation can be found [here](../../docs/Documentation.md#few_shot_gan_adaption). 63 | 64 | 65 | ## 4. Training 66 | If you want to train this model in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/few_shot_gan_adaption). 67 | 68 | 69 | ## 5. Images 70 | 71 | ### Sketches 72 |
img
73 | 74 | ### Amedeo Modigliani 75 |
img
76 | 77 | ### Babies 78 |
img
79 | 80 | ### Otto Dix 81 |
img
82 | 83 | ### Rafael 84 |
img
85 | 86 | 87 | 88 | ## 6. License 89 | MIT License 90 | 91 | 92 | -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import MappingNetwork 2 | from .generator import SynthesisNetwork 3 | from .generator import Generator 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/images/AmedeoModigliani.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/AmedeoModigliani.jpg -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/images/Babies.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Babies.jpg -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/images/OttoDix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/OttoDix.jpg -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/images/Rafael.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Rafael.jpg -------------------------------------------------------------------------------- /flaxmodels/few_shot_gan_adaption/images/Sketches.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Sketches.jpg -------------------------------------------------------------------------------- /flaxmodels/gpt2/README.md: -------------------------------------------------------------------------------- 1 | # Better Language Models and Their Implications (GPT2) 2 | 3 | 4 | Paper: https://openai.com/blog/better-language-models/ 5 | Repository: https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2 6 | 7 | 8 | ##### Table of Contents 9 | * [1. Models](#models) 10 | * [2. Basic Usage](#usage) 11 | * [3. Documentation](#documentation) 12 | * [4. Acknowledgments](#ack) 13 | * [5. License](#license) 14 | 15 | 16 | 17 | ## 1. Models 18 | 19 | | Model | Parameters | Size | URL | 20 | | ------------- | ------------- | ------------- | ------------- | 21 | | gpt2 | ~ 120 Million | ~ 500 MB | https://huggingface.co/gpt2 | 22 | | gpt2-medium | ~ 350 Million | ~ 1.5 GB | https://huggingface.co/gpt2-medium | 23 | | gpt2-large | ~ 800 Million | ~ 3 GB | https://huggingface.co/gpt2-large | 24 | | gpt2-xl | ~ 1.5 Billion | ~ 6 GB | https://huggingface.co/gpt2-xl | 25 | 26 | 27 | 28 | ## 2. Basic Usage 29 | For more usage examples check out this [Colab](gpt2_demo.ipynb). 30 | 31 | This is very simple greedy text generation. There are more sophisticated methods out there. 32 | ```python 33 | import jax 34 | import jax.numpy as jnp 35 | import flaxmodels as fm 36 | 37 | key = jax.random.PRNGKey(0) 38 | 39 | # Initialize tokenizer 40 | tokenizer = fm.gpt2.get_tokenizer() 41 | 42 | # Encode start sequence 43 | generated = tokenizer.encode('The Manhattan bridge') 44 | 45 | context = jnp.array([generated]) 46 | past = None 47 | 48 | # Initialize model 49 | # Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] 50 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2') 51 | params = model.init(key, input_ids=context, past_key_values=past) 52 | 53 | for i in range(20): 54 | # Predict next token in sequence 55 | output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True) 56 | token = jnp.argmax(output['logits'][..., -1, :]) 57 | context = jnp.expand_dims(token, axis=0) 58 | # Add token to sequence 59 | generated += [token] 60 | # Update past keys and values 61 | past = output['past_key_values'] 62 | 63 | # Decode sequence of tokens 64 | sequence = tokenizer.decode(generated) 65 | print(sequence) 66 | ``` 67 | 68 | 69 | ## 3. Documentation 70 | The documentation can be found [here](../../docs/Documentation.md#gpt2). 71 | 72 | 73 | ## 4. Acknowledgments 74 | The tokenizer is taken from Huggingface. 75 | 76 | 77 | ## 5. License 78 | Apache-2.0 License 79 | 80 | 81 | -------------------------------------------------------------------------------- /flaxmodels/gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt2 import GPT2Model 2 | from .gpt2 import GPT2LMHeadModel 3 | from .tokenizer import * 4 | -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/__init__.py -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import io 18 | import os 19 | from os.path import expanduser 20 | from typing import Dict, List, Optional, Tuple 21 | 22 | from tqdm import tqdm 23 | 24 | import requests 25 | 26 | 27 | ENDPOINT = "https://huggingface.co" 28 | 29 | 30 | class RepoObj: 31 | """ 32 | HuggingFace git-based system, data structure that represents a file belonging to the current user. 33 | """ 34 | 35 | def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs): 36 | self.filename = filename 37 | self.lastModified = lastModified 38 | self.commit = commit 39 | self.size = size 40 | 41 | 42 | class ModelSibling: 43 | """ 44 | Data structure that represents a public file inside a model, accessible from huggingface.co 45 | """ 46 | 47 | def __init__(self, rfilename: str, **kwargs): 48 | self.rfilename = rfilename # filename relative to the model root 49 | for k, v in kwargs.items(): 50 | setattr(self, k, v) 51 | 52 | 53 | class ModelInfo: 54 | """ 55 | Info about a public model accessible from huggingface.co 56 | """ 57 | 58 | def __init__( 59 | self, 60 | modelId: Optional[str] = None, # id of model 61 | tags: List[str] = [], 62 | pipeline_tag: Optional[str] = None, 63 | siblings: Optional[List[Dict]] = None, # list of files that constitute the model 64 | **kwargs 65 | ): 66 | self.modelId = modelId 67 | self.tags = tags 68 | self.pipeline_tag = pipeline_tag 69 | self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None 70 | for k, v in kwargs.items(): 71 | setattr(self, k, v) 72 | 73 | 74 | class HfApi: 75 | def __init__(self, endpoint=None): 76 | self.endpoint = endpoint if endpoint is not None else ENDPOINT 77 | 78 | def login(self, username: str, password: str) -> str: 79 | """ 80 | Call HF API to sign in a user and get a token if credentials are valid. 81 | 82 | Outputs: token if credentials are valid 83 | 84 | Throws: requests.exceptions.HTTPError if credentials are invalid 85 | """ 86 | path = f"{self.endpoint}/api/login" 87 | r = requests.post(path, json={"username": username, "password": password}) 88 | r.raise_for_status() 89 | d = r.json() 90 | return d["token"] 91 | 92 | def whoami(self, token: str) -> Tuple[str, List[str]]: 93 | """ 94 | Call HF API to know "whoami" 95 | """ 96 | path = f"{self.endpoint}/api/whoami" 97 | r = requests.get(path, headers={"authorization": f"Bearer {token}"}) 98 | r.raise_for_status() 99 | d = r.json() 100 | return d["user"], d["orgs"] 101 | 102 | def logout(self, token: str) -> None: 103 | """ 104 | Call HF API to log out. 105 | """ 106 | path = f"{self.endpoint}/api/logout" 107 | r = requests.post(path, headers={"authorization": f"Bearer {token}"}) 108 | r.raise_for_status() 109 | 110 | def model_list(self) -> List[ModelInfo]: 111 | """ 112 | Get the public list of all the models on huggingface.co 113 | """ 114 | path = f"{self.endpoint}/api/models" 115 | r = requests.get(path) 116 | r.raise_for_status() 117 | d = r.json() 118 | return [ModelInfo(**x) for x in d] 119 | 120 | def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]: 121 | """ 122 | HuggingFace git-based system, used for models. 123 | 124 | Call HF API to list all stored files for user (or one of their organizations). 125 | """ 126 | path = f"{self.endpoint}/api/repos/ls" 127 | params = {"organization": organization} if organization is not None else None 128 | r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"}) 129 | r.raise_for_status() 130 | d = r.json() 131 | return [RepoObj(**x) for x in d] 132 | 133 | def create_repo( 134 | self, 135 | token: str, 136 | name: str, 137 | organization: Optional[str] = None, 138 | private: Optional[bool] = None, 139 | exist_ok=False, 140 | lfsmultipartthresh: Optional[int] = None, 141 | ) -> str: 142 | """ 143 | HuggingFace git-based system, used for models. 144 | 145 | Call HF API to create a whole repo. 146 | 147 | Params: 148 | private: Whether the model repo should be private (requires a paid huggingface.co account) 149 | 150 | exist_ok: Do not raise an error if repo already exists 151 | 152 | lfsmultipartthresh: Optional: internal param for testing purposes. 153 | """ 154 | path = f"{self.endpoint}/api/repos/create" 155 | json = {"name": name, "organization": organization, "private": private} 156 | if lfsmultipartthresh is not None: 157 | json["lfsmultipartthresh"] = lfsmultipartthresh 158 | r = requests.post( 159 | path, 160 | headers={"authorization": f"Bearer {token}"}, 161 | json=json, 162 | ) 163 | if exist_ok and r.status_code == 409: 164 | return "" 165 | r.raise_for_status() 166 | d = r.json() 167 | return d["url"] 168 | 169 | def delete_repo(self, token: str, name: str, organization: Optional[str] = None): 170 | """ 171 | HuggingFace git-based system, used for models. 172 | 173 | Call HF API to delete a whole repo. 174 | 175 | CAUTION(this is irreversible). 176 | """ 177 | path = f"{self.endpoint}/api/repos/delete" 178 | r = requests.delete( 179 | path, 180 | headers={"authorization": f"Bearer {token}"}, 181 | json={"name": name, "organization": organization}, 182 | ) 183 | r.raise_for_status() 184 | 185 | 186 | class TqdmProgressFileReader: 187 | """ 188 | Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a 189 | tqdm progress bar. 190 | 191 | see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details. 192 | """ 193 | 194 | def __init__(self, f: io.BufferedReader): 195 | self.f = f 196 | self.total_size = os.fstat(f.fileno()).st_size 197 | self.pbar = tqdm(total=self.total_size, leave=False) 198 | self.read = f.read 199 | f.read = self._read 200 | 201 | def _read(self, n=-1): 202 | self.pbar.update(n) 203 | return self.read(n) 204 | 205 | def close(self): 206 | self.pbar.close() 207 | 208 | 209 | class HfFolder: 210 | path_token = expanduser("~/.huggingface/token") 211 | 212 | @classmethod 213 | def save_token(cls, token): 214 | """ 215 | Save token, creating folder as needed. 216 | """ 217 | os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) 218 | with open(cls.path_token, "w+") as f: 219 | f.write(token) 220 | 221 | @classmethod 222 | def get_token(cls): 223 | """ 224 | Get token or None if not existent. 225 | """ 226 | try: 227 | with open(cls.path_token, "r") as f: 228 | return f.read() 229 | except FileNotFoundError: 230 | pass 231 | 232 | @classmethod 233 | def delete_token(cls): 234 | """ 235 | Delete token. Do not fail if token does not exist. 236 | """ 237 | try: 238 | os.remove(cls.path_token) 239 | except FileNotFoundError: 240 | pass 241 | -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities. """ 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | 32 | _lock = threading.Lock() 33 | _default_handler: Optional[logging.Handler] = None 34 | 35 | log_levels = { 36 | "debug": logging.DEBUG, 37 | "info": logging.INFO, 38 | "warning": logging.WARNING, 39 | "error": logging.ERROR, 40 | "critical": logging.CRITICAL, 41 | } 42 | 43 | _default_log_level = logging.WARNING 44 | 45 | 46 | def _get_default_logging_level(): 47 | """ 48 | If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 49 | not - fall back to ``_default_log_level`` 50 | """ 51 | env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None) 52 | if env_level_str: 53 | if env_level_str in log_levels: 54 | return log_levels[env_level_str] 55 | else: 56 | logging.getLogger().warning( 57 | f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " 58 | f"has to be one of: { ', '.join(log_levels.keys()) }" 59 | ) 60 | return _default_log_level 61 | 62 | 63 | def _get_library_name() -> str: 64 | 65 | return __name__.split(".")[0] 66 | 67 | 68 | def _get_library_root_logger() -> logging.Logger: 69 | 70 | return logging.getLogger(_get_library_name()) 71 | 72 | 73 | def _configure_library_root_logger() -> None: 74 | 75 | global _default_handler 76 | 77 | with _lock: 78 | if _default_handler: 79 | # This library has already configured the library root logger. 80 | return 81 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 82 | _default_handler.flush = sys.stderr.flush 83 | 84 | # Apply our default configuration to the library root logger. 85 | library_root_logger = _get_library_root_logger() 86 | library_root_logger.addHandler(_default_handler) 87 | library_root_logger.setLevel(_get_default_logging_level()) 88 | library_root_logger.propagate = False 89 | 90 | 91 | def _reset_library_root_logger() -> None: 92 | 93 | global _default_handler 94 | 95 | with _lock: 96 | if not _default_handler: 97 | return 98 | 99 | library_root_logger = _get_library_root_logger() 100 | library_root_logger.removeHandler(_default_handler) 101 | library_root_logger.setLevel(logging.NOTSET) 102 | _default_handler = None 103 | 104 | 105 | def get_logger(name: Optional[str] = None) -> logging.Logger: 106 | """ 107 | Return a logger with the specified name. 108 | 109 | This function is not supposed to be directly accessed unless you are writing a custom transformers module. 110 | """ 111 | 112 | if name is None: 113 | name = _get_library_name() 114 | 115 | _configure_library_root_logger() 116 | return logging.getLogger(name) 117 | 118 | 119 | def get_verbosity() -> int: 120 | """ 121 | Return the current level for the 🤗 Transformers's root logger as an int. 122 | 123 | Returns: 124 | :obj:`int`: The logging level. 125 | 126 | .. note:: 127 | 128 | 🤗 Transformers has following logging levels: 129 | 130 | - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 131 | - 40: ``transformers.logging.ERROR`` 132 | - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 133 | - 20: ``transformers.logging.INFO`` 134 | - 10: ``transformers.logging.DEBUG`` 135 | """ 136 | 137 | _configure_library_root_logger() 138 | return _get_library_root_logger().getEffectiveLevel() 139 | 140 | 141 | def set_verbosity(verbosity: int) -> None: 142 | """ 143 | Set the vebosity level for the 🤗 Transformers's root logger. 144 | 145 | Args: 146 | verbosity (:obj:`int`): 147 | Logging level, e.g., one of: 148 | 149 | - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 150 | - ``transformers.logging.ERROR`` 151 | - ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 152 | - ``transformers.logging.INFO`` 153 | - ``transformers.logging.DEBUG`` 154 | """ 155 | 156 | _configure_library_root_logger() 157 | _get_library_root_logger().setLevel(verbosity) 158 | 159 | 160 | def set_verbosity_info(): 161 | """Set the verbosity to the :obj:`INFO` level.""" 162 | return set_verbosity(INFO) 163 | 164 | 165 | def set_verbosity_warning(): 166 | """Set the verbosity to the :obj:`WARNING` level.""" 167 | return set_verbosity(WARNING) 168 | 169 | 170 | def set_verbosity_debug(): 171 | """Set the verbosity to the :obj:`DEBUG` level.""" 172 | return set_verbosity(DEBUG) 173 | 174 | 175 | def set_verbosity_error(): 176 | """Set the verbosity to the :obj:`ERROR` level.""" 177 | return set_verbosity(ERROR) 178 | 179 | 180 | def disable_default_handler() -> None: 181 | """Disable the default handler of the HuggingFace Transformers's root logger.""" 182 | 183 | _configure_library_root_logger() 184 | 185 | assert _default_handler is not None 186 | _get_library_root_logger().removeHandler(_default_handler) 187 | 188 | 189 | def enable_default_handler() -> None: 190 | """Enable the default handler of the HuggingFace Transformers's root logger.""" 191 | 192 | _configure_library_root_logger() 193 | 194 | assert _default_handler is not None 195 | _get_library_root_logger().addHandler(_default_handler) 196 | 197 | 198 | def add_handler(handler: logging.Handler) -> None: 199 | """adds a handler to the HuggingFace Transformers's root logger.""" 200 | 201 | _configure_library_root_logger() 202 | 203 | assert handler is not None 204 | _get_library_root_logger().addHandler(handler) 205 | 206 | 207 | def remove_handler(handler: logging.Handler) -> None: 208 | """removes given handler from the HuggingFace Transformers's root logger.""" 209 | 210 | _configure_library_root_logger() 211 | 212 | assert handler is not None and handler not in _get_library_root_logger().handlers 213 | _get_library_root_logger().removeHandler(handler) 214 | 215 | 216 | def disable_propagation() -> None: 217 | """ 218 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 219 | """ 220 | 221 | _configure_library_root_logger() 222 | _get_library_root_logger().propagate = False 223 | 224 | 225 | def enable_propagation() -> None: 226 | """ 227 | Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to 228 | prevent double logging if the root logger has been configured. 229 | """ 230 | 231 | _configure_library_root_logger() 232 | _get_library_root_logger().propagate = True 233 | 234 | 235 | def enable_explicit_format() -> None: 236 | """ 237 | Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: 238 | 239 | :: 240 | 241 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 242 | 243 | All handlers currently bound to the root logger are affected by this method. 244 | """ 245 | handlers = _get_library_root_logger().handlers 246 | 247 | for handler in handlers: 248 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 249 | handler.setFormatter(formatter) 250 | 251 | 252 | def reset_format() -> None: 253 | """ 254 | Resets the formatting for HuggingFace Transformers's loggers. 255 | 256 | All handlers currently bound to the root logger are affected by this method. 257 | """ 258 | handlers = _get_library_root_logger().handlers 259 | 260 | for handler in handlers: 261 | handler.setFormatter(None) 262 | -------------------------------------------------------------------------------- /flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities for working with package versions 16 | """ 17 | 18 | import operator 19 | import re 20 | import sys 21 | from typing import Optional 22 | 23 | from packaging import version 24 | 25 | 26 | # The package importlib_metadata is in a different place, depending on the python version. 27 | if sys.version_info < (3, 8): 28 | import importlib_metadata 29 | else: 30 | import importlib.metadata as importlib_metadata 31 | 32 | 33 | ops = { 34 | "<": operator.lt, 35 | "<=": operator.le, 36 | "==": operator.eq, 37 | "!=": operator.ne, 38 | ">=": operator.ge, 39 | ">": operator.gt, 40 | } 41 | 42 | 43 | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): 44 | if got_ver is None: 45 | raise ValueError("got_ver is None") 46 | if want_ver is None: 47 | raise ValueError("want_ver is None") 48 | if not ops[op](version.parse(got_ver), version.parse(want_ver)): 49 | raise ImportError( 50 | f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" 51 | ) 52 | 53 | 54 | def require_version(requirement: str, hint: Optional[str] = None) -> None: 55 | """ 56 | Perform a runtime check of the dependency versions, using the exact same syntax used by pip. 57 | 58 | The installed module version comes from the `site-packages` dir via `importlib_metadata`. 59 | 60 | Args: 61 | requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" 62 | hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met 63 | 64 | Example:: 65 | 66 | require_version("pandas>1.1.2") 67 | require_version("numpy>1.18.5", "this is important to have for whatever reason") 68 | 69 | """ 70 | 71 | hint = f"\n{hint}" if hint is not None else "" 72 | 73 | # non-versioned check 74 | if re.match(r"^[\w_\-\d]+$", requirement): 75 | pkg, op, want_ver = requirement, None, None 76 | else: 77 | match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) 78 | if not match: 79 | raise ValueError( 80 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" 81 | ) 82 | pkg, want_full = match[0] 83 | want_range = want_full.split(",") # there could be multiple requirements 84 | wanted = {} 85 | for w in want_range: 86 | match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) 87 | if not match: 88 | raise ValueError( 89 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" 90 | ) 91 | op, want_ver = match[0] 92 | wanted[op] = want_ver 93 | if op not in ops: 94 | raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") 95 | 96 | # special case 97 | if pkg == "python": 98 | got_ver = ".".join([str(x) for x in sys.version_info[:3]]) 99 | for op, want_ver in wanted.items(): 100 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 101 | return 102 | 103 | # check if any version is installed 104 | try: 105 | got_ver = importlib_metadata.version(pkg) 106 | except importlib_metadata.PackageNotFoundError: 107 | raise importlib_metadata.PackageNotFoundError( 108 | f"The '{requirement}' distribution was not found and is required by this application. {hint}" 109 | ) 110 | 111 | # check that the right version is installed if version number or a range was provided 112 | if want_ver is not None: 113 | for op, want_ver in wanted.items(): 114 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 115 | 116 | 117 | def require_version_core(requirement): 118 | """ require_version wrapper which emits a core-specific hint on failure """ 119 | hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master" 120 | return require_version(requirement, hint) 121 | 122 | 123 | def require_version_examples(requirement): 124 | """ require_version wrapper which emits examples-specific hint on failure """ 125 | hint = "Try: pip install -r examples/requirements.txt" 126 | return require_version(requirement, hint) 127 | -------------------------------------------------------------------------------- /flaxmodels/gpt2/tokenizer.py: -------------------------------------------------------------------------------- 1 | from .third_party.huggingface_transformers.configuration_gpt2 import GPT2Tokenizer 2 | from .. import utils 3 | 4 | 5 | def get_tokenizer(errors='replace', 6 | unk_token='<|endoftext|>', 7 | bos_token='<|endoftext|>', 8 | eos_token='<|endoftext|>', 9 | add_prefix_space=False, 10 | ckpt_dir=None): 11 | """ 12 | Returns the GPT2Tokenizer from Huggingface with loaded merges and vocab files. 13 | See: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer 14 | 15 | Args: 16 | errors (str): Paradigm to follow when decoding bytes to UTF-8. 17 | unk_token (str): The unknown token. A token that is not in the 18 | vocabulary cannot be converted to an ID and is set to be this token instead. 19 | bos_token (str): The beginning of sequence token. 20 | eos_token (str): The end of sequence token. 21 | add_prefix_space (bool): Whether or not to add an initial space to the input. 22 | This allows to treat the leading word just as any other word. 23 | ckpt_dir (str): Path to directory, where merges and vocab files are downloaded to. 24 | If None, the files will be downloaded to a temp directory. 25 | 26 | Returns: 27 | (GPT2Tokenizer): GPT2 Tokenizer. 28 | 29 | """ 30 | merges_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt?dl=1') 31 | vocab_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json?dl=1') 32 | 33 | return GPT2Tokenizer(vocab_file=vocab_file, 34 | merges_file=merges_file, 35 | errors=errors, 36 | unk_token=unk_token, 37 | bos_token=bos_token, 38 | eos_token=eos_token, 39 | add_prefix_space=add_prefix_space) 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /flaxmodels/resnet/README.md: -------------------------------------------------------------------------------- 1 | # Deep Residual Learning for Image Recognition 2 | 3 | Paper: https://arxiv.org/abs/1512.03385 4 | Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 5 | 6 | ##### Table of Contents 7 | * [1. Important Note](#note) 8 | * [2. Basic Usage](#usage) 9 | * [3. Documentation](#documentation) 10 | * [4. Training](#training) 11 | * [5. License](#license) 12 | 13 | 14 | ## 1. Important Note 15 | Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use `normalize=False` (see [here](https://github.com/matthias-wright/flaxmodels/blob/main/docs/Documentation.md#33-resnet18-34-50-101-152) for details). 16 | 17 | 18 | ## 2. Basic Usage 19 | For more usage examples check out this [Colab](resnet_demo.ipynb). 20 | 21 | ```python 22 | from PIL import Image 23 | import jax 24 | import jax.numpy as jnp 25 | import flaxmodels as fm 26 | 27 | key = jax.random.PRNGKey(0) 28 | 29 | # Load image 30 | img = Image.open('example.jpg') 31 | # Image should be in range [0, 1] 32 | x = jnp.array(img, dtype=jnp.float32) / 255.0 33 | # Add batch dimension 34 | x = jnp.expand_dims(x, axis=0) 35 | 36 | resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') 37 | params = resnet18.init(key, x) 38 | # Shape [1, 1000] 39 | out = resnet18.apply(params, x, train=False) 40 | 41 | ``` 42 | Usage is equivalent for ResNet34, ResNet50, ResNet101, and Resnet152. 43 | 44 | 45 | ## 3. Documentation 46 | The documentation can be found [here](../../docs/Documentation.md#resnet). 47 | 48 | 49 | ## 4. Training 50 | If you want to train ResNet in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/resnet). 51 | 52 | 53 | ## 5. License 54 | MIT License 55 | 56 | 57 | -------------------------------------------------------------------------------- /flaxmodels/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet18 2 | from .resnet import ResNet34 3 | from .resnet import ResNet50 4 | from .resnet import ResNet101 5 | from .resnet import ResNet152 6 | -------------------------------------------------------------------------------- /flaxmodels/resnet/ops.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from flax.core import FrozenDict 3 | import jax.numpy as jnp 4 | 5 | from flax.linen.module import compact, merge_param 6 | from typing import (Any, Callable, Optional, Tuple) 7 | from jax.nn import initializers 8 | from jax import lax 9 | 10 | PRNGKey = Any 11 | Array = Any 12 | Shape = Tuple[int] 13 | Dtype = Any 14 | 15 | 16 | #---------------------------------------------------------------# 17 | # Normalization 18 | #---------------------------------------------------------------# 19 | def batch_norm(x, train, epsilon=1e-05, momentum=0.99, params=None, dtype='float32'): 20 | if params is None: 21 | x = BatchNorm(epsilon=epsilon, 22 | momentum=momentum, 23 | use_running_average=not train, 24 | dtype=dtype)(x) 25 | else: 26 | x = BatchNorm(epsilon=epsilon, 27 | momentum=momentum, 28 | bias_init=lambda *_ : jnp.array(params['bias']), 29 | scale_init=lambda *_ : jnp.array(params['scale']), 30 | mean_init=lambda *_ : jnp.array(params['mean']), 31 | var_init=lambda *_ : jnp.array(params['var']), 32 | use_running_average=not train, 33 | dtype=dtype)(x) 34 | return x 35 | 36 | 37 | def _absolute_dims(rank, dims): 38 | return tuple([rank + dim if dim < 0 else dim for dim in dims]) 39 | 40 | 41 | class BatchNorm(nn.Module): 42 | """BatchNorm Module. 43 | 44 | Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py 45 | 46 | Attributes: 47 | use_running_average: if True, the statistics stored in batch_stats 48 | will be used instead of computing the batch statistics on the input. 49 | axis: the feature or non-batch axis of the input. 50 | momentum: decay rate for the exponential moving average of the batch statistics. 51 | epsilon: a small float added to variance to avoid dividing by zero. 52 | dtype: the dtype of the computation (default: float32). 53 | use_bias: if True, bias (beta) is added. 54 | use_scale: if True, multiply by scale (gamma). 55 | When the next layer is linear (also e.g. nn.relu), this can be disabled 56 | since the scaling will be done by the next layer. 57 | bias_init: initializer for bias, by default, zero. 58 | scale_init: initializer for scale, by default, one. 59 | axis_name: the axis name used to combine batch statistics from multiple 60 | devices. See `jax.pmap` for a description of axis names (default: None). 61 | axis_index_groups: groups of axis indices within that named axis 62 | representing subsets of devices to reduce over (default: None). For 63 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over 64 | the examples on the first two and last two devices. See `jax.lax.psum` 65 | for more details. 66 | """ 67 | use_running_average: Optional[bool] = None 68 | axis: int = -1 69 | momentum: float = 0.99 70 | epsilon: float = 1e-5 71 | dtype: Dtype = jnp.float32 72 | use_bias: bool = True 73 | use_scale: bool = True 74 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 75 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 76 | mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32) 77 | var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32) 78 | axis_name: Optional[str] = None 79 | axis_index_groups: Any = None 80 | 81 | @compact 82 | def __call__(self, x, use_running_average: Optional[bool] = None): 83 | """Normalizes the input using batch statistics. 84 | 85 | NOTE: 86 | During initialization (when parameters are mutable) the running average 87 | of the batch statistics will not be updated. Therefore, the inputs 88 | fed during initialization don't need to match that of the actual input 89 | distribution and the reduction axis (set with `axis_name`) does not have 90 | to exist. 91 | Args: 92 | x: the input to be normalized. 93 | use_running_average: if true, the statistics stored in batch_stats 94 | will be used instead of computing the batch statistics on the input. 95 | Returns: 96 | Normalized inputs (the same shape as inputs). 97 | """ 98 | use_running_average = merge_param( 99 | 'use_running_average', self.use_running_average, use_running_average) 100 | x = jnp.asarray(x, jnp.float32) 101 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) 102 | axis = _absolute_dims(x.ndim, axis) 103 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) 104 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) 105 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) 106 | 107 | # see NOTE above on initialization behavior 108 | initializing = self.is_mutable_collection('params') 109 | 110 | ra_mean = self.variable('batch_stats', 'mean', 111 | self.mean_init, 112 | reduced_feature_shape) 113 | ra_var = self.variable('batch_stats', 'var', 114 | self.var_init, 115 | reduced_feature_shape) 116 | 117 | if use_running_average: 118 | mean, var = ra_mean.value, ra_var.value 119 | else: 120 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False) 121 | mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) 122 | if self.axis_name is not None and not initializing: 123 | concatenated_mean = jnp.concatenate([mean, mean2]) 124 | mean, mean2 = jnp.split( 125 | lax.pmean( 126 | concatenated_mean, 127 | axis_name=self.axis_name, 128 | axis_index_groups=self.axis_index_groups), 2) 129 | var = mean2 - lax.square(mean) 130 | 131 | if not initializing: 132 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean 133 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 134 | 135 | y = x - mean.reshape(feature_shape) 136 | mul = lax.rsqrt(var + self.epsilon) 137 | if self.use_scale: 138 | scale = self.param('scale', 139 | self.scale_init, 140 | reduced_feature_shape).reshape(feature_shape) 141 | mul = mul * scale 142 | y = y * mul 143 | if self.use_bias: 144 | bias = self.param('bias', 145 | self.bias_init, 146 | reduced_feature_shape).reshape(feature_shape) 147 | y = y + bias 148 | return jnp.asarray(y, self.dtype) 149 | 150 | 151 | -------------------------------------------------------------------------------- /flaxmodels/stylegan2/README.md: -------------------------------------------------------------------------------- 1 | # Analyzing and Improving the Image Quality of StyleGAN 2 | 3 |
img
4 | 5 | Paper: https://arxiv.org/abs/1912.04958 6 | Repository: https://github.com/NVlabs/stylegan2 and https://github.com/NVlabs/stylegan2-ada 7 | 8 | ##### Table of Contents 9 | * [1. Basic Usage](#usage) 10 | * [2. Documentation](#documentation) 11 | * [3. Training](#training) 12 | * [4. Pretrained Models](#models) 13 | * [5. License](#license) 14 | 15 | 16 | 17 | ## 1. Basic Usage 18 | For more usage examples check out this [Colab](stylegan2_demo.ipynb). 19 | 20 | ```python 21 | import numpy as np 22 | from PIL import Image 23 | import jax 24 | import jax.numpy as jnp 25 | import flaxmodels as fm 26 | 27 | # Seed 28 | key = jax.random.PRNGKey(0) 29 | 30 | # Input noise 31 | z = jax.random.normal(key, shape=(4, 512)) 32 | 33 | generator = fm.stylegan2.Generator(pretrained='metfaces') 34 | params = generator.init(key, z) 35 | images = generator.apply(params, z, train=False) 36 | 37 | # Normalize images to be in range [0, 1] 38 | images = (images - jnp.min(images)) / (jnp.max(images) - jnp.min(images)) 39 | 40 | # Save images 41 | for i in range(images.shape[0]): 42 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg') 43 | 44 | ``` 45 | 46 |
img
47 | 48 | 49 | ## 2. Documentation 50 | The documentation can be found [here](../../docs/Documentation.md#stylegan2). 51 | 52 | 53 | ## 3. Training 54 | If you want to train StyleGAN2 in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/stylegan2). 55 | 56 | 57 | ## 4. Pretrained Models 58 | 59 | ### Metfaces 60 |
img
61 | 62 | ### FFHQ 63 |
img
64 | 65 | ### AFHQ Wild 66 |
img
67 | 68 | ### AFHQ Dog 69 |
img
70 | 71 | ### AFHQ Cat 72 |
img
73 | 74 | ### LSUN Cat 75 |
img
76 | 77 | ### LSUN Horse 78 |
img
79 | 80 | ### LSUN Car 81 |
img
82 | 83 | ### BreCaHAD 84 |
img
85 | 86 | ### CIFAR-10 87 |
img
88 | 89 | ### LSUN Church 90 |
img
91 | 92 | 93 | 94 | ## 5. License 95 | Nvidia Source Code License-NC 96 | 97 | 98 | -------------------------------------------------------------------------------- /flaxmodels/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import SynthesisNetwork 2 | from .generator import MappingNetwork 3 | from .generator import Generator 4 | from .discriminator import Discriminator 5 | 6 | -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/afhqcat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqcat.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/afhqdog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqdog.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/afhqwild.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqwild.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/brecahad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/brecahad.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/car.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/cat.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/church.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/church.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/cifar10.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/ffhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/ffhq.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/gen_images_w_trunc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_w_trunc.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/gen_images_with_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_with_labels.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/gen_images_wo_trunc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_wo_trunc.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/horse.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/metfaces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/metfaces.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/style_mixing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/style_mixing.jpg -------------------------------------------------------------------------------- /flaxmodels/stylegan2/images/title.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/title.jpg -------------------------------------------------------------------------------- /flaxmodels/utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import requests 3 | import os 4 | import tempfile 5 | 6 | 7 | def download(ckpt_dir, url): 8 | name = url[url.rfind('/') + 1 : url.rfind('?')] 9 | if ckpt_dir is None: 10 | ckpt_dir = tempfile.gettempdir() 11 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') 12 | ckpt_file = os.path.join(ckpt_dir, name) 13 | if not os.path.exists(ckpt_file): 14 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 15 | if not os.path.exists(ckpt_dir): 16 | os.makedirs(ckpt_dir) 17 | 18 | response = requests.get(url, stream=True) 19 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 20 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 21 | 22 | # first create temp file, in case the download fails 23 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 24 | with open(ckpt_file_temp, 'wb') as file: 25 | for data in response.iter_content(chunk_size=1024): 26 | progress_bar.update(len(data)) 27 | file.write(data) 28 | progress_bar.close() 29 | 30 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 31 | print('An error occured while downloading, please try again.') 32 | if os.path.exists(ckpt_file_temp): 33 | os.remove(ckpt_file_temp) 34 | else: 35 | # if download was successful, rename the temp file 36 | os.rename(ckpt_file_temp, ckpt_file) 37 | return ckpt_file 38 | -------------------------------------------------------------------------------- /flaxmodels/vgg/README.md: -------------------------------------------------------------------------------- 1 | # Very Deep Convolutional Networks for Large-Scale Image Recognition 2 | 3 | Paper: https://arxiv.org/abs/1409.1556 4 | Project Page: https://www.robots.ox.ac.uk/~vgg/research/very_deep/ 5 | Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 6 | 7 | ##### Table of Contents 8 | * [1. Important Note](#note) 9 | * [2. Basic Usage](#usage) 10 | * [3. Documentation](#documentation) 11 | * [4. Training](#training) 12 | * [5. License](#license) 13 | 14 | 15 | 16 | ## 1. Important Note 17 | Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use `normalize=False` (see [here](https://github.com/matthias-wright/flaxmodels/blob/main/docs/Documentation.md#34-vgg16-19) for details). 18 | 19 | 20 | ## 2. Basic Usage 21 | For more usage examples check out this [Colab](vgg_demo.ipynb). 22 | 23 | ```python 24 | from PIL import Image 25 | import jax 26 | import jax.numpy as jnp 27 | import flaxmodels as fm 28 | 29 | key = jax.random.PRNGKey(0) 30 | 31 | # Load image 32 | img = Image.open('example.jpg') 33 | # Image must be 224x224 if classification head is included 34 | img = img.resize((224, 224)) 35 | # Image should be in range [0, 1] 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | # Add batch dimension 38 | x = jnp.expand_dims(x, axis=0) 39 | 40 | vgg16 = fm.VGG16(output='logits', pretrained='imagenet') 41 | params = vgg16.init(key, x) 42 | out = vgg16.apply(params, x, train=False) 43 | 44 | ``` 45 | Usage is equivalent for VGG19. 46 | 47 | 48 | ## 3. Documentation 49 | The documentation can be found [here](../../docs/Documentation.md#vgg). 50 | 51 | 52 | ## 4. Training 53 | If you want to train VGG in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/vgg). 54 | 55 | 56 | ## 5. License 57 | Creative Commons Attribution License 58 | -------------------------------------------------------------------------------- /flaxmodels/vgg/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import VGG16 2 | from .vgg import VGG19 3 | 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | 5 | directory = os.path.abspath(os.path.dirname(__file__)) 6 | with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setup(name='flaxmodels', 10 | version='0.1.3', 11 | url='https://github.com/matthias-wright/flaxmodels', 12 | author='Matthias Wright', 13 | packages=find_packages(), 14 | install_requires=['h5py>=2.10.0', 15 | 'numpy>=1.19.5', 16 | 'requests>=2.23.0', 17 | 'packaging>=20.9', 18 | 'dataclasses>=0.6', 19 | 'filelock>=3.0.12', 20 | 'jax>=0.3', 21 | 'jaxlib', 22 | 'flax>=0.4.0', 23 | 'Pillow>=7.1.2', 24 | 'regex>=2021.4.4', 25 | 'tqdm>=4.60.0'], 26 | extras_require={ 27 | 'testing': ['pytest'], 28 | }, 29 | python_requires='>=3.6', 30 | license='Each model has an individual license.', 31 | description='A collection of pretrained models in Flax.', 32 | long_description=long_description, 33 | long_description_content_type='text/markdown') 34 | -------------------------------------------------------------------------------- /tests/aux_files/elefant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/aux_files/elefant.jpg -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy -------------------------------------------------------------------------------- /tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy -------------------------------------------------------------------------------- /tests/gpt2/test_gpt2.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flaxmodels as fm 4 | 5 | 6 | def test_reference_output_lm_head_input_ids(): 7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy') 8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy') 9 | 10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy') 11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy') 12 | 13 | key = jax.random.PRNGKey(0) 14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2') 15 | 16 | params = model.init(key, input_ids=input_ids, labels=labels) 17 | output = model.apply(params, input_ids=input_ids, labels=labels) 18 | logits, loss = output['logits'], output['loss'] 19 | 20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 22 | 23 | assert diff_logits < 1e-4 and diff_loss < 1e-4 24 | 25 | 26 | def test_reference_output_lm_head_input_embds(): 27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy') 28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy') 29 | 30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy') 31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy') 32 | 33 | key = jax.random.PRNGKey(0) 34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2') 35 | 36 | params = model.init(key, input_embds=input_embds, labels=labels) 37 | output = model.apply(params, input_embds=input_embds, labels=labels) 38 | logits, loss = output['logits'], output['loss'] 39 | 40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 42 | 43 | assert diff_logits < 1e-4 and diff_loss < 1e-4 44 | 45 | 46 | def test_reference_output_model_input_ids(): 47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy') 48 | 49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy') 50 | 51 | key = jax.random.PRNGKey(0) 52 | model = fm.gpt2.GPT2Model(pretrained='gpt2') 53 | 54 | params = model.init(key, input_ids=input_ids) 55 | output = model.apply(params, input_ids=input_ids) 56 | 57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 58 | 59 | assert diff < 1e-4 60 | 61 | 62 | def test_reference_output_model_input_embds(): 63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy') 64 | 65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy') 66 | 67 | key = jax.random.PRNGKey(0) 68 | model = fm.gpt2.GPT2Model(pretrained='gpt2') 69 | 70 | params = model.init(key, input_embds=input_embds) 71 | output = model.apply(params, input_embds=input_embds) 72 | 73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 74 | 75 | assert diff < 1e-4 76 | 77 | -------------------------------------------------------------------------------- /tests/gpt2/test_gpt2_large.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flaxmodels as fm 4 | 5 | 6 | def test_reference_output_lm_head_input_ids(): 7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy') 8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy') 9 | 10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy') 11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy') 12 | 13 | key = jax.random.PRNGKey(0) 14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-large') 15 | 16 | params = model.init(key, input_ids=input_ids, labels=labels) 17 | output = model.apply(params, input_ids=input_ids, labels=labels) 18 | logits, loss = output['logits'], output['loss'] 19 | 20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 22 | 23 | assert diff_logits < 1e-3 and diff_loss < 1e-3 24 | 25 | 26 | def test_reference_output_lm_head_input_embds(): 27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy') 28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy') 29 | 30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy') 31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy') 32 | 33 | key = jax.random.PRNGKey(0) 34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-large') 35 | 36 | params = model.init(key, input_embds=input_embds, labels=labels) 37 | output = model.apply(params, input_embds=input_embds, labels=labels) 38 | logits, loss = output['logits'], output['loss'] 39 | 40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 42 | 43 | assert diff_logits < 1e-3 and diff_loss < 1e-3 44 | 45 | 46 | def test_reference_output_model_input_ids(): 47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy') 48 | 49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy') 50 | 51 | key = jax.random.PRNGKey(0) 52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-large') 53 | 54 | params = model.init(key, input_ids=input_ids) 55 | output = model.apply(params, input_ids=input_ids) 56 | 57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 58 | 59 | assert diff < 1e-4 60 | 61 | 62 | def test_reference_output_model_input_embds(): 63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy') 64 | 65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy') 66 | 67 | key = jax.random.PRNGKey(0) 68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-large') 69 | 70 | params = model.init(key, input_embds=input_embds) 71 | output = model.apply(params, input_embds=input_embds) 72 | 73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 74 | 75 | assert diff < 1e-4 76 | 77 | -------------------------------------------------------------------------------- /tests/gpt2/test_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flaxmodels as fm 4 | 5 | 6 | def test_reference_output_lm_head_input_ids(): 7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy') 8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy') 9 | 10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy') 11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy') 12 | 13 | key = jax.random.PRNGKey(0) 14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-medium') 15 | 16 | params = model.init(key, input_ids=input_ids, labels=labels) 17 | output = model.apply(params, input_ids=input_ids, labels=labels) 18 | logits, loss = output['logits'], output['loss'] 19 | 20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 22 | 23 | assert diff_logits < 1e-3 and diff_loss < 1e-3 24 | 25 | 26 | def test_reference_output_lm_head_input_embds(): 27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy') 28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy') 29 | 30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy') 31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy') 32 | 33 | key = jax.random.PRNGKey(0) 34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-medium') 35 | 36 | params = model.init(key, input_embds=input_embds, labels=labels) 37 | output = model.apply(params, input_embds=input_embds, labels=labels) 38 | logits, loss = output['logits'], output['loss'] 39 | 40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 42 | 43 | assert diff_logits < 1e-3 and diff_loss < 1e-3 44 | 45 | 46 | def test_reference_output_model_input_ids(): 47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy') 48 | 49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy') 50 | 51 | key = jax.random.PRNGKey(0) 52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-medium') 53 | 54 | params = model.init(key, input_ids=input_ids) 55 | output = model.apply(params, input_ids=input_ids) 56 | 57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 58 | 59 | assert diff < 1e-4 60 | 61 | 62 | def test_reference_output_model_input_embds(): 63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy') 64 | 65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy') 66 | 67 | key = jax.random.PRNGKey(0) 68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-medium') 69 | 70 | params = model.init(key, input_embds=input_embds) 71 | output = model.apply(params, input_embds=input_embds) 72 | 73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 74 | 75 | assert diff < 1e-4 76 | 77 | -------------------------------------------------------------------------------- /tests/gpt2/test_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flaxmodels as fm 4 | 5 | 6 | def test_reference_output_lm_head_input_ids(): 7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy') 8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy') 9 | 10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy') 11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy') 12 | 13 | key = jax.random.PRNGKey(0) 14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-xl') 15 | 16 | params = model.init(key, input_ids=input_ids, labels=labels) 17 | output = model.apply(params, input_ids=input_ids, labels=labels) 18 | logits, loss = output['logits'], output['loss'] 19 | 20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 22 | 23 | assert diff_logits < 1e-3 and diff_loss < 1e-3 24 | 25 | 26 | def test_reference_output_lm_head_input_embds(): 27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy') 28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy') 29 | 30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy') 31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy') 32 | 33 | key = jax.random.PRNGKey(0) 34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-xl') 35 | 36 | params = model.init(key, input_embds=input_embds, labels=labels) 37 | output = model.apply(params, input_embds=input_embds, labels=labels) 38 | logits, loss = output['logits'], output['loss'] 39 | 40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits)) 41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss)) 42 | 43 | assert diff_logits < 1e-3 and diff_loss < 1e-3 44 | 45 | 46 | def test_reference_output_model_input_ids(): 47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy') 48 | 49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy') 50 | 51 | key = jax.random.PRNGKey(0) 52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-xl') 53 | 54 | params = model.init(key, input_ids=input_ids) 55 | output = model.apply(params, input_ids=input_ids) 56 | 57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 58 | 59 | assert diff < 1e-4 60 | 61 | 62 | def test_reference_output_model_input_embds(): 63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy') 64 | 65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy') 66 | 67 | key = jax.random.PRNGKey(0) 68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-xl') 69 | 70 | params = model.init(key, input_embds=input_embds) 71 | output = model.apply(params, input_embds=input_embds) 72 | 73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state'])) 74 | 75 | assert diff < 1e-4 76 | 77 | -------------------------------------------------------------------------------- /tests/resnet/aux_files/resnet101_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet101_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/resnet/aux_files/resnet152_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet152_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/resnet/aux_files/resnet18_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet18_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/resnet/aux_files/resnet34_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet34_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/resnet/aux_files/resnet50_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet50_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/resnet/test_resnet101.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | resnet101 = fm.ResNet101(output='softmax', pretrained=None) 14 | params = resnet101.init(key, x) 15 | out = resnet101.apply(params, x, train=False) 16 | 17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 18 | 19 | 20 | def test_output_activations(): 21 | # If output='activations', the output should be a dict. 22 | key = jax.random.PRNGKey(0) 23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 24 | 25 | resnet101 = fm.ResNet101(output='activations', pretrained=None) 26 | params = resnet101.init(key, x) 27 | out = resnet101.apply(params, x, train=False) 28 | 29 | assert isinstance(out, dict) 30 | 31 | 32 | def test_reference_output(): 33 | key = jax.random.PRNGKey(0) 34 | img = Image.open('tests/aux_files/elefant.jpg') 35 | img = img.resize((224, 224)) 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | x = jnp.expand_dims(x, axis=0) 38 | 39 | resnet101 = fm.ResNet101(output='logits', pretrained='imagenet') 40 | params = resnet101.init(key, x) 41 | out = resnet101.apply(params, x, train=False) 42 | 43 | out_ref = jnp.load('tests/resnet/aux_files/resnet101_elefant_output_ref.npy') 44 | diff = jnp.mean(jnp.abs(out - out_ref)) 45 | 46 | assert diff < 1e-5 47 | -------------------------------------------------------------------------------- /tests/resnet/test_resnet152.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | resnet152 = fm.ResNet152(output='softmax', pretrained=None) 14 | params = resnet152.init(key, x) 15 | out = resnet152.apply(params, x, train=False) 16 | 17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 18 | 19 | 20 | def test_output_activations(): 21 | # If output='activations', the output should be a dict. 22 | key = jax.random.PRNGKey(0) 23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 24 | 25 | resnet152 = fm.ResNet152(output='activations', pretrained=None) 26 | params = resnet152.init(key, x) 27 | out = resnet152.apply(params, x, train=False) 28 | 29 | assert isinstance(out, dict) 30 | 31 | 32 | def test_reference_output(): 33 | key = jax.random.PRNGKey(0) 34 | img = Image.open('tests/aux_files/elefant.jpg') 35 | img = img.resize((224, 224)) 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | x = jnp.expand_dims(x, axis=0) 38 | 39 | resnet152 = fm.ResNet152(output='logits', pretrained='imagenet') 40 | params = resnet152.init(key, x) 41 | out = resnet152.apply(params, x, train=False) 42 | 43 | out_ref = jnp.load('tests/resnet/aux_files/resnet152_elefant_output_ref.npy') 44 | diff = jnp.mean(jnp.abs(out - out_ref)) 45 | 46 | assert diff < 1e-5 47 | 48 | -------------------------------------------------------------------------------- /tests/resnet/test_resnet18.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | resnet18 = fm.ResNet18(output='softmax', pretrained=None) 14 | params = resnet18.init(key, x, train=False) 15 | out, _ = resnet18.apply(params, x, mutable=['batch_stats']) 16 | 17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 18 | 19 | 20 | def test_output_activations(): 21 | # If output='activations', the output should be a dict. 22 | key = jax.random.PRNGKey(0) 23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 24 | 25 | resnet18 = fm.ResNet18(output='activations', pretrained=None) 26 | params = resnet18.init(key, x, train=False) 27 | out, _ = resnet18.apply(params, x, mutable=['batch_stats']) 28 | 29 | assert isinstance(out, dict) 30 | 31 | 32 | def test_reference_output(): 33 | key = jax.random.PRNGKey(0) 34 | img = Image.open('tests/aux_files/elefant.jpg') 35 | img = img.resize((224, 224)) 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | x = jnp.expand_dims(x, axis=0) 38 | 39 | resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') 40 | params = resnet18.init(key, x) 41 | out = resnet18.apply(params, x, train=False) 42 | 43 | out_ref = jnp.load('tests/resnet/aux_files/resnet18_elefant_output_ref.npy') 44 | diff = jnp.mean(jnp.abs(out - out_ref)) 45 | 46 | assert diff < 1e-5 47 | 48 | -------------------------------------------------------------------------------- /tests/resnet/test_resnet34.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | resnet34 = fm.ResNet34(output='softmax', pretrained=None) 14 | params = resnet34.init(key, x) 15 | out = resnet34.apply(params, x, train=False) 16 | 17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 18 | 19 | 20 | def test_output_activations(): 21 | # If output='activations', the output should be a dict. 22 | key = jax.random.PRNGKey(0) 23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 24 | 25 | resnet34 = fm.ResNet34(output='activations', pretrained=None) 26 | params = resnet34.init(key, x) 27 | out = resnet34.apply(params, x, train=False) 28 | 29 | assert isinstance(out, dict) 30 | 31 | 32 | def test_reference_output(): 33 | key = jax.random.PRNGKey(0) 34 | img = Image.open('tests/aux_files/elefant.jpg') 35 | img = img.resize((224, 224)) 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | x = jnp.expand_dims(x, axis=0) 38 | 39 | resnet34 = fm.ResNet34(output='logits', pretrained='imagenet') 40 | params = resnet34.init(key, x) 41 | out = resnet34.apply(params, x, train=False) 42 | 43 | out_ref = jnp.load('tests/resnet/aux_files/resnet34_elefant_output_ref.npy') 44 | diff = jnp.mean(jnp.abs(out - out_ref)) 45 | 46 | assert diff < 1e-5 47 | 48 | -------------------------------------------------------------------------------- /tests/resnet/test_resnet50.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | resnet50 = fm.ResNet50(output='softmax', pretrained=None) 14 | params = resnet50.init(key, x) 15 | out = resnet50.apply(params, x, train=False) 16 | 17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 18 | 19 | 20 | def test_output_activations(): 21 | # If output='activations', the output should be a dict. 22 | key = jax.random.PRNGKey(0) 23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 24 | 25 | resnet50 = fm.ResNet50(output='activations', pretrained=None) 26 | params = resnet50.init(key, x) 27 | out = resnet50.apply(params, x, train=False) 28 | 29 | assert isinstance(out, dict) 30 | 31 | 32 | def test_reference_output(): 33 | key = jax.random.PRNGKey(0) 34 | img = Image.open('tests/aux_files/elefant.jpg') 35 | img = img.resize((224, 224)) 36 | x = jnp.array(img, dtype=jnp.float32) / 255.0 37 | x = jnp.expand_dims(x, axis=0) 38 | 39 | resnet50 = fm.ResNet50(output='logits', pretrained='imagenet') 40 | params = resnet50.init(key, x) 41 | out = resnet50.apply(params, x, train=False) 42 | 43 | out_ref = jnp.load('tests/resnet/aux_files/resnet50_elefant_output_ref.npy') 44 | diff = jnp.mean(jnp.abs(out - out_ref)) 45 | 46 | assert diff < 1e-5 47 | -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/car_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/car_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/car_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/car_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/cat_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cat_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/cat_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cat_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/church_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/church_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/church_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/church_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/horse_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/horse_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/horse_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/horse_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/discriminator/test_discriminator.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_reference_output_afhqcat(): 9 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy') 10 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy') 11 | 12 | key = jax.random.PRNGKey(0) 13 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqcat') 14 | params = discriminator.init(key, img) 15 | out = discriminator.apply(params, img) 16 | 17 | diff = jnp.mean(jnp.abs(out - out_ref)) 18 | 19 | assert diff < 5e-4 20 | 21 | 22 | def test_reference_output_afhqdog(): 23 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy') 24 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy') 25 | 26 | key = jax.random.PRNGKey(0) 27 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqdog') 28 | params = discriminator.init(key, img) 29 | out = discriminator.apply(params, img) 30 | 31 | diff = jnp.mean(jnp.abs(out - out_ref)) 32 | 33 | assert diff < 4e-3 34 | 35 | 36 | def test_reference_output_afhqwild(): 37 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy') 38 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy') 39 | 40 | key = jax.random.PRNGKey(0) 41 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqwild') 42 | params = discriminator.init(key, img) 43 | out = discriminator.apply(params, img) 44 | 45 | diff = jnp.mean(jnp.abs(out - out_ref)) 46 | 47 | assert diff < 4e-4 48 | 49 | 50 | def test_reference_output_brecahad(): 51 | img = jnp.load('tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy') 52 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy') 53 | 54 | key = jax.random.PRNGKey(0) 55 | discriminator = fm.stylegan2.Discriminator(pretrained='brecahad') 56 | params = discriminator.init(key, img) 57 | out = discriminator.apply(params, img) 58 | 59 | diff = jnp.mean(jnp.abs(out - out_ref)) 60 | 61 | assert diff < 2e-4 62 | 63 | 64 | def test_reference_output_car(): 65 | img = jnp.load('tests/stylegan2/discriminator/aux_files/car_input_img.npy') 66 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/car_output_ref.npy') 67 | 68 | key = jax.random.PRNGKey(0) 69 | discriminator = fm.stylegan2.Discriminator(pretrained='car') 70 | params = discriminator.init(key, img) 71 | out = discriminator.apply(params, img) 72 | 73 | diff = jnp.mean(jnp.abs(out - out_ref)) 74 | 75 | assert diff < 3e-4 76 | 77 | 78 | def test_reference_output_cat(): 79 | img = jnp.load('tests/stylegan2/discriminator/aux_files/cat_input_img.npy') 80 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/cat_output_ref.npy') 81 | 82 | key = jax.random.PRNGKey(0) 83 | discriminator = fm.stylegan2.Discriminator(pretrained='cat') 84 | params = discriminator.init(key, img) 85 | out = discriminator.apply(params, img) 86 | 87 | diff = jnp.mean(jnp.abs(out - out_ref)) 88 | 89 | assert diff < 1e-4 90 | 91 | 92 | def test_reference_output_church(): 93 | img = jnp.load('tests/stylegan2/discriminator/aux_files/church_input_img.npy') 94 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/church_output_ref.npy') 95 | 96 | key = jax.random.PRNGKey(0) 97 | discriminator = fm.stylegan2.Discriminator(pretrained='church') 98 | params = discriminator.init(key, img) 99 | out = discriminator.apply(params, img) 100 | 101 | diff = jnp.mean(jnp.abs(out - out_ref)) 102 | 103 | assert diff < 1e-4 104 | 105 | 106 | def test_reference_output_cifar10(): 107 | img = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy') 108 | label = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy') 109 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy') 110 | 111 | key = jax.random.PRNGKey(0) 112 | discriminator = fm.stylegan2.Discriminator(pretrained='cifar10') 113 | params = discriminator.init(key, img, label) 114 | out = discriminator.apply(params, img, label) 115 | 116 | diff = jnp.mean(jnp.abs(out - out_ref)) 117 | 118 | assert diff < 1e-2 119 | 120 | 121 | def test_reference_output_ffhq(): 122 | img = jnp.load('tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy') 123 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy') 124 | 125 | key = jax.random.PRNGKey(0) 126 | discriminator = fm.stylegan2.Discriminator(pretrained='ffhq') 127 | params = discriminator.init(key, img) 128 | out = discriminator.apply(params, img) 129 | 130 | diff = jnp.mean(jnp.abs(out - out_ref)) 131 | 132 | assert diff < 1e-4 133 | 134 | 135 | def test_reference_output_horse(): 136 | img = jnp.load('tests/stylegan2/discriminator/aux_files/horse_input_img.npy') 137 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/horse_output_ref.npy') 138 | 139 | key = jax.random.PRNGKey(0) 140 | discriminator = fm.stylegan2.Discriminator(pretrained='horse') 141 | params = discriminator.init(key, img) 142 | out = discriminator.apply(params, img) 143 | 144 | diff = jnp.mean(jnp.abs(out - out_ref)) 145 | 146 | assert diff < 1e-4 147 | 148 | 149 | def test_reference_output_metfaces(): 150 | img = jnp.load('tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy') 151 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy') 152 | 153 | key = jax.random.PRNGKey(0) 154 | discriminator = fm.stylegan2.Discriminator(pretrained='metfaces') 155 | params = discriminator.init(key, img) 156 | out = discriminator.apply(params, img) 157 | 158 | diff = jnp.mean(jnp.abs(out - out_ref)) 159 | 160 | assert diff < 2e-3 161 | 162 | 163 | -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqcat_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqcat_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqdog_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqdog_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqwild_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqwild_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/brecahad_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/brecahad_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/brecahad_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/brecahad_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/car_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/car_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/car_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/car_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/cat_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cat_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/cat_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cat_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/church_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/church_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/church_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/church_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/cifar10_input_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_input_label.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/cifar10_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/cifar10_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/ffhq_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/ffhq_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/ffhq_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/ffhq_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/horse_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/horse_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/horse_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/horse_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/metfaces_input_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/metfaces_input_z.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/aux_files/metfaces_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/metfaces_output_ref.npy -------------------------------------------------------------------------------- /tests/stylegan2/generator/test_generator.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_reference_output_afhqcat(): 9 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqcat_input_z.npy') 10 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy') 11 | 12 | key = jax.random.PRNGKey(0) 13 | generator = fm.stylegan2.Generator(pretrained='afhqcat') 14 | params = generator.init(key, z) 15 | out = generator.apply(params, z, noise_mode='const', train=False) 16 | 17 | diff = jnp.mean(jnp.abs(out - out_ref)) 18 | 19 | assert diff < 2e-2 20 | 21 | 22 | def test_reference_output_afhqdog(): 23 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqdog_input_z.npy') 24 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy') 25 | 26 | key = jax.random.PRNGKey(0) 27 | generator = fm.stylegan2.Generator(pretrained='afhqdog') 28 | params = generator.init(key, z) 29 | out = generator.apply(params, z, noise_mode='const', train=False) 30 | 31 | diff = jnp.mean(jnp.abs(out - out_ref)) 32 | 33 | assert diff < 5e-3 34 | 35 | 36 | def test_reference_output_afhqwild(): 37 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqwild_input_z.npy') 38 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy') 39 | 40 | key = jax.random.PRNGKey(0) 41 | generator = fm.stylegan2.Generator(pretrained='afhqwild') 42 | params = generator.init(key, z) 43 | out = generator.apply(params, z, noise_mode='const', train=False) 44 | 45 | diff = jnp.mean(jnp.abs(out - out_ref)) 46 | 47 | assert diff < 3e-3 48 | 49 | 50 | def test_reference_output_brecahad(): 51 | z = jnp.load('tests/stylegan2/generator/aux_files/brecahad_input_z.npy') 52 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/brecahad_output_ref.npy') 53 | 54 | key = jax.random.PRNGKey(0) 55 | generator = fm.stylegan2.Generator(pretrained='brecahad') 56 | params = generator.init(key, z) 57 | out = generator.apply(params, z, noise_mode='const', train=False) 58 | 59 | diff = jnp.mean(jnp.abs(out - out_ref)) 60 | 61 | assert diff < 3e-3 62 | 63 | 64 | def test_reference_output_car(): 65 | z = jnp.load('tests/stylegan2/generator/aux_files/car_input_z.npy') 66 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/car_output_ref.npy') 67 | 68 | key = jax.random.PRNGKey(0) 69 | generator = fm.stylegan2.Generator(pretrained='car') 70 | params = generator.init(key, z) 71 | out = generator.apply(params, z, noise_mode='const', train=False) 72 | 73 | diff = jnp.mean(jnp.abs(out - out_ref)) 74 | 75 | assert diff < 1e-5 76 | 77 | 78 | def test_reference_output_cat(): 79 | z = jnp.load('tests/stylegan2/generator/aux_files/cat_input_z.npy') 80 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/cat_output_ref.npy') 81 | 82 | key = jax.random.PRNGKey(0) 83 | generator = fm.stylegan2.Generator(pretrained='cat') 84 | params = generator.init(key, z) 85 | out = generator.apply(params, z, noise_mode='const', train=False) 86 | 87 | diff = jnp.mean(jnp.abs(out - out_ref)) 88 | 89 | assert diff < 1e-5 90 | 91 | 92 | def test_reference_output_church(): 93 | z = jnp.load('tests/stylegan2/generator/aux_files/church_input_z.npy') 94 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/church_output_ref.npy') 95 | 96 | key = jax.random.PRNGKey(0) 97 | generator = fm.stylegan2.Generator(pretrained='church') 98 | params = generator.init(key, z) 99 | out = generator.apply(params, z, noise_mode='const', train=False) 100 | 101 | diff = jnp.mean(jnp.abs(out - out_ref)) 102 | 103 | assert diff < 1e-5 104 | 105 | 106 | def test_reference_output_cifar10(): 107 | z = jnp.load('tests/stylegan2/generator/aux_files/cifar10_input_z.npy') 108 | label = jnp.load('tests/stylegan2/generator/aux_files/cifar10_input_label.npy') 109 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/cifar10_output_ref.npy') 110 | 111 | key = jax.random.PRNGKey(0) 112 | generator = fm.stylegan2.Generator(pretrained='cifar10') 113 | params = generator.init(key, z, label) 114 | out = generator.apply(params, z, label, noise_mode='const', train=False) 115 | 116 | diff = jnp.mean(jnp.abs(out - out_ref)) 117 | 118 | assert diff < 3e-3 119 | 120 | 121 | def test_reference_output_ffhq(): 122 | z = jnp.load('tests/stylegan2/generator/aux_files/ffhq_input_z.npy') 123 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/ffhq_output_ref.npy') 124 | 125 | key = jax.random.PRNGKey(0) 126 | generator = fm.stylegan2.Generator(pretrained='ffhq') 127 | params = generator.init(key, z) 128 | out = generator.apply(params, z, noise_mode='const', train=False) 129 | 130 | diff = jnp.mean(jnp.abs(out - out_ref)) 131 | 132 | assert diff < 1e-5 133 | 134 | 135 | def test_reference_output_horse(): 136 | z = jnp.load('tests/stylegan2/generator/aux_files/horse_input_z.npy') 137 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/horse_output_ref.npy') 138 | 139 | key = jax.random.PRNGKey(0) 140 | generator = fm.stylegan2.Generator(pretrained='horse') 141 | params = generator.init(key, z) 142 | out = generator.apply(params, z, noise_mode='const', train=False) 143 | 144 | diff = jnp.mean(jnp.abs(out - out_ref)) 145 | 146 | assert diff < 1e-5 147 | 148 | 149 | def test_reference_output_metfaces(): 150 | z = jnp.load('tests/stylegan2/generator/aux_files/metfaces_input_z.npy') 151 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/metfaces_output_ref.npy') 152 | 153 | key = jax.random.PRNGKey(0) 154 | generator = fm.stylegan2.Generator(pretrained='metfaces') 155 | params = generator.init(key, z) 156 | out = generator.apply(params, z, noise_mode='const', train=False) 157 | 158 | diff = jnp.mean(jnp.abs(out - out_ref)) 159 | 160 | assert diff < 3e-4 161 | 162 | 163 | -------------------------------------------------------------------------------- /tests/vgg/aux_files/vgg16_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/vgg/aux_files/vgg16_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/vgg/aux_files/vgg19_elefant_output_ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/vgg/aux_files/vgg19_elefant_output_ref.npy -------------------------------------------------------------------------------- /tests/vgg/test_vgg16.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | vgg16 = fm.VGG16(output='softmax', pretrained=None) 14 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 15 | params = vgg16.init(init_rngs, x) 16 | out = vgg16.apply(params, x, train=False) 17 | 18 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 19 | 20 | 21 | def test_output_activations(): 22 | # If output='activations', the output should be a dict. 23 | key = jax.random.PRNGKey(0) 24 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 25 | 26 | vgg16 = fm.VGG16(output='activations', pretrained=None) 27 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 28 | params = vgg16.init(init_rngs, x) 29 | out = vgg16.apply(params, x, train=False) 30 | 31 | assert isinstance(out, dict) 32 | 33 | 34 | def test_include_head_true(): 35 | # If include_head=True, the output should be a tensor of shape [B, 1000]. 36 | key = jax.random.PRNGKey(0) 37 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 38 | 39 | vgg16 = fm.VGG16(include_head=True, pretrained=None) 40 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 41 | params = vgg16.init(init_rngs, x) 42 | out = vgg16.apply(params, x, train=False) 43 | 44 | assert hasattr(out, 'shape') and len(out.shape) == 2 and out.shape[0] == 1 and out.shape[1] == 1000 45 | 46 | 47 | def test_include_head_false(): 48 | # If include_head=False, the output should be a tensor of shape [B, *, *, 512]. 49 | key = jax.random.PRNGKey(0) 50 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 51 | 52 | vgg16 = fm.VGG16(include_head=False, pretrained=None) 53 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 54 | params = vgg16.init(init_rngs, x) 55 | out = vgg16.apply(params, x, train=False) 56 | 57 | assert hasattr(out, 'shape') and len(out.shape) == 4 and out.shape[0] == 1 and out.shape[-1] == 512 58 | 59 | 60 | def test_reference_output(): 61 | key = jax.random.PRNGKey(0) 62 | img = Image.open('tests/aux_files/elefant.jpg') 63 | img = img.resize((224, 224)) 64 | x = jnp.array(img, dtype=jnp.float32) / 255.0 65 | x = jnp.expand_dims(x, axis=0) 66 | 67 | vgg16 = fm.VGG16(output='logits', pretrained='imagenet') 68 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 69 | params = vgg16.init(init_rngs, x) 70 | out = vgg16.apply(params, x, train=False) 71 | 72 | out_ref = jnp.load('tests/vgg/aux_files/vgg16_elefant_output_ref.npy') 73 | diff = jnp.mean(jnp.abs(out - out_ref)) 74 | 75 | assert diff < 1e-6 76 | 77 | -------------------------------------------------------------------------------- /tests/vgg/test_vgg19.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import flaxmodels as fm 5 | from PIL import Image 6 | 7 | 8 | def test_output_softmax(): 9 | # If output='softmax', the output should be in range [0, 1] 10 | key = jax.random.PRNGKey(0) 11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 12 | 13 | vgg19 = fm.VGG19(output='softmax', pretrained=None) 14 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 15 | params = vgg19.init(init_rngs, x) 16 | out = vgg19.apply(params, x, train=False) 17 | 18 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0 19 | 20 | 21 | def test_output_activations(): 22 | # If output='activations', the output should be a dict. 23 | key = jax.random.PRNGKey(0) 24 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 25 | 26 | vgg19 = fm.VGG19(output='activations', pretrained=None) 27 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 28 | params = vgg19.init(init_rngs, x) 29 | out = vgg19.apply(params, x, train=False) 30 | 31 | assert isinstance(out, dict) 32 | 33 | 34 | def test_include_head_true(): 35 | # If include_head=True, the output should be a tensor of shape [B, 1000]. 36 | key = jax.random.PRNGKey(0) 37 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 38 | 39 | vgg19 = fm.VGG19(include_head=True, pretrained=None) 40 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 41 | params = vgg19.init(init_rngs, x) 42 | out = vgg19.apply(params, x, train=False) 43 | 44 | assert hasattr(out, 'shape') and len(out.shape) == 2 and out.shape[0] == 1 and out.shape[1] == 1000 45 | 46 | 47 | def test_include_head_false(): 48 | # If include_head=False, the output should be a tensor of shape [B, *, *, 512]. 49 | key = jax.random.PRNGKey(0) 50 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1) 51 | 52 | vgg19 = fm.VGG19(include_head=False, pretrained=None) 53 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 54 | params = vgg19.init(init_rngs, x) 55 | out = vgg19.apply(params, x, train=False) 56 | 57 | assert hasattr(out, 'shape') and len(out.shape) == 4 and out.shape[0] == 1 and out.shape[-1] == 512 58 | 59 | 60 | def test_reference_output(): 61 | key = jax.random.PRNGKey(0) 62 | img = Image.open('tests/aux_files/elefant.jpg') 63 | img = img.resize((224, 224)) 64 | x = jnp.array(img, dtype=jnp.float32) / 255.0 65 | x = jnp.expand_dims(x, axis=0) 66 | 67 | vgg19 = fm.VGG19(output='logits', pretrained='imagenet') 68 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} 69 | params = vgg19.init(init_rngs, x) 70 | out = vgg19.apply(params, x, train=False) 71 | 72 | out_ref = jnp.load('tests/vgg/aux_files/vgg19_elefant_output_ref.npy') 73 | diff = jnp.mean(jnp.abs(out - out_ref)) 74 | 75 | assert diff < 1e-5 76 | 77 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/README.md: -------------------------------------------------------------------------------- 1 | # Few-shot Image Generation via Cross-domain Correspondence Training in Jax/Flax 2 | This is the training code for the [Jax/Flax implementation](https://github.com/matthias-wright/flaxmodels/tree/main/flaxmodels/few_shot_gan_adaption) of [Few-shot Image Generation via Cross-domain Correspondence](https://arxiv.org/abs/2104.06820). 3 | 4 |
img
5 | 6 | #### Table of Contents 7 | * [Getting Started](#getting-started) 8 | * [Preparing Datasets for Training](#preparing-datasets-for-training) 9 | * [Training](#training) 10 | * [Checkpoints](#checkpoints) 11 | * [Generating Images](#generating-images) 12 | * [References](#references) 13 | * [License](#license) 14 | 15 | 16 | ## Getting Started 17 | You will need Python 3.7 or later. 18 | 19 | 1. Clone the repository: 20 | ```sh 21 | > git clone https://github.com/matthias-wright/flaxmodels.git 22 | ``` 23 | 2. Go into the directory: 24 | ```sh 25 | > cd flaxmodels/training/few_shot_gan_adaption 26 | ``` 27 | 3. Install Jax with CUDA. 28 | 4. Install requirements: 29 | ```sh 30 | > pip install -r requirements.txt 31 | ``` 32 | 33 | ## Preparing Datasets for Training 34 | Before training, the images should be stored in a [TFRecord dataset](https://www.tensorflow.org/tutorials/load_data/tfrecord). The TFRecord format stores your data as a sequence of bytes, which allows for fast data loading. 35 | Alternatively, you can also use [tfds.folder_dataset.ImageFolder](https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder) on the image directory directly but you will have to replace the `tf.data.TFRecordDataset` in `data_pipeline.py` with `tfds.folder_dataset.ImageFolder` (see [this](https://github.com/matthias-wright/flaxmodels/issues/8#issue-1020780783) thread for more info). 36 | 37 | 1. Download dataset from [here](https://github.com/utkarshojha/few-shot-gan-adaptation#choose-the-target-domain). 38 | 2. Put all images into a directory: 39 | ``` 40 | /path/to/image_dir/ 41 | 0.jpg 42 | 1.jpg 43 | 2.jpg 44 | 4.jpg 45 | ... 46 | ``` 47 | 3. Create TFRecord dataset: 48 | ```sh 49 | > python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord 50 | ``` 51 | `--image_dir` is the path to the image directory. 52 | `--data_dir` is the path where the TFRecord dataset is stored. 53 | 54 | 55 | ## Training 56 | Download checkpoint of source model: 57 | ```sh 58 | > wget https://www.dropbox.com/s/hyh1k8ixtzy24ye/ffhq_256x256.pickle\?dl\=1 -O ffhq_256x256.pickle 59 | ``` 60 | 61 | Start training: 62 | ```python 63 | > CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord --source_ckpt_path ffhq_256x256.pickle 64 | ``` 65 | Here `a`, `b`, `c`, `d` are the GPU indices. Multi GPU training (data parallelism) works by default and will automatically use all the devices that you make visible. 66 | 67 | 68 | ### Logging 69 | I use [Weights & Biases](https://wandb.ai/site) for logging but you can simply replace it with the logging method of your choice. The logging happens all in the training loop implemented in `training.py`. To use logging with Weights & Biases, use `--wand`. 70 | 71 | ### Checkpointing 72 | By default, every `1000` training steps the FID score is evaluated for `10.000` images. The checkpoint with the highest FID score is saved. You can change evaluation frequency using the `--eval_fid_every` argument and the number of images to evaluate the FID score on using `--num_fid_images`. 73 | You can disable the FID score evaluation using `--disable_fid`. In that case, a checkpoint will be saved every `2000` steps (can be changed using `--save_every`). 74 | 75 | 76 | ## Checkpoints 77 | * [Sketches](https://www.dropbox.com/s/azr6b316juhme6c/sketches.pickle?dl=1) (357,2 MB) 78 | * [Amedeo Modigliani](https://www.dropbox.com/s/xrh4a7wt2kggn4v/amedeo_modigliani.pickle?dl=1) (357,2 MB) 79 | * [Babies](https://www.dropbox.com/s/ntictyzrisqg5zh/babies.pickle?dl=1) (357,2 MB) 80 | * [Otto Dix](https://www.dropbox.com/s/u1g18nv73uac21m/otto_dix.pickle?dl=1) (357,2 MB) 81 | * [Rafael](https://www.dropbox.com/s/b8w928s4wffuo2c/raphael.pickle?dl=1) (357,2 MB) 82 | 83 | ## Generating Images 84 | ```python 85 | import jax 86 | import numpy as np 87 | import dill as pickle 88 | from PIL import Image 89 | 90 | import flaxmodels as fm 91 | 92 | ckpt = pickle.load(open('sketches.pickle', 'rb')) 93 | params = ckpt['params_ema_G'] 94 | 95 | generator = fm.few_shot_gan_adaption.Generator() 96 | 97 | # Seed 98 | key = jax.random.PRNGKey(0) 99 | 100 | # Input noise 101 | z = jax.random.normal(key, shape=(4, 512)) 102 | 103 | # Generate images 104 | images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const') 105 | 106 | # Normalize images to be in range [0, 1] 107 | images = (images - np.min(images)) / (np.max(images) - np.min(images)) 108 | 109 | # Save images 110 | for i in range(images.shape[0]): 111 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg') 112 | 113 | ``` 114 | 115 | ## References 116 | * [utkarshojha/few-shot-gan-adaptation](https://github.com/utkarshojha/few-shot-gan-adaptation) 117 | 118 | 119 | ## License 120 | [MIT License](https://opensource.org/licenses/MIT) 121 | 122 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/checkpoint.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import dill as pickle 3 | import os 4 | import glob 5 | 6 | 7 | def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, config, step, z_latent_anchors, keep=2): 8 | """ 9 | Saves checkpoint. 10 | 11 | Args: 12 | ckpt_dir (str): Path to the directory, where checkpoints are saved. 13 | state_G (train_state.TrainState): Generator state. 14 | state_D (train_state.TrainState): Discriminator state. 15 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator. 16 | config (argparse.Namespace): Configuration. 17 | step (int): Current step. 18 | z_latent_anchors (DeviceArray): Noise anchors. 19 | keep (int): Number of checkpoints to keep. 20 | """ 21 | 22 | state_G = flax.jax_utils.unreplicate(state_G) 23 | state_D = flax.jax_utils.unreplicate(state_D) 24 | params_G = {'params': {'mapping_network': state_G.params.unfreeze()['mapping'], 'synthesis_network': state_G.params.unfreeze()['synthesis']}, 25 | 'moving_stats': {'mapping_network': state_G.moving_stats}, 26 | 'noise_consts': {'synthesis_network': state_G.noise_consts}} 27 | 28 | ckpt_dict = {'params_G': params_G, 29 | 'params_D': state_D.params, 30 | 'params_ema_G': params_ema_G, 31 | 'z_latent_anchors': z_latent_anchors, 32 | 'config': config} 33 | 34 | with open(os.path.join(ckpt_dir, f'ckpt_{step}.pickle'), 'wb') as handle: 35 | pickle.dump(ckpt_dict, handle, protocol=pickle.DEFAULT_PROTOCOL) 36 | 37 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle')) 38 | if len(ckpts) > keep: 39 | oldest_ckpt = min(ckpts, key=os.path.getctime) 40 | os.remove(oldest_ckpt) 41 | 42 | 43 | def load_checkpoint(filename, replicate=True): 44 | """ 45 | Loads checkpoints. 46 | 47 | Args: 48 | filename (str): Path to the checkpoint file. 49 | replicate (bool): If True, replicate parameters across devices. 50 | 51 | Returns: 52 | (dict): Checkpoint. 53 | """ 54 | state_dict = pickle.load(open(filename, 'rb')) 55 | if replicate: 56 | state_dict['state_G'] = flax.jax_utils.replicate(state_dict['state_G']) 57 | state_dict['state_D'] = flax.jax_utils.replicate(state_dict['state_D']) 58 | state_dict['pl_mean'] = flax.jax_utils.replicate(state_dict['pl_mean']) 59 | return state_dict 60 | 61 | 62 | def get_latest_checkpoint(ckpt_dir): 63 | """ 64 | Returns the path of the latest checkpoint. 65 | 66 | Args: 67 | ckpt_dir (str): Path to the directory, where checkpoints are saved. 68 | 69 | Returns: 70 | (str): Path to latest checkpoint (if it exists). 71 | """ 72 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle')) 73 | if len(ckpts) == 0: 74 | return None 75 | latest_ckpt = max(ckpts, key=os.path.getctime) 76 | return latest_ckpt 77 | 78 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/data_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | import jax 4 | import flax 5 | import numpy as np 6 | from PIL import Image 7 | import os 8 | from typing import Sequence 9 | from tqdm import tqdm 10 | import json 11 | from tqdm import tqdm 12 | 13 | 14 | def prefetch(dataset, n_prefetch): 15 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py 16 | ds_iter = iter(dataset) 17 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), 18 | ds_iter) 19 | if n_prefetch: 20 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) 21 | return ds_iter 22 | 23 | 24 | def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000): 25 | """ 26 | 27 | Args: 28 | data_dir (str): Root directory of the dataset. 29 | img_size (int): Image size for training. 30 | img_channels (int): Number of image channels. 31 | num_classes (int): Number of classes, 0 for no classes. 32 | num_devices (int): Number of devices. 33 | batch_size (int): Batch size (per device). 34 | shuffle_buffer (int): Buffer used for shuffling the dataset. 35 | 36 | Returns: 37 | (tf.data.Dataset): Dataset. 38 | """ 39 | 40 | def pre_process(serialized_example): 41 | feature = {'height': tf.io.FixedLenFeature([], tf.int64), 42 | 'width': tf.io.FixedLenFeature([], tf.int64), 43 | 'channels': tf.io.FixedLenFeature([], tf.int64), 44 | 'image': tf.io.FixedLenFeature([], tf.string), 45 | 'label': tf.io.FixedLenFeature([], tf.int64)} 46 | example = tf.io.parse_single_example(serialized_example, feature) 47 | 48 | height = tf.cast(example['height'], dtype=tf.int64) 49 | width = tf.cast(example['width'], dtype=tf.int64) 50 | channels = tf.cast(example['channels'], dtype=tf.int64) 51 | 52 | image = tf.io.decode_raw(example['image'], out_type=tf.uint8) 53 | image = tf.reshape(image, shape=[height, width, channels]) 54 | 55 | image = tf.cast(image, dtype='float32') 56 | image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True) 57 | image = tf.image.random_flip_left_right(image) 58 | 59 | image = (image - 127.5) / 127.5 60 | 61 | label = tf.one_hot(example['label'], num_classes) 62 | return {'image': image, 'label': label} 63 | 64 | def shard(data): 65 | # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C] 66 | # because the first dimension will be mapped across devices using jax.pmap 67 | data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels]) 68 | data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes]) 69 | return data 70 | 71 | print('Loading TFRecord...') 72 | with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin: 73 | dataset_info = json.load(fin) 74 | 75 | ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords')) 76 | 77 | ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer)) 78 | ds = ds.map(pre_process, tf.data.AUTOTUNE) 79 | ds = ds.batch(batch_size * num_devices, drop_remainder=True) 80 | ds = ds.map(shard, tf.data.AUTOTUNE) 81 | ds = ds.prefetch(1) 82 | return ds, dataset_info 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/dataset_utils/images_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from PIL import Image 4 | from typing import Sequence 5 | from tqdm import tqdm 6 | import argparse 7 | import json 8 | import os 9 | 10 | 11 | def images_to_tfrecords(image_dir, data_dir, has_labels): 12 | """ 13 | Converts a folder of images to a TFRecord file. 14 | 15 | The image directory should have one of the following structures: 16 | 17 | If has_labels = False, image_dir should look like this: 18 | 19 | path/to/image_dir/ 20 | 0.jpg 21 | 1.jpg 22 | 2.jpg 23 | 4.jpg 24 | ... 25 | 26 | 27 | If has_labels = True, image_dir should look like this: 28 | 29 | path/to/image_dir/ 30 | label0/ 31 | 0.jpg 32 | 1.jpg 33 | ... 34 | label1/ 35 | a.jpg 36 | b.jpg 37 | c.jpg 38 | ... 39 | ... 40 | 41 | 42 | The labels will be label0 -> 0, label1 -> 1. 43 | 44 | Args: 45 | image_dir (str): Path to images. 46 | data_dir (str): Path where the TFrecords dataset is stored. 47 | has_labels (bool): If True, 'image_dir' contains label directories. 48 | 49 | Returns: 50 | (dict): Dataset info. 51 | """ 52 | 53 | def _bytes_feature(value): 54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 55 | 56 | def _int64_feature(value): 57 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 58 | 59 | os.makedirs(data_dir, exist_ok=True) 60 | writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords')) 61 | 62 | num_examples = 0 63 | num_classes = 0 64 | 65 | if has_labels: 66 | for label_dir in os.listdir(image_dir): 67 | if not os.path.isdir(os.path.join(image_dir, label_dir)): 68 | print('The image directory should contain one directory for each label.') 69 | print('These label directories should contain the image files.') 70 | if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')): 71 | os.remove(os.path.join(data_dir, 'dataset.tfrecords')) 72 | return 73 | 74 | for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))): 75 | file_format = img_file[img_file.rfind('.') + 1:] 76 | if file_format not in ['png', 'jpg', 'jpeg']: 77 | continue 78 | 79 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) 80 | img = Image.open(os.path.join(image_dir, label_dir, img_file)) 81 | img = np.array(img, dtype=np.uint8) 82 | 83 | height = img.shape[0] 84 | width = img.shape[1] 85 | channels = img.shape[2] 86 | 87 | img_encoded = img.tobytes() 88 | 89 | example = tf.train.Example(features=tf.train.Features(feature={ 90 | 'height': _int64_feature(height), 91 | 'width': _int64_feature(width), 92 | 'channels': _int64_feature(channels), 93 | 'image': _bytes_feature(img_encoded), 94 | 'label': _int64_feature(num_classes)})) 95 | 96 | writer.write(example.SerializeToString()) 97 | num_examples += 1 98 | 99 | num_classes += 1 100 | else: 101 | for img_file in tqdm(os.listdir(os.path.join(image_dir))): 102 | file_format = img_file[img_file.rfind('.') + 1:] 103 | if file_format not in ['png', 'jpg', 'jpeg']: 104 | continue 105 | 106 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) 107 | img = Image.open(os.path.join(image_dir, img_file)) 108 | img = np.array(img, dtype=np.uint8) 109 | 110 | height = img.shape[0] 111 | width = img.shape[1] 112 | channels = img.shape[2] 113 | 114 | img_encoded = img.tobytes() 115 | 116 | example = tf.train.Example(features=tf.train.Features(feature={ 117 | 'height': _int64_feature(height), 118 | 'width': _int64_feature(width), 119 | 'channels': _int64_feature(channels), 120 | 'image': _bytes_feature(img_encoded), 121 | 'label': _int64_feature(num_classes)})) # dummy label 122 | 123 | writer.write(example.SerializeToString()) 124 | num_examples += 1 125 | 126 | writer.close() 127 | 128 | dataset_info = {'num_examples': num_examples, 'num_classes': num_classes} 129 | with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout: 130 | json.dump(dataset_info, fout) 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.') 136 | parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.') 137 | parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.') 138 | 139 | args = parser.parse_args() 140 | 141 | images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels) 142 | 143 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/datatest.py: -------------------------------------------------------------------------------- 1 | import data_pipeline 2 | 3 | ds_train, dataset_info = data_pipeline.get_data(data_dir='datasets/Sketches', 4 | img_size=256, 5 | img_channels=3, 6 | num_classes=0, 7 | num_devices=1, 8 | batch_size=5) 9 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/fid/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FID 2 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/fid/core.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax 4 | import flax.linen as nn 5 | import numpy as np 6 | import os 7 | import functools 8 | import argparse 9 | import scipy 10 | from tqdm import tqdm 11 | 12 | from . import inception 13 | from . import utils 14 | 15 | 16 | class FID: 17 | 18 | def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0): 19 | """ 20 | Evaluates the FID score for a given generator and a given dataset. 21 | Implementation mostly taken from https://github.com/matthias-wright/jax-fid 22 | 23 | Reference: https://arxiv.org/abs/1706.08500 24 | 25 | Args: 26 | generator (nn.Module): Generator network. 27 | dataset (tf.data.Dataset): Dataset containing the real images. 28 | config (argparse.Namespace): Configuration. 29 | use_cache (bool): If True, only compute the activation stats once for the real images and store them. 30 | truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled. 31 | """ 32 | self.num_images = config.num_fid_images 33 | self.batch_size = config.batch_size 34 | self.c_dim = config.c_dim 35 | self.z_dim = config.z_dim 36 | self.dataset = dataset 37 | self.num_devices = jax.device_count() 38 | self.use_cache = use_cache 39 | 40 | if self.use_cache: 41 | self.cache = {} 42 | 43 | rng = jax.random.PRNGKey(0) 44 | inception_net = inception.InceptionV3(pretrained=True) 45 | self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3))) 46 | self.inception_params = flax.jax_utils.replicate(self.inception_params) 47 | #self.inception = jax.jit(functools.partial(model.apply, train=False)) 48 | self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch') 49 | 50 | self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch') 51 | 52 | def compute_fid(self, generator_params, seed_offset=0): 53 | generator_params = flax.jax_utils.replicate(generator_params) 54 | mu_real, sigma_real = self.compute_stats_for_dataset() 55 | mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset) 56 | fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6) 57 | return fid_score 58 | 59 | def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6): 60 | # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 61 | mu1 = np.atleast_1d(mu1) 62 | mu2 = np.atleast_1d(mu2) 63 | sigma1 = np.atleast_1d(sigma1) 64 | sigma2 = np.atleast_1d(sigma2) 65 | 66 | assert mu1.shape == mu2.shape 67 | assert sigma1.shape == sigma2.shape 68 | 69 | diff = mu1 - mu2 70 | 71 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 72 | if not np.isfinite(covmean).all(): 73 | msg = ('fid calculation produces singular product; ' 74 | 'adding %s to diagonal of cov estimates') % eps 75 | print(msg) 76 | offset = np.eye(sigma1.shape[0]) * eps 77 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 78 | 79 | # Numerical error might give slight imaginary component 80 | if np.iscomplexobj(covmean): 81 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 82 | m = np.max(np.abs(covmean.imag)) 83 | raise ValueError('Imaginary component {}'.format(m)) 84 | covmean = covmean.real 85 | 86 | tr_covmean = np.trace(covmean) 87 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 88 | 89 | def compute_stats_for_dataset(self): 90 | if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache: 91 | print('Use cached statistics for dataset...') 92 | return self.cache['mu'], self.cache['sigma'] 93 | 94 | print() 95 | print('Compute statistics for dataset...') 96 | pbar = tqdm(total=self.num_images) 97 | image_count = 0 98 | 99 | activations = [] 100 | for batch in utils.prefetch(self.dataset, n_prefetch=2): 101 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image'])) 102 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1)) 103 | activations.append(act) 104 | 105 | pbar.update(self.num_devices * self.batch_size) 106 | image_count += self.num_devices * self.batch_size 107 | if image_count >= self.num_images: 108 | break 109 | pbar.close() 110 | 111 | activations = jnp.concatenate(activations, axis=0) 112 | activations = activations[:self.num_images] 113 | mu = np.mean(activations, axis=0) 114 | sigma = np.cov(activations, rowvar=False) 115 | self.cache['mu'] = mu 116 | self.cache['sigma'] = sigma 117 | return mu, sigma 118 | 119 | def compute_stats_for_generator(self, generator_params, seed_offset): 120 | print() 121 | print('Compute statistics for generator...') 122 | num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_devices))) 123 | 124 | pbar = tqdm(total=self.num_images) 125 | activations = [] 126 | 127 | for i in range(num_batches): 128 | rng = jax.random.PRNGKey(seed_offset + i) 129 | z_latent = jax.random.normal(rng, shape=(self.num_devices, self.batch_size, self.z_dim)) 130 | 131 | labels = None 132 | if self.c_dim > 0: 133 | labels = jax.random.randint(rng, shape=(self.num_devices * self.batch_size,), minval=0, maxval=self.c_dim) 134 | labels = jax.nn.one_hot(labels, num_classes=self.c_dim) 135 | labels = jnp.reshape(labels, (self.num_devices, self.batch_size, self.c_dim)) 136 | 137 | image, _ = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels) 138 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) 139 | 140 | image = 2 * image - 1 141 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image)) 142 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1)) 143 | activations.append(act) 144 | pbar.update(self.num_devices * self.batch_size) 145 | pbar.close() 146 | 147 | activations = jnp.concatenate(activations, axis=0) 148 | activations = activations[:self.num_images] 149 | mu = np.mean(activations, axis=0) 150 | sigma = np.cov(activations, rowvar=False) 151 | return mu, sigma 152 | 153 | 154 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/fid/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import flax 3 | import numpy as np 4 | from tqdm import tqdm 5 | import requests 6 | import os 7 | import tempfile 8 | 9 | 10 | def download(url, ckpt_dir=None): 11 | name = url[url.rfind('/') + 1 : url.rfind('?')] 12 | if ckpt_dir is None: 13 | ckpt_dir = tempfile.gettempdir() 14 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') 15 | ckpt_file = os.path.join(ckpt_dir, name) 16 | if not os.path.exists(ckpt_file): 17 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 18 | if not os.path.exists(ckpt_dir): 19 | os.makedirs(ckpt_dir) 20 | 21 | response = requests.get(url, stream=True) 22 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 23 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 24 | 25 | # first create temp file, in case the download fails 26 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 27 | with open(ckpt_file_temp, 'wb') as file: 28 | for data in response.iter_content(chunk_size=1024): 29 | progress_bar.update(len(data)) 30 | file.write(data) 31 | progress_bar.close() 32 | 33 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 34 | print('An error occured while downloading, please try again.') 35 | if os.path.exists(ckpt_file_temp): 36 | os.remove(ckpt_file_temp) 37 | else: 38 | # if download was successful, rename the temp file 39 | os.rename(ckpt_file_temp, ckpt_file) 40 | return ckpt_file 41 | 42 | 43 | def get(dictionary, key): 44 | if dictionary is None or key not in dictionary: 45 | return None 46 | return dictionary[key] 47 | 48 | 49 | def prefetch(dataset, n_prefetch): 50 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py 51 | ds_iter = iter(dataset) 52 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), 53 | ds_iter) 54 | if n_prefetch: 55 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) 56 | return ds_iter 57 | -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/images/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/few_shot_gan_adaption/images/overview.jpg -------------------------------------------------------------------------------- /training/few_shot_gan_adaption/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import jax 4 | import wandb 5 | import training 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | # Paths 11 | parser.add_argument('--work_dir', type=str, default='logging', help='Directory for logging and checkpoints.') 12 | parser.add_argument('--data_dir', type=str, default='datasets/Sketches', help='Directory of the dataset.') 13 | parser.add_argument('--project', type=str, default='few-shot-adapt', help='Name of this project.') 14 | parser.add_argument('--name', type=str, default='default', help='Name of this experiment.') 15 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment (for Weights&Biases).') 16 | parser.add_argument('--source_ckpt_path', type=str, default='ffhq_256x256.pickle', help='Path to the checkpoint of the source model.') 17 | # Training 18 | parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.') 19 | parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.') 20 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size.') 21 | parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.') 22 | parser.add_argument('--resolution', type=int, default=256, help='Image resolution. Must be a multiple of 2.') 23 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.') 24 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.') 25 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.') 26 | parser.add_argument('--subspace_std', type=float, default=0.1, help='Std for sampling z_latents around the anchor points.') 27 | parser.add_argument('--subspace_freq', type=int, default=4, help='Frequency for sampling z_latents from anchor regions.') 28 | # Generator 29 | parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.') 30 | # Discriminator 31 | parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.') 32 | # Exponentially Moving Average of Generator Weights 33 | parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).') 34 | # Losses 35 | parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).') 36 | parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.') 37 | parser.add_argument('--kl_weight', type=float, default=1000.0, help='Weight for distance consistency loss.') 38 | # Regularization 39 | parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.') 40 | parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.') 41 | parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.') 42 | parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.') 43 | # Model 44 | parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.') 45 | # Logging 46 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.') 47 | parser.add_argument('--log_every', type=int, default=50, help='Log every log_every steps.') 48 | parser.add_argument('--save_every', type=int, default=1000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.') 49 | # FID 50 | parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.') 51 | parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.') 52 | parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.') 53 | 54 | args = parser.parse_args() 55 | 56 | if jax.process_index() == 0: 57 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints') 58 | if not os.path.exists(args.ckpt_dir): 59 | os.makedirs(args.ckpt_dir) 60 | 61 | if args.wandb: 62 | wandb.init(project=args.project, 63 | group=args.group, 64 | config=args, 65 | name=args.name, 66 | dir=os.path.join(args.work_dir, args.group, args.name)) 67 | 68 | training.train_and_evaluate(args) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | -------------------------------------------------------------------------------- /training/resnet/README.md: -------------------------------------------------------------------------------- 1 | # ResNet Training 2 | 3 | ##### Table of Contents 4 | * [Getting Started](#getting_started) 5 | * [Training](#training) 6 | * [Options](#options) 7 | * [Results](#results) 8 | * [References](#references) 9 | * [License](#license) 10 | 11 | 12 | 13 | ## Getting Started 14 | You will need Python 3.7 or later. 15 | 16 | 1. Clone the repository: 17 | ```sh 18 | > git clone https://github.com/matthias-wright/flaxmodels.git 19 | ``` 20 | 2. Go into the directory: 21 | ```sh 22 | > cd flaxmodels/training/resnet 23 | ``` 24 | 3. Install Jax with CUDA. 25 | 4. Install requirements: 26 | ```sh 27 | > pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | ## Training 32 | 33 | ### Basic Training 34 | ```python 35 | CUDA_VISIBLE_DEVICES=0 python main.py 36 | ``` 37 | 38 | ### Multi GPU Training 39 | The script will automatically use all the visible GPUs for distributed training. 40 | ```python 41 | CUDA_VISIBLE_DEVICES=0,1 python main.py 42 | ``` 43 | 44 | ### Mixed-Precision Training 45 | ```python 46 | CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision 47 | ``` 48 | 49 | 50 | ## Options 51 | * `--work_dir` - Path to directory for logging and checkpoints (str). 52 | * `--data_dir` - Path for storing the dataset (str). 53 | * `--name` - Name of the training run (str). 54 | * `--group` - Group name of the training run (str). 55 | * `--arch` - Architecture (str). Options: resnet18, resnet34, resnet50, resnet101, resnet152. 56 | * `--resume` - Resume training from best checkpoint (bool). 57 | * `--num_epochs` - Number of epochs (int). 58 | * `--learning_rate` - Learning rate (float). 59 | * `--warmup_epochs` - Number of warmup epochs with lower learning rate (int). 60 | * `--batch_size` - Batch size (int). 61 | * `--num_classes` - Number of classes (int). 62 | * `--img_size` - Image size (int). 63 | * `--img_channels` - Number of image channels (int). 64 | * `--mixed_precision` - Use mixed precision training (bool). 65 | * `--random_seed` - Random seed (int). 66 | * `--wandb` - Use Weights&Biases for logging (bool). 67 | * `--log_every` - Log every log_every steps (int). 68 | 69 | 70 | 71 | ## Results 72 | ResNet18 was trained on the Imagenette dataset. The validation accuracy is around 90%. 73 | 74 | * Images were resized to 256x256 (random crops for training and center crops for evaluation). 75 | * Data augmentation: flipping, brightness, hue, contrast. 76 | * Learning rate schedule: Cosine Annealing. 77 | * Training was done from scratch, no transfer learning. 78 | 79 | 80 | 81 | ## References 82 | * [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) 83 | * [pytorch/vision/torchvision/models/resnet.py](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py) 84 | * [google/flax/examples](https://github.com/google/flax/tree/main/examples) 85 | 86 | 87 | 88 | ## License 89 | [MIT License](https://opensource.org/licenses/MIT) 90 | 91 | 92 | -------------------------------------------------------------------------------- /training/resnet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import jax 4 | import wandb 5 | import training 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | # Paths 11 | parser.add_argument('--work_dir', type=str, default='/export/scratch/mwright/projects/misc/imagenette', help='Directory for logging and checkpoints.') 12 | parser.add_argument('--data_dir', type=str, default='/export/data/mwright/tensorflow_datasets', help='Directory for storing data.') 13 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.') 14 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment.') 15 | # Training 16 | parser.add_argument('--arch', type=str, default='resnet18', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], help='Architecture.') 17 | parser.add_argument('--resume', action='store_true', help='Resume training from best checkpoint.') 18 | parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs.') 19 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.') 20 | parser.add_argument('--warmup_epochs', type=int, default=9, help='Number of warmup epochs with lower learning rate.') 21 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 22 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes.') 23 | parser.add_argument('--img_size', type=int, default=224, help='Image size.') 24 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.') 25 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.') 26 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.') 27 | # Logging 28 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.') 29 | parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.') 30 | args = parser.parse_args() 31 | 32 | if jax.process_index() == 0: 33 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints') 34 | if not os.path.exists(args.ckpt_dir): 35 | os.makedirs(args.ckpt_dir) 36 | 37 | if args.wandb: 38 | wandb.init(entity='matthias-wright', 39 | project='imagenette', 40 | group=args.group, 41 | config=args, 42 | name=args.name, 43 | dir=os.path.join(args.work_dir, args.group, args.name)) 44 | 45 | training.train_and_evaluate(args) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /training/resnet/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | git+https://github.com/matthias-wright/flaxmodels.git 3 | tensorflow-datasets 4 | optax 5 | argparse 6 | wandb 7 | tqdm 8 | -------------------------------------------------------------------------------- /training/stylegan2/checkpoint.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import dill as pickle 3 | import os 4 | import glob 5 | 6 | 7 | def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2): 8 | """ 9 | Saves checkpoint. 10 | 11 | Args: 12 | ckpt_dir (str): Path to the directory, where checkpoints are saved. 13 | state_G (train_state.TrainState): Generator state. 14 | state_D (train_state.TrainState): Discriminator state. 15 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator. 16 | pl_mean (array): Moving average of the path length (generator regularization). 17 | config (argparse.Namespace): Configuration. 18 | step (int): Current step. 19 | epoch (int): Current epoch. 20 | fid_score (float): FID score corresponding to the checkpoint. 21 | keep (int): Number of checkpoints to keep. 22 | """ 23 | state_dict = {'state_G': flax.jax_utils.unreplicate(state_G), 24 | 'state_D': flax.jax_utils.unreplicate(state_D), 25 | 'params_ema_G': params_ema_G, 26 | 'pl_mean': flax.jax_utils.unreplicate(pl_mean), 27 | 'config': config, 28 | 'fid_score': fid_score, 29 | 'step': step, 30 | 'epoch': epoch} 31 | 32 | with open(os.path.join(ckpt_dir, f'ckpt_{step}.pickle'), 'wb') as handle: 33 | pickle.dump(state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 34 | 35 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle')) 36 | if len(ckpts) > keep: 37 | oldest_ckpt = min(ckpts, key=os.path.getctime) 38 | os.remove(oldest_ckpt) 39 | 40 | 41 | def load_checkpoint(filename): 42 | """ 43 | Loads checkpoints. 44 | 45 | Args: 46 | filename (str): Path to the checkpoint file. 47 | 48 | Returns: 49 | (dict): Checkpoint. 50 | """ 51 | state_dict = pickle.load(open(filename, 'rb')) 52 | state_dict['state_G'] = flax.jax_utils.replicate(state_dict['state_G']) 53 | state_dict['state_D'] = flax.jax_utils.replicate(state_dict['state_D']) 54 | state_dict['pl_mean'] = flax.jax_utils.replicate(state_dict['pl_mean']) 55 | return state_dict 56 | 57 | 58 | def get_latest_checkpoint(ckpt_dir): 59 | """ 60 | Returns the path of the latest checkpoint. 61 | 62 | Args: 63 | ckpt_dir (str): Path to the directory, where checkpoints are saved. 64 | 65 | Returns: 66 | (str): Path to latest checkpoint (if it exists). 67 | """ 68 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle')) 69 | if len(ckpts) == 0: 70 | return None 71 | latest_ckpt = max(ckpts, key=os.path.getctime) 72 | return latest_ckpt 73 | 74 | -------------------------------------------------------------------------------- /training/stylegan2/data_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | import jax 4 | import flax 5 | import numpy as np 6 | from PIL import Image 7 | import os 8 | from typing import Sequence 9 | from tqdm import tqdm 10 | import json 11 | from tqdm import tqdm 12 | 13 | 14 | def prefetch(dataset, n_prefetch): 15 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py 16 | ds_iter = iter(dataset) 17 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), 18 | ds_iter) 19 | if n_prefetch: 20 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) 21 | return ds_iter 22 | 23 | 24 | def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000): 25 | """ 26 | 27 | Args: 28 | data_dir (str): Root directory of the dataset. 29 | img_size (int): Image size for training. 30 | img_channels (int): Number of image channels. 31 | num_classes (int): Number of classes, 0 for no classes. 32 | num_devices (int): Number of devices. 33 | batch_size (int): Batch size (per device). 34 | shuffle_buffer (int): Buffer used for shuffling the dataset. 35 | 36 | Returns: 37 | (tf.data.Dataset): Dataset. 38 | """ 39 | 40 | def pre_process(serialized_example): 41 | feature = {'height': tf.io.FixedLenFeature([], tf.int64), 42 | 'width': tf.io.FixedLenFeature([], tf.int64), 43 | 'channels': tf.io.FixedLenFeature([], tf.int64), 44 | 'image': tf.io.FixedLenFeature([], tf.string), 45 | 'label': tf.io.FixedLenFeature([], tf.int64)} 46 | example = tf.io.parse_single_example(serialized_example, feature) 47 | 48 | height = tf.cast(example['height'], dtype=tf.int64) 49 | width = tf.cast(example['width'], dtype=tf.int64) 50 | channels = tf.cast(example['channels'], dtype=tf.int64) 51 | 52 | image = tf.io.decode_raw(example['image'], out_type=tf.uint8) 53 | image = tf.reshape(image, shape=[height, width, channels]) 54 | 55 | image = tf.cast(image, dtype='float32') 56 | image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True) 57 | image = tf.image.random_flip_left_right(image) 58 | 59 | image = (image - 127.5) / 127.5 60 | 61 | label = tf.one_hot(example['label'], num_classes) 62 | return {'image': image, 'label': label} 63 | 64 | def shard(data): 65 | # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C] 66 | # because the first dimension will be mapped across devices using jax.pmap 67 | data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels]) 68 | data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes]) 69 | return data 70 | 71 | print('Loading TFRecord...') 72 | with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin: 73 | dataset_info = json.load(fin) 74 | 75 | ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords')) 76 | 77 | ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer)) 78 | ds = ds.map(pre_process, tf.data.AUTOTUNE) 79 | ds = ds.batch(batch_size * num_devices, drop_remainder=True) 80 | ds = ds.map(shard, tf.data.AUTOTUNE) 81 | ds = ds.prefetch(1) 82 | return ds, dataset_info 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /training/stylegan2/dataset_utils/crop_image_borders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import os 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | 8 | def crop_border(x, constant=0.0): 9 | top = 0 10 | while True: 11 | if np.sum(x[top] != constant) != 0.0: 12 | break 13 | top += 1 14 | bottom = x.shape[0] - 1 15 | while True: 16 | if np.sum(x[bottom] != constant) != 0.0: 17 | bottom += 1 18 | break 19 | bottom -= 1 20 | left = 0 21 | while True: 22 | if np.sum(x[:, left] != constant) != 0.0: 23 | break 24 | left += 1 25 | right = x.shape[1] - 1 26 | while True: 27 | if np.sum(x[:, right] != constant) != 0.0: 28 | right += 1 29 | break 30 | right -= 1 31 | return x[top:bottom, left:right] 32 | 33 | 34 | def crop_images(path, constant_value): 35 | print('Crop image borders...') 36 | for f in tqdm(os.listdir(path)): 37 | img = Image.open(os.path.join(path, f)) 38 | img = crop_border(np.array(img), constant=constant_value) 39 | img = Image.fromarray(img) 40 | img.save(os.path.join(path, f)) 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.') 46 | parser.add_argument('--constant_value', type=float, default=0.0, help='Value of the border that should be cropped.') 47 | 48 | args = parser.parse_args() 49 | 50 | crop_images(args.image_dir, args.constant_value) 51 | 52 | -------------------------------------------------------------------------------- /training/stylegan2/dataset_utils/images_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from PIL import Image 4 | from typing import Sequence 5 | from tqdm import tqdm 6 | import argparse 7 | import json 8 | import os 9 | 10 | 11 | def images_to_tfrecords(image_dir, data_dir, has_labels): 12 | """ 13 | Converts a folder of images to a TFRecord file. 14 | 15 | The image directory should have one of the following structures: 16 | 17 | If has_labels = False, image_dir should look like this: 18 | 19 | path/to/image_dir/ 20 | 0.jpg 21 | 1.jpg 22 | 2.jpg 23 | 4.jpg 24 | ... 25 | 26 | 27 | If has_labels = True, image_dir should look like this: 28 | 29 | path/to/image_dir/ 30 | label0/ 31 | 0.jpg 32 | 1.jpg 33 | ... 34 | label1/ 35 | a.jpg 36 | b.jpg 37 | c.jpg 38 | ... 39 | ... 40 | 41 | 42 | The labels will be label0 -> 0, label1 -> 1. 43 | 44 | Args: 45 | image_dir (str): Path to images. 46 | data_dir (str): Path where the TFrecords dataset is stored. 47 | has_labels (bool): If True, 'image_dir' contains label directories. 48 | 49 | Returns: 50 | (dict): Dataset info. 51 | """ 52 | 53 | def _bytes_feature(value): 54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 55 | 56 | def _int64_feature(value): 57 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 58 | 59 | os.makedirs(data_dir, exist_ok=True) 60 | writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords')) 61 | 62 | num_examples = 0 63 | num_classes = 0 64 | 65 | if has_labels: 66 | for label_dir in os.listdir(image_dir): 67 | if not os.path.isdir(os.path.join(image_dir, label_dir)): 68 | print('The image directory should contain one directory for each label.') 69 | print('These label directories should contain the image files.') 70 | if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')): 71 | os.remove(os.path.join(data_dir, 'dataset.tfrecords')) 72 | return 73 | 74 | for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))): 75 | file_format = img_file[img_file.rfind('.') + 1:] 76 | if file_format not in ['png', 'jpg', 'jpeg']: 77 | continue 78 | 79 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) 80 | img = Image.open(os.path.join(image_dir, label_dir, img_file)) 81 | img = np.array(img, dtype=np.uint8) 82 | 83 | height = img.shape[0] 84 | width = img.shape[1] 85 | channels = img.shape[2] 86 | 87 | img_encoded = img.tobytes() 88 | 89 | example = tf.train.Example(features=tf.train.Features(feature={ 90 | 'height': _int64_feature(height), 91 | 'width': _int64_feature(width), 92 | 'channels': _int64_feature(channels), 93 | 'image': _bytes_feature(img_encoded), 94 | 'label': _int64_feature(num_classes)})) 95 | 96 | writer.write(example.SerializeToString()) 97 | num_examples += 1 98 | 99 | num_classes += 1 100 | else: 101 | for img_file in tqdm(os.listdir(os.path.join(image_dir))): 102 | file_format = img_file[img_file.rfind('.') + 1:] 103 | if file_format not in ['png', 'jpg', 'jpeg']: 104 | continue 105 | 106 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) 107 | img = Image.open(os.path.join(image_dir, img_file)) 108 | img = np.array(img, dtype=np.uint8) 109 | 110 | height = img.shape[0] 111 | width = img.shape[1] 112 | channels = img.shape[2] 113 | 114 | img_encoded = img.tobytes() 115 | 116 | example = tf.train.Example(features=tf.train.Features(feature={ 117 | 'height': _int64_feature(height), 118 | 'width': _int64_feature(width), 119 | 'channels': _int64_feature(channels), 120 | 'image': _bytes_feature(img_encoded), 121 | 'label': _int64_feature(num_classes)})) # dummy label 122 | 123 | writer.write(example.SerializeToString()) 124 | num_examples += 1 125 | 126 | writer.close() 127 | 128 | dataset_info = {'num_examples': num_examples, 'num_classes': num_classes} 129 | with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout: 130 | json.dump(dataset_info, fout) 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.') 136 | parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.') 137 | parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.') 138 | 139 | args = parser.parse_args() 140 | 141 | images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels) 142 | 143 | -------------------------------------------------------------------------------- /training/stylegan2/fid/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FID 2 | -------------------------------------------------------------------------------- /training/stylegan2/fid/core.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax 4 | import flax.linen as nn 5 | import numpy as np 6 | import os 7 | import functools 8 | import argparse 9 | import scipy 10 | from tqdm import tqdm 11 | 12 | from . import inception 13 | from . import utils 14 | 15 | 16 | class FID: 17 | 18 | def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0): 19 | """ 20 | Evaluates the FID score for a given generator and a given dataset. 21 | Implementation mostly taken from https://github.com/matthias-wright/jax-fid 22 | 23 | Reference: https://arxiv.org/abs/1706.08500 24 | 25 | Args: 26 | generator (nn.Module): Generator network. 27 | dataset (tf.data.Dataset): Dataset containing the real images. 28 | config (argparse.Namespace): Configuration. 29 | use_cache (bool): If True, only compute the activation stats once for the real images and store them. 30 | truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled. 31 | """ 32 | self.num_images = config.num_fid_images 33 | self.batch_size = config.batch_size 34 | self.c_dim = config.c_dim 35 | self.z_dim = config.z_dim 36 | self.dataset = dataset 37 | self.num_devices = jax.device_count() 38 | self.use_cache = use_cache 39 | 40 | if self.use_cache: 41 | self.cache = {} 42 | 43 | rng = jax.random.PRNGKey(0) 44 | inception_net = inception.InceptionV3(pretrained=True) 45 | self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3))) 46 | self.inception_params = flax.jax_utils.replicate(self.inception_params) 47 | #self.inception = jax.jit(functools.partial(model.apply, train=False)) 48 | self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch') 49 | 50 | self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch') 51 | 52 | def compute_fid(self, generator_params, seed_offset=0): 53 | generator_params = flax.jax_utils.replicate(generator_params) 54 | mu_real, sigma_real = self.compute_stats_for_dataset() 55 | mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset) 56 | fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6) 57 | return fid_score 58 | 59 | def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6): 60 | # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 61 | mu1 = np.atleast_1d(mu1) 62 | mu2 = np.atleast_1d(mu2) 63 | sigma1 = np.atleast_1d(sigma1) 64 | sigma2 = np.atleast_1d(sigma2) 65 | 66 | assert mu1.shape == mu2.shape 67 | assert sigma1.shape == sigma2.shape 68 | 69 | diff = mu1 - mu2 70 | 71 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 72 | if not np.isfinite(covmean).all(): 73 | msg = ('fid calculation produces singular product; ' 74 | 'adding %s to diagonal of cov estimates') % eps 75 | print(msg) 76 | offset = np.eye(sigma1.shape[0]) * eps 77 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 78 | 79 | # Numerical error might give slight imaginary component 80 | if np.iscomplexobj(covmean): 81 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 82 | m = np.max(np.abs(covmean.imag)) 83 | raise ValueError('Imaginary component {}'.format(m)) 84 | covmean = covmean.real 85 | 86 | tr_covmean = np.trace(covmean) 87 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 88 | 89 | def compute_stats_for_dataset(self): 90 | if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache: 91 | print('Use cached statistics for dataset...') 92 | return self.cache['mu'], self.cache['sigma'] 93 | 94 | print() 95 | print('Compute statistics for dataset...') 96 | pbar = tqdm(total=self.num_images) 97 | image_count = 0 98 | 99 | activations = [] 100 | for batch in utils.prefetch(self.dataset, n_prefetch=2): 101 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image'])) 102 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1)) 103 | activations.append(act) 104 | 105 | pbar.update(self.num_devices * self.batch_size) 106 | image_count += self.num_devices * self.batch_size 107 | if image_count >= self.num_images: 108 | break 109 | pbar.close() 110 | 111 | activations = jnp.concatenate(activations, axis=0) 112 | activations = activations[:self.num_images] 113 | mu = np.mean(activations, axis=0) 114 | sigma = np.cov(activations, rowvar=False) 115 | self.cache['mu'] = mu 116 | self.cache['sigma'] = sigma 117 | return mu, sigma 118 | 119 | def compute_stats_for_generator(self, generator_params, seed_offset): 120 | print() 121 | print('Compute statistics for generator...') 122 | num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_devices))) 123 | 124 | pbar = tqdm(total=self.num_images) 125 | activations = [] 126 | 127 | for i in range(num_batches): 128 | rng = jax.random.PRNGKey(seed_offset + i) 129 | z_latent = jax.random.normal(rng, shape=(self.num_devices, self.batch_size, self.z_dim)) 130 | 131 | labels = None 132 | if self.c_dim > 0: 133 | labels = jax.random.randint(rng, shape=(self.num_devices * self.batch_size,), minval=0, maxval=self.c_dim) 134 | labels = jax.nn.one_hot(labels, num_classes=self.c_dim) 135 | labels = jnp.reshape(labels, (self.num_devices, self.batch_size, self.c_dim)) 136 | 137 | image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels) 138 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) 139 | 140 | image = 2 * image - 1 141 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image)) 142 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1)) 143 | activations.append(act) 144 | pbar.update(self.num_devices * self.batch_size) 145 | pbar.close() 146 | 147 | activations = jnp.concatenate(activations, axis=0) 148 | activations = activations[:self.num_images] 149 | mu = np.mean(activations, axis=0) 150 | sigma = np.cov(activations, rowvar=False) 151 | return mu, sigma 152 | 153 | 154 | -------------------------------------------------------------------------------- /training/stylegan2/fid/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import flax 3 | import numpy as np 4 | from tqdm import tqdm 5 | import requests 6 | import os 7 | import tempfile 8 | 9 | 10 | def download(url, ckpt_dir=None): 11 | name = url[url.rfind('/') + 1 : url.rfind('?')] 12 | if ckpt_dir is None: 13 | ckpt_dir = tempfile.gettempdir() 14 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') 15 | ckpt_file = os.path.join(ckpt_dir, name) 16 | if not os.path.exists(ckpt_file): 17 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 18 | if not os.path.exists(ckpt_dir): 19 | os.makedirs(ckpt_dir) 20 | 21 | response = requests.get(url, stream=True) 22 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 23 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 24 | 25 | # first create temp file, in case the download fails 26 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 27 | with open(ckpt_file_temp, 'wb') as file: 28 | for data in response.iter_content(chunk_size=1024): 29 | progress_bar.update(len(data)) 30 | file.write(data) 31 | progress_bar.close() 32 | 33 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 34 | print('An error occured while downloading, please try again.') 35 | if os.path.exists(ckpt_file_temp): 36 | os.remove(ckpt_file_temp) 37 | else: 38 | # if download was successful, rename the temp file 39 | os.rename(ckpt_file_temp, ckpt_file) 40 | return ckpt_file 41 | 42 | 43 | def get(dictionary, key): 44 | if dictionary is None or key not in dictionary: 45 | return None 46 | return dictionary[key] 47 | 48 | 49 | def prefetch(dataset, n_prefetch): 50 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py 51 | ds_iter = iter(dataset) 52 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), 53 | ds_iter) 54 | if n_prefetch: 55 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) 56 | return ds_iter 57 | -------------------------------------------------------------------------------- /training/stylegan2/generate_images.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax 4 | import numpy as np 5 | import dill as pickle 6 | import flaxmodels as fm 7 | import data_pipeline 8 | import checkpoint 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import argparse 12 | import functools 13 | import os 14 | 15 | 16 | def generate_images(args): 17 | num_devices = jax.device_count() 18 | ckpt = checkpoint.load_checkpoint(args.ckpt_path) 19 | config = ckpt['config'] 20 | 21 | dtype = jnp.float32 22 | 23 | generator_ema = fm.stylegan2.Generator(resolution=config.resolution, 24 | num_channels=config.img_channels, 25 | z_dim=config.z_dim, 26 | c_dim=config.c_dim, 27 | w_dim=config.w_dim, 28 | num_ws=int(np.log2(config.resolution)) * 2 - 3, 29 | num_mapping_layers=8, 30 | fmap_base=config.fmap_base, 31 | dtype=dtype) 32 | 33 | generator_apply = jax.jit(functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const')) 34 | params_ema_G = ckpt['params_ema_G'] 35 | 36 | for seed in tqdm(args.seeds): 37 | rng = jax.random.PRNGKey(seed) 38 | z_latent = jax.random.normal(rng, shape=(1, config.z_dim)) 39 | labels = None 40 | if config.c_dim > 0: 41 | labels = jax.random.randint(rng, shape=(1,), minval=0, maxval=config.c_dim) 42 | labels = jax.nn.one_hot(labels, num_classes=config.c_dim) 43 | labels = jnp.reshape(labels, (1, config.c_dim)) 44 | 45 | image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), labels) 46 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) 47 | 48 | Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png')) 49 | print('Images saved at:', args.out_path) 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--ckpt_path', type=str, help='Path to the checkpoint.') 55 | parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.') 56 | parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.') 57 | parser.add_argument('--seeds', type=int, nargs='*', help='List of random seeds.') 58 | args = parser.parse_args() 59 | os.makedirs(args.out_path, exist_ok=True) 60 | 61 | generate_images(args) 62 | 63 | 64 | -------------------------------------------------------------------------------- /training/stylegan2/images/anime_grid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_grid.jpg -------------------------------------------------------------------------------- /training/stylegan2/images/anime_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_overview.jpg -------------------------------------------------------------------------------- /training/stylegan2/images/anime_style_mixing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_style_mixing.jpg -------------------------------------------------------------------------------- /training/stylegan2/images/ffhq_grid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_grid.jpg -------------------------------------------------------------------------------- /training/stylegan2/images/ffhq_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_overview.jpg -------------------------------------------------------------------------------- /training/stylegan2/images/ffhq_style_mixing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_style_mixing.jpg -------------------------------------------------------------------------------- /training/stylegan2/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import jax 4 | import wandb 5 | import training 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | # Paths 11 | parser.add_argument('--work_dir', type=str, default='logging', help='Directory for logging and checkpoints.') 12 | parser.add_argument('--data_dir', type=str, help='Directory of the dataset.') 13 | parser.add_argument('--project', type=str, default='stylegan', help='Name of this project.') 14 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.') 15 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment (for Weights&Biases).') 16 | # Training 17 | parser.add_argument('--resume', action='store_true', help='Resume training from latest checkpoint.') 18 | parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.') 19 | parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.') 20 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size.') 21 | parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.') 22 | parser.add_argument('--resolution', type=int, default=128, help='Image resolution. Must be a multiple of 2.') 23 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.') 24 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.') 25 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.') 26 | # Generator 27 | parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.') 28 | # Discriminator 29 | parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.') 30 | # Exponentially Moving Average of Generator Weights 31 | parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).') 32 | # Losses 33 | parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).') 34 | parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.') 35 | # Regularization 36 | parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.') 37 | parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.') 38 | parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.') 39 | parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.') 40 | # Model 41 | parser.add_argument('--z_dim', type=int, default=512, help='Input latent (Z) dimensionality.') 42 | parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.') 43 | parser.add_argument('--w_dim', type=int, default=512, help='Conditioning label (W) dimensionality.') 44 | # Logging 45 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.') 46 | parser.add_argument('--log_every', type=int, default=50, help='Log every log_every steps.') 47 | parser.add_argument('--save_every', type=int, default=2000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.') 48 | # FID 49 | parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.') 50 | parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.') 51 | parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.') 52 | 53 | args = parser.parse_args() 54 | 55 | if jax.process_index() == 0: 56 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints') 57 | if not os.path.exists(args.ckpt_dir): 58 | os.makedirs(args.ckpt_dir) 59 | 60 | if args.wandb: 61 | wandb.init(project=args.project, 62 | group=args.group, 63 | config=args, 64 | name=args.name, 65 | dir=os.path.join(args.work_dir, args.group, args.name)) 66 | 67 | training.train_and_evaluate(args) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | 73 | -------------------------------------------------------------------------------- /training/stylegan2/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | git+https://github.com/matthias-wright/flaxmodels.git 3 | tensorflow-datasets 4 | tensorflow==2.4.1 5 | optax 6 | argparse 7 | wandb 8 | tqdm 9 | dill 10 | -------------------------------------------------------------------------------- /training/stylegan2/style_mixing.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax 4 | from flax.core import frozen_dict 5 | import numpy as np 6 | import dill as pickle 7 | import flaxmodels as fm 8 | import data_pipeline 9 | import checkpoint 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import argparse 13 | import functools 14 | import os 15 | 16 | 17 | def style_mixing(args): 18 | num_devices = jax.device_count() 19 | ckpt = checkpoint.load_checkpoint(args.ckpt_path) 20 | config = ckpt['config'] 21 | 22 | dtype = jnp.float32 23 | 24 | mapping_net = fm.stylegan2.MappingNetwork(z_dim=config.z_dim, 25 | c_dim=config.c_dim, 26 | w_dim=config.w_dim, 27 | num_ws=int(np.log2(config.resolution)) * 2 - 3, 28 | num_layers=8, 29 | dtype=dtype) 30 | 31 | synthesis_net = fm.stylegan2.SynthesisNetwork(resolution=config.resolution, 32 | num_channels=config.img_channels, 33 | w_dim=config.w_dim, 34 | fmap_base=config.fmap_base, 35 | dtype=dtype) 36 | 37 | params_ema_G = ckpt['params_ema_G'] 38 | params_ema_G = params_ema_G.unfreeze() 39 | synthesis_params = {'params': params_ema_G['params']['synthesis_network'], 40 | 'noise_consts': params_ema_G['noise_consts']['synthesis_network']} 41 | synthesis_params = frozen_dict.freeze(synthesis_params) 42 | 43 | mapping_params = {'params': params_ema_G['params']['mapping_network'], 44 | 'moving_stats': params_ema_G['moving_stats']['mapping_network']} 45 | mapping_params = frozen_dict.freeze(mapping_params) 46 | 47 | synthesis_apply = jax.jit(functools.partial(synthesis_net.apply, noise_mode='const')) 48 | mapping_apply = jax.jit(functools.partial(mapping_net.apply, truncation_psi=args.truncation_psi, train=False)) 49 | 50 | all_seeds = args.row_seeds + args.col_seeds 51 | # Generate noise inputs, [minibatch, component] 52 | all_z = jnp.concatenate([jax.random.normal(jax.random.PRNGKey(seed), shape=(1, 512)) for seed in all_seeds]) 53 | # Generate latent vectors, [minibatch, num_ws, component] 54 | all_w = mapping_apply(mapping_params, all_z) 55 | # Generate images, [minibatch, H, W, 3] 56 | all_images = synthesis_apply(synthesis_params, all_w) 57 | # Normalize image to be in range [0, 1] 58 | all_images = (all_images - jnp.min(all_images)) / (jnp.max(all_images) - jnp.min(all_images)) 59 | col_images = jnp.concatenate([all_images[i] for i in range(len(args.row_seeds))], axis=0) 60 | row_images = jnp.concatenate([all_images[len(args.row_seeds) + i] for i in range(len(args.col_seeds))], axis=1) 61 | 62 | images_grid = [] 63 | 64 | cutoff = mapping_net.num_ws // 2 65 | 66 | # Generate style mixing images 67 | for row in range(len(args.row_seeds)): 68 | image_row = [] 69 | for col in range(len(args.col_seeds)): 70 | # Combine first 9 dimensions from row seed latent w with last 9 dimensions from col seed latent w 71 | w = jnp.concatenate([all_w[row, :cutoff], all_w[len(args.row_seeds) + col, cutoff:]], axis=0) 72 | # Add batch dimension 73 | w = jnp.expand_dims(w, axis=0) 74 | image = synthesis_apply(synthesis_params, w) 75 | # Remove batch dimension 76 | image = jnp.squeeze(image, axis=0) 77 | 78 | # Normalize image to be in range [0, 1] 79 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image)) 80 | image_row.append(image) 81 | image_row = jnp.concatenate(image_row, axis=1) 82 | images_grid.append(image_row) 83 | 84 | images_grid = jnp.concatenate(images_grid, axis=0) 85 | 86 | # Add row and column images to the grid 87 | border = 20 88 | grid = np.ones((row_images.shape[0] + images_grid.shape[0] + border, 89 | col_images.shape[1] + images_grid.shape[1] + border, 90 | 3)) 91 | grid[grid.shape[0] - images_grid.shape[0]:, grid.shape[1] - images_grid.shape[1]:] = images_grid 92 | grid[:row_images.shape[0], grid.shape[1] - row_images.shape[1]:] = row_images 93 | grid[grid.shape[0] - col_images.shape[0]:, :col_images.shape[1]] = col_images 94 | Image.fromarray(np.uint8(np.clip(grid * 255, 0, 255))).save(os.path.join(args.out_path, 'style_mixing.png')) 95 | print('Style mixing grid saved at:', args.out_path) 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--ckpt_path', type=str, help='Path to the checkpoint.') 101 | parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.') 102 | parser.add_argument('--num_images', type=int, default=100, help='Number of images to generate.') 103 | parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.') 104 | parser.add_argument('--row_seeds', type=int, nargs='*', help='List of random seeds for row images.') 105 | parser.add_argument('--col_seeds', type=int, nargs='*', help='List of random seeds for column images.') 106 | args = parser.parse_args() 107 | assert len(args.row_seeds) == len(args.col_seeds), 'row_seeds and col_seeds must have the same length.' 108 | os.makedirs(args.out_path, exist_ok=True) 109 | 110 | style_mixing(args) 111 | 112 | 113 | -------------------------------------------------------------------------------- /training/stylegan2/training_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jaxlib.xla_extension import DeviceArray 4 | import flax 5 | from flax.optim import dynamic_scale as dynamic_scale_lib 6 | from flax.core import frozen_dict 7 | from flax.training import train_state 8 | from flax import struct 9 | import numpy as np 10 | from PIL import Image 11 | from typing import Any, Callable 12 | 13 | 14 | def sync_moving_stats(state): 15 | """ 16 | Sync moving statistics across devices. 17 | 18 | Args: 19 | state (train_state.TrainState): Training state. 20 | 21 | Returns: 22 | (train_state.TrainState): Updated training state. 23 | """ 24 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x') 25 | return state.replace(moving_stats=cross_replica_mean(state.moving_stats)) 26 | 27 | 28 | def update_generator_ema(state_G, params_ema_G, config, ema_beta=None): 29 | """ 30 | Update exponentially moving average of the generator weights. 31 | Moving stats and noise constants will be copied over. 32 | 33 | Args: 34 | state_G (train_state.TrainState): Generator state. 35 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator. 36 | config (Any): Config object. 37 | ema_beta (float): Beta parameter of the ema. If None, will be computed 38 | from 'ema_nimg' and 'batch_size'. 39 | 40 | Returns: 41 | (frozen_dict.FrozenDict): Updates parameters of the ema generator. 42 | """ 43 | def _update_ema(src, trg, beta): 44 | for name, src_child in src.items(): 45 | if isinstance(src_child, DeviceArray): 46 | trg[name] = src[name] + ema_beta * (trg[name] - src[name]) 47 | else: 48 | _update_ema(src_child, trg[name], beta) 49 | 50 | if ema_beta is None: 51 | ema_nimg = config.ema_kimg * 1000 52 | ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8)) 53 | 54 | params_ema_G = params_ema_G.unfreeze() 55 | 56 | # Copy over moving stats 57 | params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats 58 | params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts 59 | 60 | # Update exponentially moving average of the trainable parameters 61 | _update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta) 62 | _update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta) 63 | 64 | params_ema_G = frozen_dict.freeze(params_ema_G) 65 | return params_ema_G 66 | 67 | 68 | class TrainStateG(train_state.TrainState): 69 | """ 70 | Generator train state for a single Optax optimizer. 71 | 72 | Attributes: 73 | apply_mapping (Callable): Apply function of the Mapping Network. 74 | apply_synthesis (Callable): Apply function of the Synthesis Network. 75 | dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients. 76 | epoch (int): Current epoch. 77 | moving_stats (Any): Moving average of the latent W. 78 | noise_consts (Any): Noise constants from synthesis layers. 79 | """ 80 | apply_mapping: Callable = struct.field(pytree_node=False) 81 | apply_synthesis: Callable = struct.field(pytree_node=False) 82 | dynamic_scale_main: dynamic_scale_lib.DynamicScale 83 | dynamic_scale_reg: dynamic_scale_lib.DynamicScale 84 | epoch: int 85 | moving_stats: Any=None 86 | noise_consts: Any=None 87 | 88 | 89 | class TrainStateD(train_state.TrainState): 90 | """ 91 | Discriminator train state for a single Optax optimizer. 92 | 93 | Attributes: 94 | dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients. 95 | epoch (int): Current epoch. 96 | """ 97 | dynamic_scale_main: dynamic_scale_lib.DynamicScale 98 | dynamic_scale_reg: dynamic_scale_lib.DynamicScale 99 | epoch: int 100 | 101 | 102 | def get_training_snapshot(image_real, image_gen, max_num=10): 103 | """ 104 | Creates a snapshot of generated images and real images. 105 | 106 | Args: 107 | images_real (DeviceArray): Batch of real images, shape [B, H, W, C]. 108 | images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C]. 109 | max_num (int): Maximum number of images used for snapshot. 110 | 111 | Returns: 112 | (PIL.Image): Training snapshot. Top row: generated images, bottom row: real images. 113 | """ 114 | if image_real.shape[0] > max_num: 115 | image_real = image_real[:max_num] 116 | if image_gen.shape[0] > max_num: 117 | image_gen = image_gen[:max_num] 118 | 119 | image_real = jnp.split(image_real, image_real.shape[0], axis=0) 120 | image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0) 121 | 122 | image_real = [jnp.squeeze(x, axis=0) for x in image_real] 123 | image_gen = [jnp.squeeze(x, axis=0) for x in image_gen] 124 | 125 | image_real = jnp.concatenate(image_real, axis=1) 126 | image_gen = jnp.concatenate(image_gen, axis=1) 127 | 128 | image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen)) 129 | image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real)) 130 | image = jnp.concatenate((image_gen, image_real), axis=0) 131 | 132 | image = np.uint8(image * 255) 133 | if image.shape[-1] == 1: 134 | image = np.repeat(image, 3, axis=-1) 135 | return Image.fromarray(image) 136 | 137 | 138 | def get_eval_snapshot(image, max_num=10): 139 | """ 140 | Creates a snapshot of generated images. 141 | 142 | Args: 143 | image (DeviceArray): Generated images, shape [B, H, W, C]. 144 | 145 | Returns: 146 | (PIL.Image): Eval snapshot. 147 | """ 148 | if image.shape[0] > max_num: 149 | image = image[:max_num] 150 | 151 | image = jnp.split(image, image.shape[0], axis=0) 152 | image = [jnp.squeeze(x, axis=0) for x in image] 153 | image = jnp.concatenate(image, axis=1) 154 | image = (image - np.min(image)) / (np.max(image) - np.min(image)) 155 | image = np.uint8(image * 255) 156 | if image.shape[-1] == 1: 157 | image = np.repeat(image, 3, axis=-1) 158 | return Image.fromarray(image) 159 | -------------------------------------------------------------------------------- /training/vgg/README.md: -------------------------------------------------------------------------------- 1 | # VGG Training in Jax/Flax 2 | This is the training code for the [Jax/Flax implementation](https://github.com/matthias-wright/flaxmodels/tree/main/flaxmodels/vgg) of [VGG](https://arxiv.org/abs/1409.1556). 3 | 4 | ##### Table of Contents 5 | * [Getting Started](#getting_started) 6 | * [Training](#training) 7 | * [Options](#options) 8 | * [References](#references) 9 | * [License](#license) 10 | 11 | 12 | 13 | ## Getting Started 14 | You will need Python 3.7 or later. 15 | 16 | 1. Clone the repository: 17 | ```sh 18 | > git clone https://github.com/matthias-wright/flaxmodels.git 19 | ``` 20 | 2. Go into the directory: 21 | ```sh 22 | > cd flaxmodels/training/vgg 23 | ``` 24 | 3. Install Jax with CUDA. 25 | 4. Install requirements: 26 | ```sh 27 | > pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | ## Training 32 | 33 | ### Basic Training 34 | ```python 35 | CUDA_VISIBLE_DEVICES=0 python main.py 36 | ``` 37 | 38 | ### Multi GPU Training 39 | The script will automatically use all the visible GPUs for distributed training. 40 | ```python 41 | CUDA_VISIBLE_DEVICES=0,1 python main.py 42 | ``` 43 | 44 | ### Mixed-Precision Training 45 | ```python 46 | CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision 47 | ``` 48 | 49 | 50 | ## Options 51 | * `--work_dir` - Path to directory for logging and checkpoints (str). 52 | * `--data_dir` - Path for storing the dataset (str). 53 | * `--name` - Name of the training run (str). 54 | * `--group` - Group name of the training run (str). 55 | * `--arch` - Architecture (str). Options: vgg16, vgg19. 56 | * `--resume` - Resume training from best checkpoint (bool). 57 | * `--num_epochs` - Number of epochs (int). 58 | * `--learning_rate` - Learning rate (float). 59 | * `--warmup_epochs` - Number of warmup epochs with lower learning rate (int). 60 | * `--batch_size` - Batch size (int). 61 | * `--num_classes` - Number of classes (int). 62 | * `--img_size` - Image size (int). 63 | * `--img_channels` - Number of image channels (int). 64 | * `--mixed_precision` - Use mixed precision training (bool). 65 | * `--random_seed` - Random seed (int). 66 | * `--wandb` - Use Weights&Biases for logging (bool). 67 | * `--log_every` - Log every log_every steps (int). 68 | 69 | 70 | ## References 71 | * [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) 72 | * [pytorch/vision/torchvision/models/vgg.py](https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py) 73 | * [google/flax/examples](https://github.com/google/flax/tree/main/examples) 74 | 75 | 76 | 77 | ## License 78 | [MIT License](https://opensource.org/licenses/MIT) 79 | 80 | -------------------------------------------------------------------------------- /training/vgg/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import jax 4 | import wandb 5 | import training 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | # Paths 11 | parser.add_argument('--work_dir', type=str, default='/export/scratch/mwright/projects/misc/imagenette', help='Directory for logging and checkpoints.') 12 | parser.add_argument('--data_dir', type=str, default='/export/data/mwright/tensorflow_datasets', help='Directory for storing data.') 13 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.') 14 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment.') 15 | # Training 16 | parser.add_argument('--arch', type=str, default='vgg16', choices=['vgg16', 'vgg19'], help='Architecture.') 17 | parser.add_argument('--resume', action='store_true', help='Resume training from best checkpoint.') 18 | parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs.') 19 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.') 20 | parser.add_argument('--warmup_epochs', type=int, default=9, help='Number of warmup epochs with lower learning rate.') 21 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 22 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes.') 23 | parser.add_argument('--img_size', type=int, default=224, help='Image size.') 24 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.') 25 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.') 26 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.') 27 | # Logging 28 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.') 29 | parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.') 30 | args = parser.parse_args() 31 | 32 | if jax.process_index() == 0: 33 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints') 34 | if not os.path.exists(args.ckpt_dir): 35 | os.makedirs(args.ckpt_dir) 36 | 37 | if args.wandb: 38 | wandb.init(entity='matthias-wright', 39 | project='imagenette', 40 | group=args.group, 41 | config=args, 42 | name=args.name, 43 | dir=os.path.join(args.work_dir, args.group, args.name)) 44 | 45 | training.train_and_evaluate(args) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /training/vgg/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | git+https://github.com/matthias-wright/flaxmodels.git 3 | tensorflow 4 | tensorflow-datasets 5 | optax 6 | argparse 7 | wandb 8 | tqdm 9 | --------------------------------------------------------------------------------