├── .gitattributes
├── .gitignore
├── README.md
├── docs
├── Documentation.md
└── img
│ ├── flax.png
│ └── gpt2_diagram.png
├── flaxmodels
├── __init__.py
├── few_shot_gan_adaption
│ ├── README.md
│ ├── __init__.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── images
│ │ ├── AmedeoModigliani.jpg
│ │ ├── Babies.jpg
│ │ ├── OttoDix.jpg
│ │ ├── Rafael.jpg
│ │ └── Sketches.jpg
│ └── ops.py
├── gpt2
│ ├── README.md
│ ├── __init__.py
│ ├── gpt2.py
│ ├── gpt2_demo.ipynb
│ ├── ops.py
│ ├── third_party
│ │ ├── __init__.py
│ │ └── huggingface_transformers
│ │ │ ├── __init__.py
│ │ │ ├── configuration_gpt2.py
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── file_utils.py
│ │ │ ├── hf_api.py
│ │ │ ├── logging.py
│ │ │ ├── tokenization_utils.py
│ │ │ ├── tokenization_utils_base.py
│ │ │ └── versions.py
│ └── tokenizer.py
├── resnet
│ ├── README.md
│ ├── __init__.py
│ ├── ops.py
│ ├── resnet.py
│ └── resnet_demo.ipynb
├── stylegan2
│ ├── README.md
│ ├── __init__.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── images
│ │ ├── afhqcat.jpg
│ │ ├── afhqdog.jpg
│ │ ├── afhqwild.jpg
│ │ ├── brecahad.jpg
│ │ ├── car.jpg
│ │ ├── cat.jpg
│ │ ├── church.jpg
│ │ ├── cifar10.jpg
│ │ ├── ffhq.jpg
│ │ ├── gen_images_w_trunc.jpg
│ │ ├── gen_images_with_labels.jpg
│ │ ├── gen_images_wo_trunc.jpg
│ │ ├── horse.jpg
│ │ ├── metfaces.jpg
│ │ ├── style_mixing.jpg
│ │ └── title.jpg
│ ├── ops.py
│ └── stylegan2_demo.ipynb
├── utils.py
└── vgg
│ ├── README.md
│ ├── __init__.py
│ ├── vgg.py
│ └── vgg_demo.ipynb
├── setup.py
├── tests
├── aux_files
│ └── elefant.jpg
├── gpt2
│ ├── aux_files
│ │ ├── gpt2-large
│ │ │ ├── gpt2-large_lmhead_input_embds_input.npy
│ │ │ ├── gpt2-large_lmhead_input_embds_labels.npy
│ │ │ ├── gpt2-large_lmhead_input_embds_logits_ref.npy
│ │ │ ├── gpt2-large_lmhead_input_embds_loss_ref.npy
│ │ │ ├── gpt2-large_lmhead_input_ids_input.npy
│ │ │ ├── gpt2-large_lmhead_input_ids_labels.npy
│ │ │ ├── gpt2-large_lmhead_input_ids_logits_ref.npy
│ │ │ ├── gpt2-large_lmhead_input_ids_loss_ref.npy
│ │ │ ├── gpt2-large_model_input_embds_input.npy
│ │ │ ├── gpt2-large_model_input_embds_output_ref.npy
│ │ │ ├── gpt2-large_model_input_ids_input.npy
│ │ │ └── gpt2-large_model_input_ids_output_ref.npy
│ │ ├── gpt2-medium
│ │ │ ├── gpt2-medium_lmhead_input_embds_input.npy
│ │ │ ├── gpt2-medium_lmhead_input_embds_labels.npy
│ │ │ ├── gpt2-medium_lmhead_input_embds_logits_ref.npy
│ │ │ ├── gpt2-medium_lmhead_input_embds_loss_ref.npy
│ │ │ ├── gpt2-medium_lmhead_input_ids_input.npy
│ │ │ ├── gpt2-medium_lmhead_input_ids_labels.npy
│ │ │ ├── gpt2-medium_lmhead_input_ids_logits_ref.npy
│ │ │ ├── gpt2-medium_lmhead_input_ids_loss_ref.npy
│ │ │ ├── gpt2-medium_model_input_embds_input.npy
│ │ │ ├── gpt2-medium_model_input_embds_output_ref.npy
│ │ │ ├── gpt2-medium_model_input_ids_input.npy
│ │ │ └── gpt2-medium_model_input_ids_output_ref.npy
│ │ ├── gpt2-xl
│ │ │ ├── gpt2-xl_lmhead_input_embds_input.npy
│ │ │ ├── gpt2-xl_lmhead_input_embds_labels.npy
│ │ │ ├── gpt2-xl_lmhead_input_embds_logits_ref.npy
│ │ │ ├── gpt2-xl_lmhead_input_embds_loss_ref.npy
│ │ │ ├── gpt2-xl_lmhead_input_ids_input.npy
│ │ │ ├── gpt2-xl_lmhead_input_ids_labels.npy
│ │ │ ├── gpt2-xl_lmhead_input_ids_logits_ref.npy
│ │ │ ├── gpt2-xl_lmhead_input_ids_loss_ref.npy
│ │ │ ├── gpt2-xl_model_input_embds_input.npy
│ │ │ ├── gpt2-xl_model_input_embds_output_ref.npy
│ │ │ ├── gpt2-xl_model_input_ids_input.npy
│ │ │ └── gpt2-xl_model_input_ids_output_ref.npy
│ │ └── gpt2
│ │ │ ├── gpt2_lmhead_input_embds_input.npy
│ │ │ ├── gpt2_lmhead_input_embds_labels.npy
│ │ │ ├── gpt2_lmhead_input_embds_logits_ref.npy
│ │ │ ├── gpt2_lmhead_input_embds_loss_ref.npy
│ │ │ ├── gpt2_lmhead_input_ids_input.npy
│ │ │ ├── gpt2_lmhead_input_ids_labels.npy
│ │ │ ├── gpt2_lmhead_input_ids_logits_ref.npy
│ │ │ ├── gpt2_lmhead_input_ids_loss_ref.npy
│ │ │ ├── gpt2_model_input_embds_input.npy
│ │ │ ├── gpt2_model_input_embds_output_ref.npy
│ │ │ ├── gpt2_model_input_ids_input.npy
│ │ │ └── gpt2_model_input_ids_output_ref.npy
│ ├── test_gpt2.py
│ ├── test_gpt2_large.py
│ ├── test_gpt2_medium.py
│ └── test_gpt2_xl.py
├── resnet
│ ├── aux_files
│ │ ├── resnet101_elefant_output_ref.npy
│ │ ├── resnet152_elefant_output_ref.npy
│ │ ├── resnet18_elefant_output_ref.npy
│ │ ├── resnet34_elefant_output_ref.npy
│ │ └── resnet50_elefant_output_ref.npy
│ ├── test_resnet101.py
│ ├── test_resnet152.py
│ ├── test_resnet18.py
│ ├── test_resnet34.py
│ └── test_resnet50.py
├── stylegan2
│ ├── discriminator
│ │ ├── aux_files
│ │ │ ├── afhqcat_input_img.npy
│ │ │ ├── afhqcat_output_ref.npy
│ │ │ ├── afhqdog_input_img.npy
│ │ │ ├── afhqdog_output_ref.npy
│ │ │ ├── afhqwild_input_img.npy
│ │ │ ├── afhqwild_output_ref.npy
│ │ │ ├── brecahad_input_img.npy
│ │ │ ├── brecahad_output_ref.npy
│ │ │ ├── car_input_img.npy
│ │ │ ├── car_output_ref.npy
│ │ │ ├── cat_input_img.npy
│ │ │ ├── cat_output_ref.npy
│ │ │ ├── church_input_img.npy
│ │ │ ├── church_output_ref.npy
│ │ │ ├── cifar10_input_img.npy
│ │ │ ├── cifar10_input_label.npy
│ │ │ ├── cifar10_output_ref.npy
│ │ │ ├── ffhq_input_img.npy
│ │ │ ├── ffhq_output_ref.npy
│ │ │ ├── horse_input_img.npy
│ │ │ ├── horse_output_ref.npy
│ │ │ ├── metfaces_input_img.npy
│ │ │ └── metfaces_output_ref.npy
│ │ └── test_discriminator.py
│ └── generator
│ │ ├── aux_files
│ │ ├── afhqcat_input_z.npy
│ │ ├── afhqcat_output_ref.npy
│ │ ├── afhqdog_input_z.npy
│ │ ├── afhqdog_output_ref.npy
│ │ ├── afhqwild_input_z.npy
│ │ ├── afhqwild_output_ref.npy
│ │ ├── brecahad_input_z.npy
│ │ ├── brecahad_output_ref.npy
│ │ ├── car_input_z.npy
│ │ ├── car_output_ref.npy
│ │ ├── cat_input_z.npy
│ │ ├── cat_output_ref.npy
│ │ ├── church_input_z.npy
│ │ ├── church_output_ref.npy
│ │ ├── cifar10_input_label.npy
│ │ ├── cifar10_input_z.npy
│ │ ├── cifar10_output_ref.npy
│ │ ├── ffhq_input_z.npy
│ │ ├── ffhq_output_ref.npy
│ │ ├── horse_input_z.npy
│ │ ├── horse_output_ref.npy
│ │ ├── metfaces_input_z.npy
│ │ └── metfaces_output_ref.npy
│ │ └── test_generator.py
└── vgg
│ ├── aux_files
│ ├── vgg16_elefant_output_ref.npy
│ └── vgg19_elefant_output_ref.npy
│ ├── test_vgg16.py
│ └── test_vgg19.py
└── training
├── few_shot_gan_adaption
├── README.md
├── checkpoint.py
├── data_pipeline.py
├── dataset_utils
│ └── images_to_tfrecords.py
├── datatest.py
├── fid
│ ├── __init__.py
│ ├── core.py
│ ├── inception.py
│ └── utils.py
├── images
│ └── overview.jpg
├── main.py
├── training.py
├── training_steps.py
└── training_utils.py
├── resnet
├── README.md
├── main.py
├── requirements.txt
└── training.py
├── stylegan2
├── README.md
├── checkpoint.py
├── data_pipeline.py
├── dataset_utils
│ ├── crop_image_borders.py
│ └── images_to_tfrecords.py
├── fid
│ ├── __init__.py
│ ├── core.py
│ ├── inception.py
│ └── utils.py
├── generate_images.py
├── images
│ ├── anime_grid.jpg
│ ├── anime_overview.jpg
│ ├── anime_style_mixing.jpg
│ ├── ffhq_grid.jpg
│ ├── ffhq_overview.jpg
│ └── ffhq_style_mixing.jpg
├── main.py
├── requirements.txt
├── style_mixing.py
├── training.py
├── training_steps.py
└── training_utils.py
└── vgg
├── README.md
├── main.py
├── requirements.txt
└── training.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-vendored
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.swp
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 | Flax Models
3 | A collection of pretrained models in
Flax.
4 |
5 |
6 |
7 |
8 | ### About
9 | The goal of this project is to make current deep learning models more easily available for the awesome Jax/Flax ecosystem.
10 |
11 | ### Models
12 | * GPT2 [[model](flaxmodels/gpt2)]
13 | * StyleGAN2 [[model](flaxmodels/stylegan2)] [[training](training/stylegan2)]
14 | * ResNet{18, 34, 50, 101, 152} [[model](flaxmodels/resnet)] [[training](training/resnet)]
15 | * VGG{16, 19} [[model](flaxmodels/vgg)] [[training](training/vgg)]
16 | * FewShotGanAdaption [[model](flaxmodels/few_shot_gan_adaption)] [[training](training/few_shot_gan_adaption)]
17 |
18 |
19 | ### Installation
20 | You will need Python 3.7 or later.
21 |
22 | 1. For GPU usage, follow the Jax installation with CUDA.
23 | 2. Then install:
24 | ```sh
25 | > pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
26 | ```
27 | For CPU-only you can skip step 1.
28 |
29 | ### Documentation
30 | The documentation for the models can be found [here](docs/Documentation.md#models).
31 |
32 | ### Checkpoints
33 | The checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documented [here](docs/Documentation.md#1-checkpoints).
34 |
35 | ### Testing
36 | To run the tests, pytest needs to be installed.
37 | ```sh
38 | > git clone https://github.com/matthias-wright/flaxmodels.git
39 | > cd flaxmodels
40 | > python -m pytest tests/
41 | ```
42 | See [here](docs/Documentation.md#2-testing) for an explanation of the testing strategy.
43 |
44 |
45 | ### Acknowledgments
46 | Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by Marta Matyszczyk.
47 |
48 | ### License
49 | Each model has an individual license.
50 |
--------------------------------------------------------------------------------
/docs/img/flax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/docs/img/flax.png
--------------------------------------------------------------------------------
/docs/img/gpt2_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/docs/img/gpt2_diagram.png
--------------------------------------------------------------------------------
/flaxmodels/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .resnet import *
3 | from . import stylegan2
4 | from . import gpt2
5 | from . import few_shot_gan_adaption
6 |
7 | __version__ = '0.1.2'
8 |
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/README.md:
--------------------------------------------------------------------------------
1 | # Few-shot Image Generation via Cross-domain Correspondence
2 |
3 |
4 |
5 | Paper: https://arxiv.org/abs/2104.06820
6 | Repository: https://github.com/utkarshojha/few-shot-gan-adaptation
7 |
8 | ##### Table of Contents
9 | * [1. Basic Usage](#usage)
10 | * [2. Checkpoints](#checkpoints)
11 | * [3. Documentation](#documentation)
12 | * [4. Training](#training)
13 | * [5. Images](#images)
14 | * [6. License](#license)
15 |
16 |
17 |
18 | ## 1. Basic Usage
19 |
20 | ```python
21 | import jax
22 | import numpy as np
23 | import dill as pickle
24 | from PIL import Image
25 |
26 | import flaxmodels as fm
27 |
28 | ckpt = pickle.load(open('sketches.pickle', 'rb'))
29 | params = ckpt['params_ema_G']
30 |
31 | generator = fm.few_shot_gan_adaption.Generator()
32 |
33 | # Seed
34 | key = jax.random.PRNGKey(0)
35 |
36 | # Input noise
37 | z = jax.random.normal(key, shape=(4, 512))
38 |
39 | # Generate images
40 | images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const')
41 |
42 | # Normalize images to be in range [0, 1]
43 | images = (images - np.min(images)) / (np.max(images) - np.min(images))
44 |
45 | # Save images
46 | for i in range(images.shape[0]):
47 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')
48 |
49 | ```
50 |
51 |
52 | ## 2. Checkpoints
53 | * [Sketches](https://www.dropbox.com/s/azr6b316juhme6c/sketches.pickle?dl=1) (357,2 MB)
54 | * [Amedeo Modigliani](https://www.dropbox.com/s/xrh4a7wt2kggn4v/amedeo_modigliani.pickle?dl=1) (357,2 MB)
55 | * [Babies](https://www.dropbox.com/s/ntictyzrisqg5zh/babies.pickle?dl=1) (357,2 MB)
56 | * [Otto Dix](https://www.dropbox.com/s/u1g18nv73uac21m/otto_dix.pickle?dl=1) (357,2 MB)
57 | * [Rafael](https://www.dropbox.com/s/b8w928s4wffuo2c/raphael.pickle?dl=1) (357,2 MB)
58 |
59 |
60 |
61 | ## 3. Documentation
62 | The documentation can be found [here](../../docs/Documentation.md#few_shot_gan_adaption).
63 |
64 |
65 | ## 4. Training
66 | If you want to train this model in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/few_shot_gan_adaption).
67 |
68 |
69 | ## 5. Images
70 |
71 | ### Sketches
72 |
73 |
74 | ### Amedeo Modigliani
75 |
76 |
77 | ### Babies
78 |
79 |
80 | ### Otto Dix
81 |
82 |
83 | ### Rafael
84 |
85 |
86 |
87 |
88 | ## 6. License
89 | MIT License
90 |
91 |
92 |
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import MappingNetwork
2 | from .generator import SynthesisNetwork
3 | from .generator import Generator
4 | from .discriminator import Discriminator
5 |
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/images/AmedeoModigliani.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/AmedeoModigliani.jpg
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/images/Babies.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Babies.jpg
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/images/OttoDix.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/OttoDix.jpg
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/images/Rafael.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Rafael.jpg
--------------------------------------------------------------------------------
/flaxmodels/few_shot_gan_adaption/images/Sketches.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/few_shot_gan_adaption/images/Sketches.jpg
--------------------------------------------------------------------------------
/flaxmodels/gpt2/README.md:
--------------------------------------------------------------------------------
1 | # Better Language Models and Their Implications (GPT2)
2 |
3 |
4 | Paper: https://openai.com/blog/better-language-models/
5 | Repository: https://github.com/huggingface/transformers/tree/master/src/transformers/models/gpt2
6 |
7 |
8 | ##### Table of Contents
9 | * [1. Models](#models)
10 | * [2. Basic Usage](#usage)
11 | * [3. Documentation](#documentation)
12 | * [4. Acknowledgments](#ack)
13 | * [5. License](#license)
14 |
15 |
16 |
17 | ## 1. Models
18 |
19 | | Model | Parameters | Size | URL |
20 | | ------------- | ------------- | ------------- | ------------- |
21 | | gpt2 | ~ 120 Million | ~ 500 MB | https://huggingface.co/gpt2 |
22 | | gpt2-medium | ~ 350 Million | ~ 1.5 GB | https://huggingface.co/gpt2-medium |
23 | | gpt2-large | ~ 800 Million | ~ 3 GB | https://huggingface.co/gpt2-large |
24 | | gpt2-xl | ~ 1.5 Billion | ~ 6 GB | https://huggingface.co/gpt2-xl |
25 |
26 |
27 |
28 | ## 2. Basic Usage
29 | For more usage examples check out this [Colab](gpt2_demo.ipynb).
30 |
31 | This is very simple greedy text generation. There are more sophisticated methods out there.
32 | ```python
33 | import jax
34 | import jax.numpy as jnp
35 | import flaxmodels as fm
36 |
37 | key = jax.random.PRNGKey(0)
38 |
39 | # Initialize tokenizer
40 | tokenizer = fm.gpt2.get_tokenizer()
41 |
42 | # Encode start sequence
43 | generated = tokenizer.encode('The Manhattan bridge')
44 |
45 | context = jnp.array([generated])
46 | past = None
47 |
48 | # Initialize model
49 | # Models to choose from ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
50 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
51 | params = model.init(key, input_ids=context, past_key_values=past)
52 |
53 | for i in range(20):
54 | # Predict next token in sequence
55 | output = model.apply(params, input_ids=context, past_key_values=past, use_cache=True)
56 | token = jnp.argmax(output['logits'][..., -1, :])
57 | context = jnp.expand_dims(token, axis=0)
58 | # Add token to sequence
59 | generated += [token]
60 | # Update past keys and values
61 | past = output['past_key_values']
62 |
63 | # Decode sequence of tokens
64 | sequence = tokenizer.decode(generated)
65 | print(sequence)
66 | ```
67 |
68 |
69 | ## 3. Documentation
70 | The documentation can be found [here](../../docs/Documentation.md#gpt2).
71 |
72 |
73 | ## 4. Acknowledgments
74 | The tokenizer is taken from Huggingface.
75 |
76 |
77 | ## 5. License
78 | Apache-2.0 License
79 |
80 |
81 |
--------------------------------------------------------------------------------
/flaxmodels/gpt2/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt2 import GPT2Model
2 | from .gpt2 import GPT2LMHeadModel
3 | from .tokenizer import *
4 |
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/__init__.py
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/huggingface_transformers/__init__.py
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/gpt2/third_party/huggingface_transformers/utils/__init__.py
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/huggingface_transformers/utils/hf_api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import io
18 | import os
19 | from os.path import expanduser
20 | from typing import Dict, List, Optional, Tuple
21 |
22 | from tqdm import tqdm
23 |
24 | import requests
25 |
26 |
27 | ENDPOINT = "https://huggingface.co"
28 |
29 |
30 | class RepoObj:
31 | """
32 | HuggingFace git-based system, data structure that represents a file belonging to the current user.
33 | """
34 |
35 | def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
36 | self.filename = filename
37 | self.lastModified = lastModified
38 | self.commit = commit
39 | self.size = size
40 |
41 |
42 | class ModelSibling:
43 | """
44 | Data structure that represents a public file inside a model, accessible from huggingface.co
45 | """
46 |
47 | def __init__(self, rfilename: str, **kwargs):
48 | self.rfilename = rfilename # filename relative to the model root
49 | for k, v in kwargs.items():
50 | setattr(self, k, v)
51 |
52 |
53 | class ModelInfo:
54 | """
55 | Info about a public model accessible from huggingface.co
56 | """
57 |
58 | def __init__(
59 | self,
60 | modelId: Optional[str] = None, # id of model
61 | tags: List[str] = [],
62 | pipeline_tag: Optional[str] = None,
63 | siblings: Optional[List[Dict]] = None, # list of files that constitute the model
64 | **kwargs
65 | ):
66 | self.modelId = modelId
67 | self.tags = tags
68 | self.pipeline_tag = pipeline_tag
69 | self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
70 | for k, v in kwargs.items():
71 | setattr(self, k, v)
72 |
73 |
74 | class HfApi:
75 | def __init__(self, endpoint=None):
76 | self.endpoint = endpoint if endpoint is not None else ENDPOINT
77 |
78 | def login(self, username: str, password: str) -> str:
79 | """
80 | Call HF API to sign in a user and get a token if credentials are valid.
81 |
82 | Outputs: token if credentials are valid
83 |
84 | Throws: requests.exceptions.HTTPError if credentials are invalid
85 | """
86 | path = f"{self.endpoint}/api/login"
87 | r = requests.post(path, json={"username": username, "password": password})
88 | r.raise_for_status()
89 | d = r.json()
90 | return d["token"]
91 |
92 | def whoami(self, token: str) -> Tuple[str, List[str]]:
93 | """
94 | Call HF API to know "whoami"
95 | """
96 | path = f"{self.endpoint}/api/whoami"
97 | r = requests.get(path, headers={"authorization": f"Bearer {token}"})
98 | r.raise_for_status()
99 | d = r.json()
100 | return d["user"], d["orgs"]
101 |
102 | def logout(self, token: str) -> None:
103 | """
104 | Call HF API to log out.
105 | """
106 | path = f"{self.endpoint}/api/logout"
107 | r = requests.post(path, headers={"authorization": f"Bearer {token}"})
108 | r.raise_for_status()
109 |
110 | def model_list(self) -> List[ModelInfo]:
111 | """
112 | Get the public list of all the models on huggingface.co
113 | """
114 | path = f"{self.endpoint}/api/models"
115 | r = requests.get(path)
116 | r.raise_for_status()
117 | d = r.json()
118 | return [ModelInfo(**x) for x in d]
119 |
120 | def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
121 | """
122 | HuggingFace git-based system, used for models.
123 |
124 | Call HF API to list all stored files for user (or one of their organizations).
125 | """
126 | path = f"{self.endpoint}/api/repos/ls"
127 | params = {"organization": organization} if organization is not None else None
128 | r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"})
129 | r.raise_for_status()
130 | d = r.json()
131 | return [RepoObj(**x) for x in d]
132 |
133 | def create_repo(
134 | self,
135 | token: str,
136 | name: str,
137 | organization: Optional[str] = None,
138 | private: Optional[bool] = None,
139 | exist_ok=False,
140 | lfsmultipartthresh: Optional[int] = None,
141 | ) -> str:
142 | """
143 | HuggingFace git-based system, used for models.
144 |
145 | Call HF API to create a whole repo.
146 |
147 | Params:
148 | private: Whether the model repo should be private (requires a paid huggingface.co account)
149 |
150 | exist_ok: Do not raise an error if repo already exists
151 |
152 | lfsmultipartthresh: Optional: internal param for testing purposes.
153 | """
154 | path = f"{self.endpoint}/api/repos/create"
155 | json = {"name": name, "organization": organization, "private": private}
156 | if lfsmultipartthresh is not None:
157 | json["lfsmultipartthresh"] = lfsmultipartthresh
158 | r = requests.post(
159 | path,
160 | headers={"authorization": f"Bearer {token}"},
161 | json=json,
162 | )
163 | if exist_ok and r.status_code == 409:
164 | return ""
165 | r.raise_for_status()
166 | d = r.json()
167 | return d["url"]
168 |
169 | def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
170 | """
171 | HuggingFace git-based system, used for models.
172 |
173 | Call HF API to delete a whole repo.
174 |
175 | CAUTION(this is irreversible).
176 | """
177 | path = f"{self.endpoint}/api/repos/delete"
178 | r = requests.delete(
179 | path,
180 | headers={"authorization": f"Bearer {token}"},
181 | json={"name": name, "organization": organization},
182 | )
183 | r.raise_for_status()
184 |
185 |
186 | class TqdmProgressFileReader:
187 | """
188 | Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
189 | tqdm progress bar.
190 |
191 | see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
192 | """
193 |
194 | def __init__(self, f: io.BufferedReader):
195 | self.f = f
196 | self.total_size = os.fstat(f.fileno()).st_size
197 | self.pbar = tqdm(total=self.total_size, leave=False)
198 | self.read = f.read
199 | f.read = self._read
200 |
201 | def _read(self, n=-1):
202 | self.pbar.update(n)
203 | return self.read(n)
204 |
205 | def close(self):
206 | self.pbar.close()
207 |
208 |
209 | class HfFolder:
210 | path_token = expanduser("~/.huggingface/token")
211 |
212 | @classmethod
213 | def save_token(cls, token):
214 | """
215 | Save token, creating folder as needed.
216 | """
217 | os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
218 | with open(cls.path_token, "w+") as f:
219 | f.write(token)
220 |
221 | @classmethod
222 | def get_token(cls):
223 | """
224 | Get token or None if not existent.
225 | """
226 | try:
227 | with open(cls.path_token, "r") as f:
228 | return f.read()
229 | except FileNotFoundError:
230 | pass
231 |
232 | @classmethod
233 | def delete_token(cls):
234 | """
235 | Delete token. Do not fail if token does not exist.
236 | """
237 | try:
238 | os.remove(cls.path_token)
239 | except FileNotFoundError:
240 | pass
241 |
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/huggingface_transformers/utils/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Optuna, Hugging Face
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logging utilities. """
16 |
17 | import logging
18 | import os
19 | import sys
20 | import threading
21 | from logging import CRITICAL # NOQA
22 | from logging import DEBUG # NOQA
23 | from logging import ERROR # NOQA
24 | from logging import FATAL # NOQA
25 | from logging import INFO # NOQA
26 | from logging import NOTSET # NOQA
27 | from logging import WARN # NOQA
28 | from logging import WARNING # NOQA
29 | from typing import Optional
30 |
31 |
32 | _lock = threading.Lock()
33 | _default_handler: Optional[logging.Handler] = None
34 |
35 | log_levels = {
36 | "debug": logging.DEBUG,
37 | "info": logging.INFO,
38 | "warning": logging.WARNING,
39 | "error": logging.ERROR,
40 | "critical": logging.CRITICAL,
41 | }
42 |
43 | _default_log_level = logging.WARNING
44 |
45 |
46 | def _get_default_logging_level():
47 | """
48 | If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
49 | not - fall back to ``_default_log_level``
50 | """
51 | env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
52 | if env_level_str:
53 | if env_level_str in log_levels:
54 | return log_levels[env_level_str]
55 | else:
56 | logging.getLogger().warning(
57 | f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
58 | f"has to be one of: { ', '.join(log_levels.keys()) }"
59 | )
60 | return _default_log_level
61 |
62 |
63 | def _get_library_name() -> str:
64 |
65 | return __name__.split(".")[0]
66 |
67 |
68 | def _get_library_root_logger() -> logging.Logger:
69 |
70 | return logging.getLogger(_get_library_name())
71 |
72 |
73 | def _configure_library_root_logger() -> None:
74 |
75 | global _default_handler
76 |
77 | with _lock:
78 | if _default_handler:
79 | # This library has already configured the library root logger.
80 | return
81 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
82 | _default_handler.flush = sys.stderr.flush
83 |
84 | # Apply our default configuration to the library root logger.
85 | library_root_logger = _get_library_root_logger()
86 | library_root_logger.addHandler(_default_handler)
87 | library_root_logger.setLevel(_get_default_logging_level())
88 | library_root_logger.propagate = False
89 |
90 |
91 | def _reset_library_root_logger() -> None:
92 |
93 | global _default_handler
94 |
95 | with _lock:
96 | if not _default_handler:
97 | return
98 |
99 | library_root_logger = _get_library_root_logger()
100 | library_root_logger.removeHandler(_default_handler)
101 | library_root_logger.setLevel(logging.NOTSET)
102 | _default_handler = None
103 |
104 |
105 | def get_logger(name: Optional[str] = None) -> logging.Logger:
106 | """
107 | Return a logger with the specified name.
108 |
109 | This function is not supposed to be directly accessed unless you are writing a custom transformers module.
110 | """
111 |
112 | if name is None:
113 | name = _get_library_name()
114 |
115 | _configure_library_root_logger()
116 | return logging.getLogger(name)
117 |
118 |
119 | def get_verbosity() -> int:
120 | """
121 | Return the current level for the 🤗 Transformers's root logger as an int.
122 |
123 | Returns:
124 | :obj:`int`: The logging level.
125 |
126 | .. note::
127 |
128 | 🤗 Transformers has following logging levels:
129 |
130 | - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
131 | - 40: ``transformers.logging.ERROR``
132 | - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
133 | - 20: ``transformers.logging.INFO``
134 | - 10: ``transformers.logging.DEBUG``
135 | """
136 |
137 | _configure_library_root_logger()
138 | return _get_library_root_logger().getEffectiveLevel()
139 |
140 |
141 | def set_verbosity(verbosity: int) -> None:
142 | """
143 | Set the vebosity level for the 🤗 Transformers's root logger.
144 |
145 | Args:
146 | verbosity (:obj:`int`):
147 | Logging level, e.g., one of:
148 |
149 | - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
150 | - ``transformers.logging.ERROR``
151 | - ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
152 | - ``transformers.logging.INFO``
153 | - ``transformers.logging.DEBUG``
154 | """
155 |
156 | _configure_library_root_logger()
157 | _get_library_root_logger().setLevel(verbosity)
158 |
159 |
160 | def set_verbosity_info():
161 | """Set the verbosity to the :obj:`INFO` level."""
162 | return set_verbosity(INFO)
163 |
164 |
165 | def set_verbosity_warning():
166 | """Set the verbosity to the :obj:`WARNING` level."""
167 | return set_verbosity(WARNING)
168 |
169 |
170 | def set_verbosity_debug():
171 | """Set the verbosity to the :obj:`DEBUG` level."""
172 | return set_verbosity(DEBUG)
173 |
174 |
175 | def set_verbosity_error():
176 | """Set the verbosity to the :obj:`ERROR` level."""
177 | return set_verbosity(ERROR)
178 |
179 |
180 | def disable_default_handler() -> None:
181 | """Disable the default handler of the HuggingFace Transformers's root logger."""
182 |
183 | _configure_library_root_logger()
184 |
185 | assert _default_handler is not None
186 | _get_library_root_logger().removeHandler(_default_handler)
187 |
188 |
189 | def enable_default_handler() -> None:
190 | """Enable the default handler of the HuggingFace Transformers's root logger."""
191 |
192 | _configure_library_root_logger()
193 |
194 | assert _default_handler is not None
195 | _get_library_root_logger().addHandler(_default_handler)
196 |
197 |
198 | def add_handler(handler: logging.Handler) -> None:
199 | """adds a handler to the HuggingFace Transformers's root logger."""
200 |
201 | _configure_library_root_logger()
202 |
203 | assert handler is not None
204 | _get_library_root_logger().addHandler(handler)
205 |
206 |
207 | def remove_handler(handler: logging.Handler) -> None:
208 | """removes given handler from the HuggingFace Transformers's root logger."""
209 |
210 | _configure_library_root_logger()
211 |
212 | assert handler is not None and handler not in _get_library_root_logger().handlers
213 | _get_library_root_logger().removeHandler(handler)
214 |
215 |
216 | def disable_propagation() -> None:
217 | """
218 | Disable propagation of the library log outputs. Note that log propagation is disabled by default.
219 | """
220 |
221 | _configure_library_root_logger()
222 | _get_library_root_logger().propagate = False
223 |
224 |
225 | def enable_propagation() -> None:
226 | """
227 | Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
228 | prevent double logging if the root logger has been configured.
229 | """
230 |
231 | _configure_library_root_logger()
232 | _get_library_root_logger().propagate = True
233 |
234 |
235 | def enable_explicit_format() -> None:
236 | """
237 | Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
238 |
239 | ::
240 |
241 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
242 |
243 | All handlers currently bound to the root logger are affected by this method.
244 | """
245 | handlers = _get_library_root_logger().handlers
246 |
247 | for handler in handlers:
248 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
249 | handler.setFormatter(formatter)
250 |
251 |
252 | def reset_format() -> None:
253 | """
254 | Resets the formatting for HuggingFace Transformers's loggers.
255 |
256 | All handlers currently bound to the root logger are affected by this method.
257 | """
258 | handlers = _get_library_root_logger().handlers
259 |
260 | for handler in handlers:
261 | handler.setFormatter(None)
262 |
--------------------------------------------------------------------------------
/flaxmodels/gpt2/third_party/huggingface_transformers/utils/versions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities for working with package versions
16 | """
17 |
18 | import operator
19 | import re
20 | import sys
21 | from typing import Optional
22 |
23 | from packaging import version
24 |
25 |
26 | # The package importlib_metadata is in a different place, depending on the python version.
27 | if sys.version_info < (3, 8):
28 | import importlib_metadata
29 | else:
30 | import importlib.metadata as importlib_metadata
31 |
32 |
33 | ops = {
34 | "<": operator.lt,
35 | "<=": operator.le,
36 | "==": operator.eq,
37 | "!=": operator.ne,
38 | ">=": operator.ge,
39 | ">": operator.gt,
40 | }
41 |
42 |
43 | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
44 | if got_ver is None:
45 | raise ValueError("got_ver is None")
46 | if want_ver is None:
47 | raise ValueError("want_ver is None")
48 | if not ops[op](version.parse(got_ver), version.parse(want_ver)):
49 | raise ImportError(
50 | f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
51 | )
52 |
53 |
54 | def require_version(requirement: str, hint: Optional[str] = None) -> None:
55 | """
56 | Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
57 |
58 | The installed module version comes from the `site-packages` dir via `importlib_metadata`.
59 |
60 | Args:
61 | requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
62 | hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
63 |
64 | Example::
65 |
66 | require_version("pandas>1.1.2")
67 | require_version("numpy>1.18.5", "this is important to have for whatever reason")
68 |
69 | """
70 |
71 | hint = f"\n{hint}" if hint is not None else ""
72 |
73 | # non-versioned check
74 | if re.match(r"^[\w_\-\d]+$", requirement):
75 | pkg, op, want_ver = requirement, None, None
76 | else:
77 | match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
78 | if not match:
79 | raise ValueError(
80 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
81 | )
82 | pkg, want_full = match[0]
83 | want_range = want_full.split(",") # there could be multiple requirements
84 | wanted = {}
85 | for w in want_range:
86 | match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
87 | if not match:
88 | raise ValueError(
89 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
90 | )
91 | op, want_ver = match[0]
92 | wanted[op] = want_ver
93 | if op not in ops:
94 | raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
95 |
96 | # special case
97 | if pkg == "python":
98 | got_ver = ".".join([str(x) for x in sys.version_info[:3]])
99 | for op, want_ver in wanted.items():
100 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
101 | return
102 |
103 | # check if any version is installed
104 | try:
105 | got_ver = importlib_metadata.version(pkg)
106 | except importlib_metadata.PackageNotFoundError:
107 | raise importlib_metadata.PackageNotFoundError(
108 | f"The '{requirement}' distribution was not found and is required by this application. {hint}"
109 | )
110 |
111 | # check that the right version is installed if version number or a range was provided
112 | if want_ver is not None:
113 | for op, want_ver in wanted.items():
114 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
115 |
116 |
117 | def require_version_core(requirement):
118 | """ require_version wrapper which emits a core-specific hint on failure """
119 | hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
120 | return require_version(requirement, hint)
121 |
122 |
123 | def require_version_examples(requirement):
124 | """ require_version wrapper which emits examples-specific hint on failure """
125 | hint = "Try: pip install -r examples/requirements.txt"
126 | return require_version(requirement, hint)
127 |
--------------------------------------------------------------------------------
/flaxmodels/gpt2/tokenizer.py:
--------------------------------------------------------------------------------
1 | from .third_party.huggingface_transformers.configuration_gpt2 import GPT2Tokenizer
2 | from .. import utils
3 |
4 |
5 | def get_tokenizer(errors='replace',
6 | unk_token='<|endoftext|>',
7 | bos_token='<|endoftext|>',
8 | eos_token='<|endoftext|>',
9 | add_prefix_space=False,
10 | ckpt_dir=None):
11 | """
12 | Returns the GPT2Tokenizer from Huggingface with loaded merges and vocab files.
13 | See: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizer
14 |
15 | Args:
16 | errors (str): Paradigm to follow when decoding bytes to UTF-8.
17 | unk_token (str): The unknown token. A token that is not in the
18 | vocabulary cannot be converted to an ID and is set to be this token instead.
19 | bos_token (str): The beginning of sequence token.
20 | eos_token (str): The end of sequence token.
21 | add_prefix_space (bool): Whether or not to add an initial space to the input.
22 | This allows to treat the leading word just as any other word.
23 | ckpt_dir (str): Path to directory, where merges and vocab files are downloaded to.
24 | If None, the files will be downloaded to a temp directory.
25 |
26 | Returns:
27 | (GPT2Tokenizer): GPT2 Tokenizer.
28 |
29 | """
30 | merges_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/7f5n1gf348sy1mt/merges.txt?dl=1')
31 | vocab_file = utils.download(ckpt_dir, 'https://www.dropbox.com/s/s93xkhgcac5nbmn/vocab.json?dl=1')
32 |
33 | return GPT2Tokenizer(vocab_file=vocab_file,
34 | merges_file=merges_file,
35 | errors=errors,
36 | unk_token=unk_token,
37 | bos_token=bos_token,
38 | eos_token=eos_token,
39 | add_prefix_space=add_prefix_space)
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/flaxmodels/resnet/README.md:
--------------------------------------------------------------------------------
1 | # Deep Residual Learning for Image Recognition
2 |
3 | Paper: https://arxiv.org/abs/1512.03385
4 | Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
5 |
6 | ##### Table of Contents
7 | * [1. Important Note](#note)
8 | * [2. Basic Usage](#usage)
9 | * [3. Documentation](#documentation)
10 | * [4. Training](#training)
11 | * [5. License](#license)
12 |
13 |
14 | ## 1. Important Note
15 | Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use `normalize=False` (see [here](https://github.com/matthias-wright/flaxmodels/blob/main/docs/Documentation.md#33-resnet18-34-50-101-152) for details).
16 |
17 |
18 | ## 2. Basic Usage
19 | For more usage examples check out this [Colab](resnet_demo.ipynb).
20 |
21 | ```python
22 | from PIL import Image
23 | import jax
24 | import jax.numpy as jnp
25 | import flaxmodels as fm
26 |
27 | key = jax.random.PRNGKey(0)
28 |
29 | # Load image
30 | img = Image.open('example.jpg')
31 | # Image should be in range [0, 1]
32 | x = jnp.array(img, dtype=jnp.float32) / 255.0
33 | # Add batch dimension
34 | x = jnp.expand_dims(x, axis=0)
35 |
36 | resnet18 = fm.ResNet18(output='logits', pretrained='imagenet')
37 | params = resnet18.init(key, x)
38 | # Shape [1, 1000]
39 | out = resnet18.apply(params, x, train=False)
40 |
41 | ```
42 | Usage is equivalent for ResNet34, ResNet50, ResNet101, and Resnet152.
43 |
44 |
45 | ## 3. Documentation
46 | The documentation can be found [here](../../docs/Documentation.md#resnet).
47 |
48 |
49 | ## 4. Training
50 | If you want to train ResNet in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/resnet).
51 |
52 |
53 | ## 5. License
54 | MIT License
55 |
56 |
57 |
--------------------------------------------------------------------------------
/flaxmodels/resnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import ResNet18
2 | from .resnet import ResNet34
3 | from .resnet import ResNet50
4 | from .resnet import ResNet101
5 | from .resnet import ResNet152
6 |
--------------------------------------------------------------------------------
/flaxmodels/resnet/ops.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | from flax.core import FrozenDict
3 | import jax.numpy as jnp
4 |
5 | from flax.linen.module import compact, merge_param
6 | from typing import (Any, Callable, Optional, Tuple)
7 | from jax.nn import initializers
8 | from jax import lax
9 |
10 | PRNGKey = Any
11 | Array = Any
12 | Shape = Tuple[int]
13 | Dtype = Any
14 |
15 |
16 | #---------------------------------------------------------------#
17 | # Normalization
18 | #---------------------------------------------------------------#
19 | def batch_norm(x, train, epsilon=1e-05, momentum=0.99, params=None, dtype='float32'):
20 | if params is None:
21 | x = BatchNorm(epsilon=epsilon,
22 | momentum=momentum,
23 | use_running_average=not train,
24 | dtype=dtype)(x)
25 | else:
26 | x = BatchNorm(epsilon=epsilon,
27 | momentum=momentum,
28 | bias_init=lambda *_ : jnp.array(params['bias']),
29 | scale_init=lambda *_ : jnp.array(params['scale']),
30 | mean_init=lambda *_ : jnp.array(params['mean']),
31 | var_init=lambda *_ : jnp.array(params['var']),
32 | use_running_average=not train,
33 | dtype=dtype)(x)
34 | return x
35 |
36 |
37 | def _absolute_dims(rank, dims):
38 | return tuple([rank + dim if dim < 0 else dim for dim in dims])
39 |
40 |
41 | class BatchNorm(nn.Module):
42 | """BatchNorm Module.
43 |
44 | Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
45 |
46 | Attributes:
47 | use_running_average: if True, the statistics stored in batch_stats
48 | will be used instead of computing the batch statistics on the input.
49 | axis: the feature or non-batch axis of the input.
50 | momentum: decay rate for the exponential moving average of the batch statistics.
51 | epsilon: a small float added to variance to avoid dividing by zero.
52 | dtype: the dtype of the computation (default: float32).
53 | use_bias: if True, bias (beta) is added.
54 | use_scale: if True, multiply by scale (gamma).
55 | When the next layer is linear (also e.g. nn.relu), this can be disabled
56 | since the scaling will be done by the next layer.
57 | bias_init: initializer for bias, by default, zero.
58 | scale_init: initializer for scale, by default, one.
59 | axis_name: the axis name used to combine batch statistics from multiple
60 | devices. See `jax.pmap` for a description of axis names (default: None).
61 | axis_index_groups: groups of axis indices within that named axis
62 | representing subsets of devices to reduce over (default: None). For
63 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over
64 | the examples on the first two and last two devices. See `jax.lax.psum`
65 | for more details.
66 | """
67 | use_running_average: Optional[bool] = None
68 | axis: int = -1
69 | momentum: float = 0.99
70 | epsilon: float = 1e-5
71 | dtype: Dtype = jnp.float32
72 | use_bias: bool = True
73 | use_scale: bool = True
74 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
75 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
76 | mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
77 | var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
78 | axis_name: Optional[str] = None
79 | axis_index_groups: Any = None
80 |
81 | @compact
82 | def __call__(self, x, use_running_average: Optional[bool] = None):
83 | """Normalizes the input using batch statistics.
84 |
85 | NOTE:
86 | During initialization (when parameters are mutable) the running average
87 | of the batch statistics will not be updated. Therefore, the inputs
88 | fed during initialization don't need to match that of the actual input
89 | distribution and the reduction axis (set with `axis_name`) does not have
90 | to exist.
91 | Args:
92 | x: the input to be normalized.
93 | use_running_average: if true, the statistics stored in batch_stats
94 | will be used instead of computing the batch statistics on the input.
95 | Returns:
96 | Normalized inputs (the same shape as inputs).
97 | """
98 | use_running_average = merge_param(
99 | 'use_running_average', self.use_running_average, use_running_average)
100 | x = jnp.asarray(x, jnp.float32)
101 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
102 | axis = _absolute_dims(x.ndim, axis)
103 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
104 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
105 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
106 |
107 | # see NOTE above on initialization behavior
108 | initializing = self.is_mutable_collection('params')
109 |
110 | ra_mean = self.variable('batch_stats', 'mean',
111 | self.mean_init,
112 | reduced_feature_shape)
113 | ra_var = self.variable('batch_stats', 'var',
114 | self.var_init,
115 | reduced_feature_shape)
116 |
117 | if use_running_average:
118 | mean, var = ra_mean.value, ra_var.value
119 | else:
120 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
121 | mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
122 | if self.axis_name is not None and not initializing:
123 | concatenated_mean = jnp.concatenate([mean, mean2])
124 | mean, mean2 = jnp.split(
125 | lax.pmean(
126 | concatenated_mean,
127 | axis_name=self.axis_name,
128 | axis_index_groups=self.axis_index_groups), 2)
129 | var = mean2 - lax.square(mean)
130 |
131 | if not initializing:
132 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
133 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
134 |
135 | y = x - mean.reshape(feature_shape)
136 | mul = lax.rsqrt(var + self.epsilon)
137 | if self.use_scale:
138 | scale = self.param('scale',
139 | self.scale_init,
140 | reduced_feature_shape).reshape(feature_shape)
141 | mul = mul * scale
142 | y = y * mul
143 | if self.use_bias:
144 | bias = self.param('bias',
145 | self.bias_init,
146 | reduced_feature_shape).reshape(feature_shape)
147 | y = y + bias
148 | return jnp.asarray(y, self.dtype)
149 |
150 |
151 |
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/README.md:
--------------------------------------------------------------------------------
1 | # Analyzing and Improving the Image Quality of StyleGAN
2 |
3 |
4 |
5 | Paper: https://arxiv.org/abs/1912.04958
6 | Repository: https://github.com/NVlabs/stylegan2 and https://github.com/NVlabs/stylegan2-ada
7 |
8 | ##### Table of Contents
9 | * [1. Basic Usage](#usage)
10 | * [2. Documentation](#documentation)
11 | * [3. Training](#training)
12 | * [4. Pretrained Models](#models)
13 | * [5. License](#license)
14 |
15 |
16 |
17 | ## 1. Basic Usage
18 | For more usage examples check out this [Colab](stylegan2_demo.ipynb).
19 |
20 | ```python
21 | import numpy as np
22 | from PIL import Image
23 | import jax
24 | import jax.numpy as jnp
25 | import flaxmodels as fm
26 |
27 | # Seed
28 | key = jax.random.PRNGKey(0)
29 |
30 | # Input noise
31 | z = jax.random.normal(key, shape=(4, 512))
32 |
33 | generator = fm.stylegan2.Generator(pretrained='metfaces')
34 | params = generator.init(key, z)
35 | images = generator.apply(params, z, train=False)
36 |
37 | # Normalize images to be in range [0, 1]
38 | images = (images - jnp.min(images)) / (jnp.max(images) - jnp.min(images))
39 |
40 | # Save images
41 | for i in range(images.shape[0]):
42 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')
43 |
44 | ```
45 |
46 |
47 |
48 |
49 | ## 2. Documentation
50 | The documentation can be found [here](../../docs/Documentation.md#stylegan2).
51 |
52 |
53 | ## 3. Training
54 | If you want to train StyleGAN2 in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/stylegan2).
55 |
56 |
57 | ## 4. Pretrained Models
58 |
59 | ### Metfaces
60 |
61 |
62 | ### FFHQ
63 |
64 |
65 | ### AFHQ Wild
66 |
67 |
68 | ### AFHQ Dog
69 |
70 |
71 | ### AFHQ Cat
72 |
73 |
74 | ### LSUN Cat
75 |
76 |
77 | ### LSUN Horse
78 |
79 |
80 | ### LSUN Car
81 |
82 |
83 | ### BreCaHAD
84 |
85 |
86 | ### CIFAR-10
87 |
88 |
89 | ### LSUN Church
90 |
91 |
92 |
93 |
94 | ## 5. License
95 | Nvidia Source Code License-NC
96 |
97 |
98 |
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import SynthesisNetwork
2 | from .generator import MappingNetwork
3 | from .generator import Generator
4 | from .discriminator import Discriminator
5 |
6 |
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/afhqcat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqcat.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/afhqdog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqdog.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/afhqwild.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/afhqwild.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/brecahad.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/brecahad.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/car.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/cat.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/church.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/church.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/cifar10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/cifar10.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/ffhq.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/ffhq.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/gen_images_w_trunc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_w_trunc.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/gen_images_with_labels.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_with_labels.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/gen_images_wo_trunc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/gen_images_wo_trunc.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/horse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/horse.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/metfaces.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/metfaces.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/style_mixing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/style_mixing.jpg
--------------------------------------------------------------------------------
/flaxmodels/stylegan2/images/title.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/flaxmodels/stylegan2/images/title.jpg
--------------------------------------------------------------------------------
/flaxmodels/utils.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import requests
3 | import os
4 | import tempfile
5 |
6 |
7 | def download(ckpt_dir, url):
8 | name = url[url.rfind('/') + 1 : url.rfind('?')]
9 | if ckpt_dir is None:
10 | ckpt_dir = tempfile.gettempdir()
11 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
12 | ckpt_file = os.path.join(ckpt_dir, name)
13 | if not os.path.exists(ckpt_file):
14 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
15 | if not os.path.exists(ckpt_dir):
16 | os.makedirs(ckpt_dir)
17 |
18 | response = requests.get(url, stream=True)
19 | total_size_in_bytes = int(response.headers.get('content-length', 0))
20 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
21 |
22 | # first create temp file, in case the download fails
23 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
24 | with open(ckpt_file_temp, 'wb') as file:
25 | for data in response.iter_content(chunk_size=1024):
26 | progress_bar.update(len(data))
27 | file.write(data)
28 | progress_bar.close()
29 |
30 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
31 | print('An error occured while downloading, please try again.')
32 | if os.path.exists(ckpt_file_temp):
33 | os.remove(ckpt_file_temp)
34 | else:
35 | # if download was successful, rename the temp file
36 | os.rename(ckpt_file_temp, ckpt_file)
37 | return ckpt_file
38 |
--------------------------------------------------------------------------------
/flaxmodels/vgg/README.md:
--------------------------------------------------------------------------------
1 | # Very Deep Convolutional Networks for Large-Scale Image Recognition
2 |
3 | Paper: https://arxiv.org/abs/1409.1556
4 | Project Page: https://www.robots.ox.ac.uk/~vgg/research/very_deep/
5 | Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
6 |
7 | ##### Table of Contents
8 | * [1. Important Note](#note)
9 | * [2. Basic Usage](#usage)
10 | * [3. Documentation](#documentation)
11 | * [4. Training](#training)
12 | * [5. License](#license)
13 |
14 |
15 |
16 | ## 1. Important Note
17 | Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use `normalize=False` (see [here](https://github.com/matthias-wright/flaxmodels/blob/main/docs/Documentation.md#34-vgg16-19) for details).
18 |
19 |
20 | ## 2. Basic Usage
21 | For more usage examples check out this [Colab](vgg_demo.ipynb).
22 |
23 | ```python
24 | from PIL import Image
25 | import jax
26 | import jax.numpy as jnp
27 | import flaxmodels as fm
28 |
29 | key = jax.random.PRNGKey(0)
30 |
31 | # Load image
32 | img = Image.open('example.jpg')
33 | # Image must be 224x224 if classification head is included
34 | img = img.resize((224, 224))
35 | # Image should be in range [0, 1]
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | # Add batch dimension
38 | x = jnp.expand_dims(x, axis=0)
39 |
40 | vgg16 = fm.VGG16(output='logits', pretrained='imagenet')
41 | params = vgg16.init(key, x)
42 | out = vgg16.apply(params, x, train=False)
43 |
44 | ```
45 | Usage is equivalent for VGG19.
46 |
47 |
48 | ## 3. Documentation
49 | The documentation can be found [here](../../docs/Documentation.md#vgg).
50 |
51 |
52 | ## 4. Training
53 | If you want to train VGG in Jax/Flax, go [here](https://github.com/matthias-wright/flaxmodels/tree/main/training/vgg).
54 |
55 |
56 | ## 5. License
57 | Creative Commons Attribution License
58 |
--------------------------------------------------------------------------------
/flaxmodels/vgg/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import VGG16
2 | from .vgg import VGG19
3 |
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import os
3 |
4 |
5 | directory = os.path.abspath(os.path.dirname(__file__))
6 | with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:
7 | long_description = f.read()
8 |
9 | setup(name='flaxmodels',
10 | version='0.1.3',
11 | url='https://github.com/matthias-wright/flaxmodels',
12 | author='Matthias Wright',
13 | packages=find_packages(),
14 | install_requires=['h5py>=2.10.0',
15 | 'numpy>=1.19.5',
16 | 'requests>=2.23.0',
17 | 'packaging>=20.9',
18 | 'dataclasses>=0.6',
19 | 'filelock>=3.0.12',
20 | 'jax>=0.3',
21 | 'jaxlib',
22 | 'flax>=0.4.0',
23 | 'Pillow>=7.1.2',
24 | 'regex>=2021.4.4',
25 | 'tqdm>=4.60.0'],
26 | extras_require={
27 | 'testing': ['pytest'],
28 | },
29 | python_requires='>=3.6',
30 | license='Each model has an individual license.',
31 | description='A collection of pretrained models in Flax.',
32 | long_description=long_description,
33 | long_description_content_type='text/markdown')
34 |
--------------------------------------------------------------------------------
/tests/aux_files/elefant.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/aux_files/elefant.jpg
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy
--------------------------------------------------------------------------------
/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy
--------------------------------------------------------------------------------
/tests/gpt2/test_gpt2.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flaxmodels as fm
4 |
5 |
6 | def test_reference_output_lm_head_input_ids():
7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_input.npy')
8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_labels.npy')
9 |
10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_logits_ref.npy')
11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_ids_loss_ref.npy')
12 |
13 | key = jax.random.PRNGKey(0)
14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
15 |
16 | params = model.init(key, input_ids=input_ids, labels=labels)
17 | output = model.apply(params, input_ids=input_ids, labels=labels)
18 | logits, loss = output['logits'], output['loss']
19 |
20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
22 |
23 | assert diff_logits < 1e-4 and diff_loss < 1e-4
24 |
25 |
26 | def test_reference_output_lm_head_input_embds():
27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_input.npy')
28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_labels.npy')
29 |
30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_logits_ref.npy')
31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_lmhead_input_embds_loss_ref.npy')
32 |
33 | key = jax.random.PRNGKey(0)
34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2')
35 |
36 | params = model.init(key, input_embds=input_embds, labels=labels)
37 | output = model.apply(params, input_embds=input_embds, labels=labels)
38 | logits, loss = output['logits'], output['loss']
39 |
40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
42 |
43 | assert diff_logits < 1e-4 and diff_loss < 1e-4
44 |
45 |
46 | def test_reference_output_model_input_ids():
47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_input.npy')
48 |
49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_ids_output_ref.npy')
50 |
51 | key = jax.random.PRNGKey(0)
52 | model = fm.gpt2.GPT2Model(pretrained='gpt2')
53 |
54 | params = model.init(key, input_ids=input_ids)
55 | output = model.apply(params, input_ids=input_ids)
56 |
57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
58 |
59 | assert diff < 1e-4
60 |
61 |
62 | def test_reference_output_model_input_embds():
63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_input.npy')
64 |
65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2/gpt2_model_input_embds_output_ref.npy')
66 |
67 | key = jax.random.PRNGKey(0)
68 | model = fm.gpt2.GPT2Model(pretrained='gpt2')
69 |
70 | params = model.init(key, input_embds=input_embds)
71 | output = model.apply(params, input_embds=input_embds)
72 |
73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
74 |
75 | assert diff < 1e-4
76 |
77 |
--------------------------------------------------------------------------------
/tests/gpt2/test_gpt2_large.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flaxmodels as fm
4 |
5 |
6 | def test_reference_output_lm_head_input_ids():
7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_input.npy')
8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_labels.npy')
9 |
10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_logits_ref.npy')
11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_ids_loss_ref.npy')
12 |
13 | key = jax.random.PRNGKey(0)
14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-large')
15 |
16 | params = model.init(key, input_ids=input_ids, labels=labels)
17 | output = model.apply(params, input_ids=input_ids, labels=labels)
18 | logits, loss = output['logits'], output['loss']
19 |
20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
22 |
23 | assert diff_logits < 1e-3 and diff_loss < 1e-3
24 |
25 |
26 | def test_reference_output_lm_head_input_embds():
27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_input.npy')
28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_labels.npy')
29 |
30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_logits_ref.npy')
31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_lmhead_input_embds_loss_ref.npy')
32 |
33 | key = jax.random.PRNGKey(0)
34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-large')
35 |
36 | params = model.init(key, input_embds=input_embds, labels=labels)
37 | output = model.apply(params, input_embds=input_embds, labels=labels)
38 | logits, loss = output['logits'], output['loss']
39 |
40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
42 |
43 | assert diff_logits < 1e-3 and diff_loss < 1e-3
44 |
45 |
46 | def test_reference_output_model_input_ids():
47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_input.npy')
48 |
49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_ids_output_ref.npy')
50 |
51 | key = jax.random.PRNGKey(0)
52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-large')
53 |
54 | params = model.init(key, input_ids=input_ids)
55 | output = model.apply(params, input_ids=input_ids)
56 |
57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
58 |
59 | assert diff < 1e-4
60 |
61 |
62 | def test_reference_output_model_input_embds():
63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_input.npy')
64 |
65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-large/gpt2-large_model_input_embds_output_ref.npy')
66 |
67 | key = jax.random.PRNGKey(0)
68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-large')
69 |
70 | params = model.init(key, input_embds=input_embds)
71 | output = model.apply(params, input_embds=input_embds)
72 |
73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
74 |
75 | assert diff < 1e-4
76 |
77 |
--------------------------------------------------------------------------------
/tests/gpt2/test_gpt2_medium.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flaxmodels as fm
4 |
5 |
6 | def test_reference_output_lm_head_input_ids():
7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_input.npy')
8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_labels.npy')
9 |
10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_logits_ref.npy')
11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_ids_loss_ref.npy')
12 |
13 | key = jax.random.PRNGKey(0)
14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-medium')
15 |
16 | params = model.init(key, input_ids=input_ids, labels=labels)
17 | output = model.apply(params, input_ids=input_ids, labels=labels)
18 | logits, loss = output['logits'], output['loss']
19 |
20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
22 |
23 | assert diff_logits < 1e-3 and diff_loss < 1e-3
24 |
25 |
26 | def test_reference_output_lm_head_input_embds():
27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_input.npy')
28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_labels.npy')
29 |
30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_logits_ref.npy')
31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_lmhead_input_embds_loss_ref.npy')
32 |
33 | key = jax.random.PRNGKey(0)
34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-medium')
35 |
36 | params = model.init(key, input_embds=input_embds, labels=labels)
37 | output = model.apply(params, input_embds=input_embds, labels=labels)
38 | logits, loss = output['logits'], output['loss']
39 |
40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
42 |
43 | assert diff_logits < 1e-3 and diff_loss < 1e-3
44 |
45 |
46 | def test_reference_output_model_input_ids():
47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_input.npy')
48 |
49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_ids_output_ref.npy')
50 |
51 | key = jax.random.PRNGKey(0)
52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-medium')
53 |
54 | params = model.init(key, input_ids=input_ids)
55 | output = model.apply(params, input_ids=input_ids)
56 |
57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
58 |
59 | assert diff < 1e-4
60 |
61 |
62 | def test_reference_output_model_input_embds():
63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_input.npy')
64 |
65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-medium/gpt2-medium_model_input_embds_output_ref.npy')
66 |
67 | key = jax.random.PRNGKey(0)
68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-medium')
69 |
70 | params = model.init(key, input_embds=input_embds)
71 | output = model.apply(params, input_embds=input_embds)
72 |
73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
74 |
75 | assert diff < 1e-4
76 |
77 |
--------------------------------------------------------------------------------
/tests/gpt2/test_gpt2_xl.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flaxmodels as fm
4 |
5 |
6 | def test_reference_output_lm_head_input_ids():
7 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_input.npy')
8 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_labels.npy')
9 |
10 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_logits_ref.npy')
11 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_ids_loss_ref.npy')
12 |
13 | key = jax.random.PRNGKey(0)
14 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-xl')
15 |
16 | params = model.init(key, input_ids=input_ids, labels=labels)
17 | output = model.apply(params, input_ids=input_ids, labels=labels)
18 | logits, loss = output['logits'], output['loss']
19 |
20 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
21 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
22 |
23 | assert diff_logits < 1e-3 and diff_loss < 1e-3
24 |
25 |
26 | def test_reference_output_lm_head_input_embds():
27 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_input.npy')
28 | labels = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_labels.npy')
29 |
30 | ref_logits = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_logits_ref.npy')
31 | ref_loss = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_lmhead_input_embds_loss_ref.npy')
32 |
33 | key = jax.random.PRNGKey(0)
34 | model = fm.gpt2.GPT2LMHeadModel(pretrained='gpt2-xl')
35 |
36 | params = model.init(key, input_embds=input_embds, labels=labels)
37 | output = model.apply(params, input_embds=input_embds, labels=labels)
38 | logits, loss = output['logits'], output['loss']
39 |
40 | diff_logits = jnp.mean(jnp.abs(ref_logits - logits))
41 | diff_loss = jnp.mean(jnp.abs(ref_loss - loss))
42 |
43 | assert diff_logits < 1e-3 and diff_loss < 1e-3
44 |
45 |
46 | def test_reference_output_model_input_ids():
47 | input_ids = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_input.npy')
48 |
49 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_ids_output_ref.npy')
50 |
51 | key = jax.random.PRNGKey(0)
52 | model = fm.gpt2.GPT2Model(pretrained='gpt2-xl')
53 |
54 | params = model.init(key, input_ids=input_ids)
55 | output = model.apply(params, input_ids=input_ids)
56 |
57 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
58 |
59 | assert diff < 1e-4
60 |
61 |
62 | def test_reference_output_model_input_embds():
63 | input_embds = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_input.npy')
64 |
65 | ref_output = jnp.load(f'tests/gpt2/aux_files/gpt2-xl/gpt2-xl_model_input_embds_output_ref.npy')
66 |
67 | key = jax.random.PRNGKey(0)
68 | model = fm.gpt2.GPT2Model(pretrained='gpt2-xl')
69 |
70 | params = model.init(key, input_embds=input_embds)
71 | output = model.apply(params, input_embds=input_embds)
72 |
73 | diff = jnp.mean(jnp.abs(ref_output - output['last_hidden_state']))
74 |
75 | assert diff < 1e-4
76 |
77 |
--------------------------------------------------------------------------------
/tests/resnet/aux_files/resnet101_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet101_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/resnet/aux_files/resnet152_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet152_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/resnet/aux_files/resnet18_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet18_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/resnet/aux_files/resnet34_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet34_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/resnet/aux_files/resnet50_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/resnet/aux_files/resnet50_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/resnet/test_resnet101.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | resnet101 = fm.ResNet101(output='softmax', pretrained=None)
14 | params = resnet101.init(key, x)
15 | out = resnet101.apply(params, x, train=False)
16 |
17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
18 |
19 |
20 | def test_output_activations():
21 | # If output='activations', the output should be a dict.
22 | key = jax.random.PRNGKey(0)
23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
24 |
25 | resnet101 = fm.ResNet101(output='activations', pretrained=None)
26 | params = resnet101.init(key, x)
27 | out = resnet101.apply(params, x, train=False)
28 |
29 | assert isinstance(out, dict)
30 |
31 |
32 | def test_reference_output():
33 | key = jax.random.PRNGKey(0)
34 | img = Image.open('tests/aux_files/elefant.jpg')
35 | img = img.resize((224, 224))
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | x = jnp.expand_dims(x, axis=0)
38 |
39 | resnet101 = fm.ResNet101(output='logits', pretrained='imagenet')
40 | params = resnet101.init(key, x)
41 | out = resnet101.apply(params, x, train=False)
42 |
43 | out_ref = jnp.load('tests/resnet/aux_files/resnet101_elefant_output_ref.npy')
44 | diff = jnp.mean(jnp.abs(out - out_ref))
45 |
46 | assert diff < 1e-5
47 |
--------------------------------------------------------------------------------
/tests/resnet/test_resnet152.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | resnet152 = fm.ResNet152(output='softmax', pretrained=None)
14 | params = resnet152.init(key, x)
15 | out = resnet152.apply(params, x, train=False)
16 |
17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
18 |
19 |
20 | def test_output_activations():
21 | # If output='activations', the output should be a dict.
22 | key = jax.random.PRNGKey(0)
23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
24 |
25 | resnet152 = fm.ResNet152(output='activations', pretrained=None)
26 | params = resnet152.init(key, x)
27 | out = resnet152.apply(params, x, train=False)
28 |
29 | assert isinstance(out, dict)
30 |
31 |
32 | def test_reference_output():
33 | key = jax.random.PRNGKey(0)
34 | img = Image.open('tests/aux_files/elefant.jpg')
35 | img = img.resize((224, 224))
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | x = jnp.expand_dims(x, axis=0)
38 |
39 | resnet152 = fm.ResNet152(output='logits', pretrained='imagenet')
40 | params = resnet152.init(key, x)
41 | out = resnet152.apply(params, x, train=False)
42 |
43 | out_ref = jnp.load('tests/resnet/aux_files/resnet152_elefant_output_ref.npy')
44 | diff = jnp.mean(jnp.abs(out - out_ref))
45 |
46 | assert diff < 1e-5
47 |
48 |
--------------------------------------------------------------------------------
/tests/resnet/test_resnet18.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | resnet18 = fm.ResNet18(output='softmax', pretrained=None)
14 | params = resnet18.init(key, x, train=False)
15 | out, _ = resnet18.apply(params, x, mutable=['batch_stats'])
16 |
17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
18 |
19 |
20 | def test_output_activations():
21 | # If output='activations', the output should be a dict.
22 | key = jax.random.PRNGKey(0)
23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
24 |
25 | resnet18 = fm.ResNet18(output='activations', pretrained=None)
26 | params = resnet18.init(key, x, train=False)
27 | out, _ = resnet18.apply(params, x, mutable=['batch_stats'])
28 |
29 | assert isinstance(out, dict)
30 |
31 |
32 | def test_reference_output():
33 | key = jax.random.PRNGKey(0)
34 | img = Image.open('tests/aux_files/elefant.jpg')
35 | img = img.resize((224, 224))
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | x = jnp.expand_dims(x, axis=0)
38 |
39 | resnet18 = fm.ResNet18(output='logits', pretrained='imagenet')
40 | params = resnet18.init(key, x)
41 | out = resnet18.apply(params, x, train=False)
42 |
43 | out_ref = jnp.load('tests/resnet/aux_files/resnet18_elefant_output_ref.npy')
44 | diff = jnp.mean(jnp.abs(out - out_ref))
45 |
46 | assert diff < 1e-5
47 |
48 |
--------------------------------------------------------------------------------
/tests/resnet/test_resnet34.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | resnet34 = fm.ResNet34(output='softmax', pretrained=None)
14 | params = resnet34.init(key, x)
15 | out = resnet34.apply(params, x, train=False)
16 |
17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
18 |
19 |
20 | def test_output_activations():
21 | # If output='activations', the output should be a dict.
22 | key = jax.random.PRNGKey(0)
23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
24 |
25 | resnet34 = fm.ResNet34(output='activations', pretrained=None)
26 | params = resnet34.init(key, x)
27 | out = resnet34.apply(params, x, train=False)
28 |
29 | assert isinstance(out, dict)
30 |
31 |
32 | def test_reference_output():
33 | key = jax.random.PRNGKey(0)
34 | img = Image.open('tests/aux_files/elefant.jpg')
35 | img = img.resize((224, 224))
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | x = jnp.expand_dims(x, axis=0)
38 |
39 | resnet34 = fm.ResNet34(output='logits', pretrained='imagenet')
40 | params = resnet34.init(key, x)
41 | out = resnet34.apply(params, x, train=False)
42 |
43 | out_ref = jnp.load('tests/resnet/aux_files/resnet34_elefant_output_ref.npy')
44 | diff = jnp.mean(jnp.abs(out - out_ref))
45 |
46 | assert diff < 1e-5
47 |
48 |
--------------------------------------------------------------------------------
/tests/resnet/test_resnet50.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | resnet50 = fm.ResNet50(output='softmax', pretrained=None)
14 | params = resnet50.init(key, x)
15 | out = resnet50.apply(params, x, train=False)
16 |
17 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
18 |
19 |
20 | def test_output_activations():
21 | # If output='activations', the output should be a dict.
22 | key = jax.random.PRNGKey(0)
23 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
24 |
25 | resnet50 = fm.ResNet50(output='activations', pretrained=None)
26 | params = resnet50.init(key, x)
27 | out = resnet50.apply(params, x, train=False)
28 |
29 | assert isinstance(out, dict)
30 |
31 |
32 | def test_reference_output():
33 | key = jax.random.PRNGKey(0)
34 | img = Image.open('tests/aux_files/elefant.jpg')
35 | img = img.resize((224, 224))
36 | x = jnp.array(img, dtype=jnp.float32) / 255.0
37 | x = jnp.expand_dims(x, axis=0)
38 |
39 | resnet50 = fm.ResNet50(output='logits', pretrained='imagenet')
40 | params = resnet50.init(key, x)
41 | out = resnet50.apply(params, x, train=False)
42 |
43 | out_ref = jnp.load('tests/resnet/aux_files/resnet50_elefant_output_ref.npy')
44 | diff = jnp.mean(jnp.abs(out - out_ref))
45 |
46 | assert diff < 1e-5
47 |
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/car_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/car_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/car_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/car_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/cat_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cat_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/cat_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cat_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/church_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/church_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/church_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/church_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/horse_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/horse_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/horse_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/horse_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/discriminator/test_discriminator.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_reference_output_afhqcat():
9 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqcat_input_img.npy')
10 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqcat_output_ref.npy')
11 |
12 | key = jax.random.PRNGKey(0)
13 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqcat')
14 | params = discriminator.init(key, img)
15 | out = discriminator.apply(params, img)
16 |
17 | diff = jnp.mean(jnp.abs(out - out_ref))
18 |
19 | assert diff < 5e-4
20 |
21 |
22 | def test_reference_output_afhqdog():
23 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqdog_input_img.npy')
24 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqdog_output_ref.npy')
25 |
26 | key = jax.random.PRNGKey(0)
27 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqdog')
28 | params = discriminator.init(key, img)
29 | out = discriminator.apply(params, img)
30 |
31 | diff = jnp.mean(jnp.abs(out - out_ref))
32 |
33 | assert diff < 4e-3
34 |
35 |
36 | def test_reference_output_afhqwild():
37 | img = jnp.load('tests/stylegan2/discriminator/aux_files/afhqwild_input_img.npy')
38 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/afhqwild_output_ref.npy')
39 |
40 | key = jax.random.PRNGKey(0)
41 | discriminator = fm.stylegan2.Discriminator(pretrained='afhqwild')
42 | params = discriminator.init(key, img)
43 | out = discriminator.apply(params, img)
44 |
45 | diff = jnp.mean(jnp.abs(out - out_ref))
46 |
47 | assert diff < 4e-4
48 |
49 |
50 | def test_reference_output_brecahad():
51 | img = jnp.load('tests/stylegan2/discriminator/aux_files/brecahad_input_img.npy')
52 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/brecahad_output_ref.npy')
53 |
54 | key = jax.random.PRNGKey(0)
55 | discriminator = fm.stylegan2.Discriminator(pretrained='brecahad')
56 | params = discriminator.init(key, img)
57 | out = discriminator.apply(params, img)
58 |
59 | diff = jnp.mean(jnp.abs(out - out_ref))
60 |
61 | assert diff < 2e-4
62 |
63 |
64 | def test_reference_output_car():
65 | img = jnp.load('tests/stylegan2/discriminator/aux_files/car_input_img.npy')
66 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/car_output_ref.npy')
67 |
68 | key = jax.random.PRNGKey(0)
69 | discriminator = fm.stylegan2.Discriminator(pretrained='car')
70 | params = discriminator.init(key, img)
71 | out = discriminator.apply(params, img)
72 |
73 | diff = jnp.mean(jnp.abs(out - out_ref))
74 |
75 | assert diff < 3e-4
76 |
77 |
78 | def test_reference_output_cat():
79 | img = jnp.load('tests/stylegan2/discriminator/aux_files/cat_input_img.npy')
80 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/cat_output_ref.npy')
81 |
82 | key = jax.random.PRNGKey(0)
83 | discriminator = fm.stylegan2.Discriminator(pretrained='cat')
84 | params = discriminator.init(key, img)
85 | out = discriminator.apply(params, img)
86 |
87 | diff = jnp.mean(jnp.abs(out - out_ref))
88 |
89 | assert diff < 1e-4
90 |
91 |
92 | def test_reference_output_church():
93 | img = jnp.load('tests/stylegan2/discriminator/aux_files/church_input_img.npy')
94 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/church_output_ref.npy')
95 |
96 | key = jax.random.PRNGKey(0)
97 | discriminator = fm.stylegan2.Discriminator(pretrained='church')
98 | params = discriminator.init(key, img)
99 | out = discriminator.apply(params, img)
100 |
101 | diff = jnp.mean(jnp.abs(out - out_ref))
102 |
103 | assert diff < 1e-4
104 |
105 |
106 | def test_reference_output_cifar10():
107 | img = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_input_img.npy')
108 | label = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_input_label.npy')
109 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/cifar10_output_ref.npy')
110 |
111 | key = jax.random.PRNGKey(0)
112 | discriminator = fm.stylegan2.Discriminator(pretrained='cifar10')
113 | params = discriminator.init(key, img, label)
114 | out = discriminator.apply(params, img, label)
115 |
116 | diff = jnp.mean(jnp.abs(out - out_ref))
117 |
118 | assert diff < 1e-2
119 |
120 |
121 | def test_reference_output_ffhq():
122 | img = jnp.load('tests/stylegan2/discriminator/aux_files/ffhq_input_img.npy')
123 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/ffhq_output_ref.npy')
124 |
125 | key = jax.random.PRNGKey(0)
126 | discriminator = fm.stylegan2.Discriminator(pretrained='ffhq')
127 | params = discriminator.init(key, img)
128 | out = discriminator.apply(params, img)
129 |
130 | diff = jnp.mean(jnp.abs(out - out_ref))
131 |
132 | assert diff < 1e-4
133 |
134 |
135 | def test_reference_output_horse():
136 | img = jnp.load('tests/stylegan2/discriminator/aux_files/horse_input_img.npy')
137 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/horse_output_ref.npy')
138 |
139 | key = jax.random.PRNGKey(0)
140 | discriminator = fm.stylegan2.Discriminator(pretrained='horse')
141 | params = discriminator.init(key, img)
142 | out = discriminator.apply(params, img)
143 |
144 | diff = jnp.mean(jnp.abs(out - out_ref))
145 |
146 | assert diff < 1e-4
147 |
148 |
149 | def test_reference_output_metfaces():
150 | img = jnp.load('tests/stylegan2/discriminator/aux_files/metfaces_input_img.npy')
151 | out_ref = jnp.load('tests/stylegan2/discriminator/aux_files/metfaces_output_ref.npy')
152 |
153 | key = jax.random.PRNGKey(0)
154 | discriminator = fm.stylegan2.Discriminator(pretrained='metfaces')
155 | params = discriminator.init(key, img)
156 | out = discriminator.apply(params, img)
157 |
158 | diff = jnp.mean(jnp.abs(out - out_ref))
159 |
160 | assert diff < 2e-3
161 |
162 |
163 |
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqcat_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqcat_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqdog_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqdog_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqwild_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqwild_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/brecahad_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/brecahad_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/brecahad_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/brecahad_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/car_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/car_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/car_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/car_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/cat_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cat_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/cat_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cat_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/church_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/church_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/church_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/church_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/cifar10_input_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_input_label.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/cifar10_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/cifar10_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/cifar10_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/ffhq_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/ffhq_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/ffhq_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/ffhq_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/horse_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/horse_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/horse_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/horse_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/metfaces_input_z.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/metfaces_input_z.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/aux_files/metfaces_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/stylegan2/generator/aux_files/metfaces_output_ref.npy
--------------------------------------------------------------------------------
/tests/stylegan2/generator/test_generator.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_reference_output_afhqcat():
9 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqcat_input_z.npy')
10 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqcat_output_ref.npy')
11 |
12 | key = jax.random.PRNGKey(0)
13 | generator = fm.stylegan2.Generator(pretrained='afhqcat')
14 | params = generator.init(key, z)
15 | out = generator.apply(params, z, noise_mode='const', train=False)
16 |
17 | diff = jnp.mean(jnp.abs(out - out_ref))
18 |
19 | assert diff < 2e-2
20 |
21 |
22 | def test_reference_output_afhqdog():
23 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqdog_input_z.npy')
24 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqdog_output_ref.npy')
25 |
26 | key = jax.random.PRNGKey(0)
27 | generator = fm.stylegan2.Generator(pretrained='afhqdog')
28 | params = generator.init(key, z)
29 | out = generator.apply(params, z, noise_mode='const', train=False)
30 |
31 | diff = jnp.mean(jnp.abs(out - out_ref))
32 |
33 | assert diff < 5e-3
34 |
35 |
36 | def test_reference_output_afhqwild():
37 | z = jnp.load('tests/stylegan2/generator/aux_files/afhqwild_input_z.npy')
38 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/afhqwild_output_ref.npy')
39 |
40 | key = jax.random.PRNGKey(0)
41 | generator = fm.stylegan2.Generator(pretrained='afhqwild')
42 | params = generator.init(key, z)
43 | out = generator.apply(params, z, noise_mode='const', train=False)
44 |
45 | diff = jnp.mean(jnp.abs(out - out_ref))
46 |
47 | assert diff < 3e-3
48 |
49 |
50 | def test_reference_output_brecahad():
51 | z = jnp.load('tests/stylegan2/generator/aux_files/brecahad_input_z.npy')
52 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/brecahad_output_ref.npy')
53 |
54 | key = jax.random.PRNGKey(0)
55 | generator = fm.stylegan2.Generator(pretrained='brecahad')
56 | params = generator.init(key, z)
57 | out = generator.apply(params, z, noise_mode='const', train=False)
58 |
59 | diff = jnp.mean(jnp.abs(out - out_ref))
60 |
61 | assert diff < 3e-3
62 |
63 |
64 | def test_reference_output_car():
65 | z = jnp.load('tests/stylegan2/generator/aux_files/car_input_z.npy')
66 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/car_output_ref.npy')
67 |
68 | key = jax.random.PRNGKey(0)
69 | generator = fm.stylegan2.Generator(pretrained='car')
70 | params = generator.init(key, z)
71 | out = generator.apply(params, z, noise_mode='const', train=False)
72 |
73 | diff = jnp.mean(jnp.abs(out - out_ref))
74 |
75 | assert diff < 1e-5
76 |
77 |
78 | def test_reference_output_cat():
79 | z = jnp.load('tests/stylegan2/generator/aux_files/cat_input_z.npy')
80 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/cat_output_ref.npy')
81 |
82 | key = jax.random.PRNGKey(0)
83 | generator = fm.stylegan2.Generator(pretrained='cat')
84 | params = generator.init(key, z)
85 | out = generator.apply(params, z, noise_mode='const', train=False)
86 |
87 | diff = jnp.mean(jnp.abs(out - out_ref))
88 |
89 | assert diff < 1e-5
90 |
91 |
92 | def test_reference_output_church():
93 | z = jnp.load('tests/stylegan2/generator/aux_files/church_input_z.npy')
94 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/church_output_ref.npy')
95 |
96 | key = jax.random.PRNGKey(0)
97 | generator = fm.stylegan2.Generator(pretrained='church')
98 | params = generator.init(key, z)
99 | out = generator.apply(params, z, noise_mode='const', train=False)
100 |
101 | diff = jnp.mean(jnp.abs(out - out_ref))
102 |
103 | assert diff < 1e-5
104 |
105 |
106 | def test_reference_output_cifar10():
107 | z = jnp.load('tests/stylegan2/generator/aux_files/cifar10_input_z.npy')
108 | label = jnp.load('tests/stylegan2/generator/aux_files/cifar10_input_label.npy')
109 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/cifar10_output_ref.npy')
110 |
111 | key = jax.random.PRNGKey(0)
112 | generator = fm.stylegan2.Generator(pretrained='cifar10')
113 | params = generator.init(key, z, label)
114 | out = generator.apply(params, z, label, noise_mode='const', train=False)
115 |
116 | diff = jnp.mean(jnp.abs(out - out_ref))
117 |
118 | assert diff < 3e-3
119 |
120 |
121 | def test_reference_output_ffhq():
122 | z = jnp.load('tests/stylegan2/generator/aux_files/ffhq_input_z.npy')
123 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/ffhq_output_ref.npy')
124 |
125 | key = jax.random.PRNGKey(0)
126 | generator = fm.stylegan2.Generator(pretrained='ffhq')
127 | params = generator.init(key, z)
128 | out = generator.apply(params, z, noise_mode='const', train=False)
129 |
130 | diff = jnp.mean(jnp.abs(out - out_ref))
131 |
132 | assert diff < 1e-5
133 |
134 |
135 | def test_reference_output_horse():
136 | z = jnp.load('tests/stylegan2/generator/aux_files/horse_input_z.npy')
137 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/horse_output_ref.npy')
138 |
139 | key = jax.random.PRNGKey(0)
140 | generator = fm.stylegan2.Generator(pretrained='horse')
141 | params = generator.init(key, z)
142 | out = generator.apply(params, z, noise_mode='const', train=False)
143 |
144 | diff = jnp.mean(jnp.abs(out - out_ref))
145 |
146 | assert diff < 1e-5
147 |
148 |
149 | def test_reference_output_metfaces():
150 | z = jnp.load('tests/stylegan2/generator/aux_files/metfaces_input_z.npy')
151 | out_ref = jnp.load('tests/stylegan2/generator/aux_files/metfaces_output_ref.npy')
152 |
153 | key = jax.random.PRNGKey(0)
154 | generator = fm.stylegan2.Generator(pretrained='metfaces')
155 | params = generator.init(key, z)
156 | out = generator.apply(params, z, noise_mode='const', train=False)
157 |
158 | diff = jnp.mean(jnp.abs(out - out_ref))
159 |
160 | assert diff < 3e-4
161 |
162 |
163 |
--------------------------------------------------------------------------------
/tests/vgg/aux_files/vgg16_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/vgg/aux_files/vgg16_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/vgg/aux_files/vgg19_elefant_output_ref.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/tests/vgg/aux_files/vgg19_elefant_output_ref.npy
--------------------------------------------------------------------------------
/tests/vgg/test_vgg16.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | vgg16 = fm.VGG16(output='softmax', pretrained=None)
14 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
15 | params = vgg16.init(init_rngs, x)
16 | out = vgg16.apply(params, x, train=False)
17 |
18 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
19 |
20 |
21 | def test_output_activations():
22 | # If output='activations', the output should be a dict.
23 | key = jax.random.PRNGKey(0)
24 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
25 |
26 | vgg16 = fm.VGG16(output='activations', pretrained=None)
27 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
28 | params = vgg16.init(init_rngs, x)
29 | out = vgg16.apply(params, x, train=False)
30 |
31 | assert isinstance(out, dict)
32 |
33 |
34 | def test_include_head_true():
35 | # If include_head=True, the output should be a tensor of shape [B, 1000].
36 | key = jax.random.PRNGKey(0)
37 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
38 |
39 | vgg16 = fm.VGG16(include_head=True, pretrained=None)
40 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
41 | params = vgg16.init(init_rngs, x)
42 | out = vgg16.apply(params, x, train=False)
43 |
44 | assert hasattr(out, 'shape') and len(out.shape) == 2 and out.shape[0] == 1 and out.shape[1] == 1000
45 |
46 |
47 | def test_include_head_false():
48 | # If include_head=False, the output should be a tensor of shape [B, *, *, 512].
49 | key = jax.random.PRNGKey(0)
50 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
51 |
52 | vgg16 = fm.VGG16(include_head=False, pretrained=None)
53 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
54 | params = vgg16.init(init_rngs, x)
55 | out = vgg16.apply(params, x, train=False)
56 |
57 | assert hasattr(out, 'shape') and len(out.shape) == 4 and out.shape[0] == 1 and out.shape[-1] == 512
58 |
59 |
60 | def test_reference_output():
61 | key = jax.random.PRNGKey(0)
62 | img = Image.open('tests/aux_files/elefant.jpg')
63 | img = img.resize((224, 224))
64 | x = jnp.array(img, dtype=jnp.float32) / 255.0
65 | x = jnp.expand_dims(x, axis=0)
66 |
67 | vgg16 = fm.VGG16(output='logits', pretrained='imagenet')
68 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
69 | params = vgg16.init(init_rngs, x)
70 | out = vgg16.apply(params, x, train=False)
71 |
72 | out_ref = jnp.load('tests/vgg/aux_files/vgg16_elefant_output_ref.npy')
73 | diff = jnp.mean(jnp.abs(out - out_ref))
74 |
75 | assert diff < 1e-6
76 |
77 |
--------------------------------------------------------------------------------
/tests/vgg/test_vgg19.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import flaxmodels as fm
5 | from PIL import Image
6 |
7 |
8 | def test_output_softmax():
9 | # If output='softmax', the output should be in range [0, 1]
10 | key = jax.random.PRNGKey(0)
11 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
12 |
13 | vgg19 = fm.VGG19(output='softmax', pretrained=None)
14 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
15 | params = vgg19.init(init_rngs, x)
16 | out = vgg19.apply(params, x, train=False)
17 |
18 | assert jnp.min(out) >= 0.0 and jnp.max(out) <= 1.0
19 |
20 |
21 | def test_output_activations():
22 | # If output='activations', the output should be a dict.
23 | key = jax.random.PRNGKey(0)
24 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
25 |
26 | vgg19 = fm.VGG19(output='activations', pretrained=None)
27 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
28 | params = vgg19.init(init_rngs, x)
29 | out = vgg19.apply(params, x, train=False)
30 |
31 | assert isinstance(out, dict)
32 |
33 |
34 | def test_include_head_true():
35 | # If include_head=True, the output should be a tensor of shape [B, 1000].
36 | key = jax.random.PRNGKey(0)
37 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
38 |
39 | vgg19 = fm.VGG19(include_head=True, pretrained=None)
40 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
41 | params = vgg19.init(init_rngs, x)
42 | out = vgg19.apply(params, x, train=False)
43 |
44 | assert hasattr(out, 'shape') and len(out.shape) == 2 and out.shape[0] == 1 and out.shape[1] == 1000
45 |
46 |
47 | def test_include_head_false():
48 | # If include_head=False, the output should be a tensor of shape [B, *, *, 512].
49 | key = jax.random.PRNGKey(0)
50 | x = jax.random.uniform(key, shape=(1, 224, 224, 3), minval=0, maxval=1)
51 |
52 | vgg19 = fm.VGG19(include_head=False, pretrained=None)
53 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
54 | params = vgg19.init(init_rngs, x)
55 | out = vgg19.apply(params, x, train=False)
56 |
57 | assert hasattr(out, 'shape') and len(out.shape) == 4 and out.shape[0] == 1 and out.shape[-1] == 512
58 |
59 |
60 | def test_reference_output():
61 | key = jax.random.PRNGKey(0)
62 | img = Image.open('tests/aux_files/elefant.jpg')
63 | img = img.resize((224, 224))
64 | x = jnp.array(img, dtype=jnp.float32) / 255.0
65 | x = jnp.expand_dims(x, axis=0)
66 |
67 | vgg19 = fm.VGG19(output='logits', pretrained='imagenet')
68 | init_rngs = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
69 | params = vgg19.init(init_rngs, x)
70 | out = vgg19.apply(params, x, train=False)
71 |
72 | out_ref = jnp.load('tests/vgg/aux_files/vgg19_elefant_output_ref.npy')
73 | diff = jnp.mean(jnp.abs(out - out_ref))
74 |
75 | assert diff < 1e-5
76 |
77 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/README.md:
--------------------------------------------------------------------------------
1 | # Few-shot Image Generation via Cross-domain Correspondence Training in Jax/Flax
2 | This is the training code for the [Jax/Flax implementation](https://github.com/matthias-wright/flaxmodels/tree/main/flaxmodels/few_shot_gan_adaption) of [Few-shot Image Generation via Cross-domain Correspondence](https://arxiv.org/abs/2104.06820).
3 |
4 |
5 |
6 | #### Table of Contents
7 | * [Getting Started](#getting-started)
8 | * [Preparing Datasets for Training](#preparing-datasets-for-training)
9 | * [Training](#training)
10 | * [Checkpoints](#checkpoints)
11 | * [Generating Images](#generating-images)
12 | * [References](#references)
13 | * [License](#license)
14 |
15 |
16 | ## Getting Started
17 | You will need Python 3.7 or later.
18 |
19 | 1. Clone the repository:
20 | ```sh
21 | > git clone https://github.com/matthias-wright/flaxmodels.git
22 | ```
23 | 2. Go into the directory:
24 | ```sh
25 | > cd flaxmodels/training/few_shot_gan_adaption
26 | ```
27 | 3. Install Jax with CUDA.
28 | 4. Install requirements:
29 | ```sh
30 | > pip install -r requirements.txt
31 | ```
32 |
33 | ## Preparing Datasets for Training
34 | Before training, the images should be stored in a [TFRecord dataset](https://www.tensorflow.org/tutorials/load_data/tfrecord). The TFRecord format stores your data as a sequence of bytes, which allows for fast data loading.
35 | Alternatively, you can also use [tfds.folder_dataset.ImageFolder](https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder) on the image directory directly but you will have to replace the `tf.data.TFRecordDataset` in `data_pipeline.py` with `tfds.folder_dataset.ImageFolder` (see [this](https://github.com/matthias-wright/flaxmodels/issues/8#issue-1020780783) thread for more info).
36 |
37 | 1. Download dataset from [here](https://github.com/utkarshojha/few-shot-gan-adaptation#choose-the-target-domain).
38 | 2. Put all images into a directory:
39 | ```
40 | /path/to/image_dir/
41 | 0.jpg
42 | 1.jpg
43 | 2.jpg
44 | 4.jpg
45 | ...
46 | ```
47 | 3. Create TFRecord dataset:
48 | ```sh
49 | > python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
50 | ```
51 | `--image_dir` is the path to the image directory.
52 | `--data_dir` is the path where the TFRecord dataset is stored.
53 |
54 |
55 | ## Training
56 | Download checkpoint of source model:
57 | ```sh
58 | > wget https://www.dropbox.com/s/hyh1k8ixtzy24ye/ffhq_256x256.pickle\?dl\=1 -O ffhq_256x256.pickle
59 | ```
60 |
61 | Start training:
62 | ```python
63 | > CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord --source_ckpt_path ffhq_256x256.pickle
64 | ```
65 | Here `a`, `b`, `c`, `d` are the GPU indices. Multi GPU training (data parallelism) works by default and will automatically use all the devices that you make visible.
66 |
67 |
68 | ### Logging
69 | I use [Weights & Biases](https://wandb.ai/site) for logging but you can simply replace it with the logging method of your choice. The logging happens all in the training loop implemented in `training.py`. To use logging with Weights & Biases, use `--wand`.
70 |
71 | ### Checkpointing
72 | By default, every `1000` training steps the FID score is evaluated for `10.000` images. The checkpoint with the highest FID score is saved. You can change evaluation frequency using the `--eval_fid_every` argument and the number of images to evaluate the FID score on using `--num_fid_images`.
73 | You can disable the FID score evaluation using `--disable_fid`. In that case, a checkpoint will be saved every `2000` steps (can be changed using `--save_every`).
74 |
75 |
76 | ## Checkpoints
77 | * [Sketches](https://www.dropbox.com/s/azr6b316juhme6c/sketches.pickle?dl=1) (357,2 MB)
78 | * [Amedeo Modigliani](https://www.dropbox.com/s/xrh4a7wt2kggn4v/amedeo_modigliani.pickle?dl=1) (357,2 MB)
79 | * [Babies](https://www.dropbox.com/s/ntictyzrisqg5zh/babies.pickle?dl=1) (357,2 MB)
80 | * [Otto Dix](https://www.dropbox.com/s/u1g18nv73uac21m/otto_dix.pickle?dl=1) (357,2 MB)
81 | * [Rafael](https://www.dropbox.com/s/b8w928s4wffuo2c/raphael.pickle?dl=1) (357,2 MB)
82 |
83 | ## Generating Images
84 | ```python
85 | import jax
86 | import numpy as np
87 | import dill as pickle
88 | from PIL import Image
89 |
90 | import flaxmodels as fm
91 |
92 | ckpt = pickle.load(open('sketches.pickle', 'rb'))
93 | params = ckpt['params_ema_G']
94 |
95 | generator = fm.few_shot_gan_adaption.Generator()
96 |
97 | # Seed
98 | key = jax.random.PRNGKey(0)
99 |
100 | # Input noise
101 | z = jax.random.normal(key, shape=(4, 512))
102 |
103 | # Generate images
104 | images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const')
105 |
106 | # Normalize images to be in range [0, 1]
107 | images = (images - np.min(images)) / (np.max(images) - np.min(images))
108 |
109 | # Save images
110 | for i in range(images.shape[0]):
111 | Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')
112 |
113 | ```
114 |
115 | ## References
116 | * [utkarshojha/few-shot-gan-adaptation](https://github.com/utkarshojha/few-shot-gan-adaptation)
117 |
118 |
119 | ## License
120 | [MIT License](https://opensource.org/licenses/MIT)
121 |
122 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/checkpoint.py:
--------------------------------------------------------------------------------
1 | import flax
2 | import dill as pickle
3 | import os
4 | import glob
5 |
6 |
7 | def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, config, step, z_latent_anchors, keep=2):
8 | """
9 | Saves checkpoint.
10 |
11 | Args:
12 | ckpt_dir (str): Path to the directory, where checkpoints are saved.
13 | state_G (train_state.TrainState): Generator state.
14 | state_D (train_state.TrainState): Discriminator state.
15 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
16 | config (argparse.Namespace): Configuration.
17 | step (int): Current step.
18 | z_latent_anchors (DeviceArray): Noise anchors.
19 | keep (int): Number of checkpoints to keep.
20 | """
21 |
22 | state_G = flax.jax_utils.unreplicate(state_G)
23 | state_D = flax.jax_utils.unreplicate(state_D)
24 | params_G = {'params': {'mapping_network': state_G.params.unfreeze()['mapping'], 'synthesis_network': state_G.params.unfreeze()['synthesis']},
25 | 'moving_stats': {'mapping_network': state_G.moving_stats},
26 | 'noise_consts': {'synthesis_network': state_G.noise_consts}}
27 |
28 | ckpt_dict = {'params_G': params_G,
29 | 'params_D': state_D.params,
30 | 'params_ema_G': params_ema_G,
31 | 'z_latent_anchors': z_latent_anchors,
32 | 'config': config}
33 |
34 | with open(os.path.join(ckpt_dir, f'ckpt_{step}.pickle'), 'wb') as handle:
35 | pickle.dump(ckpt_dict, handle, protocol=pickle.DEFAULT_PROTOCOL)
36 |
37 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle'))
38 | if len(ckpts) > keep:
39 | oldest_ckpt = min(ckpts, key=os.path.getctime)
40 | os.remove(oldest_ckpt)
41 |
42 |
43 | def load_checkpoint(filename, replicate=True):
44 | """
45 | Loads checkpoints.
46 |
47 | Args:
48 | filename (str): Path to the checkpoint file.
49 | replicate (bool): If True, replicate parameters across devices.
50 |
51 | Returns:
52 | (dict): Checkpoint.
53 | """
54 | state_dict = pickle.load(open(filename, 'rb'))
55 | if replicate:
56 | state_dict['state_G'] = flax.jax_utils.replicate(state_dict['state_G'])
57 | state_dict['state_D'] = flax.jax_utils.replicate(state_dict['state_D'])
58 | state_dict['pl_mean'] = flax.jax_utils.replicate(state_dict['pl_mean'])
59 | return state_dict
60 |
61 |
62 | def get_latest_checkpoint(ckpt_dir):
63 | """
64 | Returns the path of the latest checkpoint.
65 |
66 | Args:
67 | ckpt_dir (str): Path to the directory, where checkpoints are saved.
68 |
69 | Returns:
70 | (str): Path to latest checkpoint (if it exists).
71 | """
72 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle'))
73 | if len(ckpts) == 0:
74 | return None
75 | latest_ckpt = max(ckpts, key=os.path.getctime)
76 | return latest_ckpt
77 |
78 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/data_pipeline.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_datasets as tfds
3 | import jax
4 | import flax
5 | import numpy as np
6 | from PIL import Image
7 | import os
8 | from typing import Sequence
9 | from tqdm import tqdm
10 | import json
11 | from tqdm import tqdm
12 |
13 |
14 | def prefetch(dataset, n_prefetch):
15 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
16 | ds_iter = iter(dataset)
17 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
18 | ds_iter)
19 | if n_prefetch:
20 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
21 | return ds_iter
22 |
23 |
24 | def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
25 | """
26 |
27 | Args:
28 | data_dir (str): Root directory of the dataset.
29 | img_size (int): Image size for training.
30 | img_channels (int): Number of image channels.
31 | num_classes (int): Number of classes, 0 for no classes.
32 | num_devices (int): Number of devices.
33 | batch_size (int): Batch size (per device).
34 | shuffle_buffer (int): Buffer used for shuffling the dataset.
35 |
36 | Returns:
37 | (tf.data.Dataset): Dataset.
38 | """
39 |
40 | def pre_process(serialized_example):
41 | feature = {'height': tf.io.FixedLenFeature([], tf.int64),
42 | 'width': tf.io.FixedLenFeature([], tf.int64),
43 | 'channels': tf.io.FixedLenFeature([], tf.int64),
44 | 'image': tf.io.FixedLenFeature([], tf.string),
45 | 'label': tf.io.FixedLenFeature([], tf.int64)}
46 | example = tf.io.parse_single_example(serialized_example, feature)
47 |
48 | height = tf.cast(example['height'], dtype=tf.int64)
49 | width = tf.cast(example['width'], dtype=tf.int64)
50 | channels = tf.cast(example['channels'], dtype=tf.int64)
51 |
52 | image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
53 | image = tf.reshape(image, shape=[height, width, channels])
54 |
55 | image = tf.cast(image, dtype='float32')
56 | image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
57 | image = tf.image.random_flip_left_right(image)
58 |
59 | image = (image - 127.5) / 127.5
60 |
61 | label = tf.one_hot(example['label'], num_classes)
62 | return {'image': image, 'label': label}
63 |
64 | def shard(data):
65 | # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
66 | # because the first dimension will be mapped across devices using jax.pmap
67 | data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
68 | data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
69 | return data
70 |
71 | print('Loading TFRecord...')
72 | with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
73 | dataset_info = json.load(fin)
74 |
75 | ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
76 |
77 | ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
78 | ds = ds.map(pre_process, tf.data.AUTOTUNE)
79 | ds = ds.batch(batch_size * num_devices, drop_remainder=True)
80 | ds = ds.map(shard, tf.data.AUTOTUNE)
81 | ds = ds.prefetch(1)
82 | return ds, dataset_info
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/dataset_utils/images_to_tfrecords.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from PIL import Image
4 | from typing import Sequence
5 | from tqdm import tqdm
6 | import argparse
7 | import json
8 | import os
9 |
10 |
11 | def images_to_tfrecords(image_dir, data_dir, has_labels):
12 | """
13 | Converts a folder of images to a TFRecord file.
14 |
15 | The image directory should have one of the following structures:
16 |
17 | If has_labels = False, image_dir should look like this:
18 |
19 | path/to/image_dir/
20 | 0.jpg
21 | 1.jpg
22 | 2.jpg
23 | 4.jpg
24 | ...
25 |
26 |
27 | If has_labels = True, image_dir should look like this:
28 |
29 | path/to/image_dir/
30 | label0/
31 | 0.jpg
32 | 1.jpg
33 | ...
34 | label1/
35 | a.jpg
36 | b.jpg
37 | c.jpg
38 | ...
39 | ...
40 |
41 |
42 | The labels will be label0 -> 0, label1 -> 1.
43 |
44 | Args:
45 | image_dir (str): Path to images.
46 | data_dir (str): Path where the TFrecords dataset is stored.
47 | has_labels (bool): If True, 'image_dir' contains label directories.
48 |
49 | Returns:
50 | (dict): Dataset info.
51 | """
52 |
53 | def _bytes_feature(value):
54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
55 |
56 | def _int64_feature(value):
57 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
58 |
59 | os.makedirs(data_dir, exist_ok=True)
60 | writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
61 |
62 | num_examples = 0
63 | num_classes = 0
64 |
65 | if has_labels:
66 | for label_dir in os.listdir(image_dir):
67 | if not os.path.isdir(os.path.join(image_dir, label_dir)):
68 | print('The image directory should contain one directory for each label.')
69 | print('These label directories should contain the image files.')
70 | if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
71 | os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
72 | return
73 |
74 | for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
75 | file_format = img_file[img_file.rfind('.') + 1:]
76 | if file_format not in ['png', 'jpg', 'jpeg']:
77 | continue
78 |
79 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
80 | img = Image.open(os.path.join(image_dir, label_dir, img_file))
81 | img = np.array(img, dtype=np.uint8)
82 |
83 | height = img.shape[0]
84 | width = img.shape[1]
85 | channels = img.shape[2]
86 |
87 | img_encoded = img.tobytes()
88 |
89 | example = tf.train.Example(features=tf.train.Features(feature={
90 | 'height': _int64_feature(height),
91 | 'width': _int64_feature(width),
92 | 'channels': _int64_feature(channels),
93 | 'image': _bytes_feature(img_encoded),
94 | 'label': _int64_feature(num_classes)}))
95 |
96 | writer.write(example.SerializeToString())
97 | num_examples += 1
98 |
99 | num_classes += 1
100 | else:
101 | for img_file in tqdm(os.listdir(os.path.join(image_dir))):
102 | file_format = img_file[img_file.rfind('.') + 1:]
103 | if file_format not in ['png', 'jpg', 'jpeg']:
104 | continue
105 |
106 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
107 | img = Image.open(os.path.join(image_dir, img_file))
108 | img = np.array(img, dtype=np.uint8)
109 |
110 | height = img.shape[0]
111 | width = img.shape[1]
112 | channels = img.shape[2]
113 |
114 | img_encoded = img.tobytes()
115 |
116 | example = tf.train.Example(features=tf.train.Features(feature={
117 | 'height': _int64_feature(height),
118 | 'width': _int64_feature(width),
119 | 'channels': _int64_feature(channels),
120 | 'image': _bytes_feature(img_encoded),
121 | 'label': _int64_feature(num_classes)})) # dummy label
122 |
123 | writer.write(example.SerializeToString())
124 | num_examples += 1
125 |
126 | writer.close()
127 |
128 | dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
129 | with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
130 | json.dump(dataset_info, fout)
131 |
132 |
133 | if __name__ == '__main__':
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
136 | parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
137 | parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
138 |
139 | args = parser.parse_args()
140 |
141 | images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
142 |
143 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/datatest.py:
--------------------------------------------------------------------------------
1 | import data_pipeline
2 |
3 | ds_train, dataset_info = data_pipeline.get_data(data_dir='datasets/Sketches',
4 | img_size=256,
5 | img_channels=3,
6 | num_classes=0,
7 | num_devices=1,
8 | batch_size=5)
9 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/fid/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import FID
2 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/fid/core.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax
4 | import flax.linen as nn
5 | import numpy as np
6 | import os
7 | import functools
8 | import argparse
9 | import scipy
10 | from tqdm import tqdm
11 |
12 | from . import inception
13 | from . import utils
14 |
15 |
16 | class FID:
17 |
18 | def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
19 | """
20 | Evaluates the FID score for a given generator and a given dataset.
21 | Implementation mostly taken from https://github.com/matthias-wright/jax-fid
22 |
23 | Reference: https://arxiv.org/abs/1706.08500
24 |
25 | Args:
26 | generator (nn.Module): Generator network.
27 | dataset (tf.data.Dataset): Dataset containing the real images.
28 | config (argparse.Namespace): Configuration.
29 | use_cache (bool): If True, only compute the activation stats once for the real images and store them.
30 | truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
31 | """
32 | self.num_images = config.num_fid_images
33 | self.batch_size = config.batch_size
34 | self.c_dim = config.c_dim
35 | self.z_dim = config.z_dim
36 | self.dataset = dataset
37 | self.num_devices = jax.device_count()
38 | self.use_cache = use_cache
39 |
40 | if self.use_cache:
41 | self.cache = {}
42 |
43 | rng = jax.random.PRNGKey(0)
44 | inception_net = inception.InceptionV3(pretrained=True)
45 | self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
46 | self.inception_params = flax.jax_utils.replicate(self.inception_params)
47 | #self.inception = jax.jit(functools.partial(model.apply, train=False))
48 | self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
49 |
50 | self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
51 |
52 | def compute_fid(self, generator_params, seed_offset=0):
53 | generator_params = flax.jax_utils.replicate(generator_params)
54 | mu_real, sigma_real = self.compute_stats_for_dataset()
55 | mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
56 | fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
57 | return fid_score
58 |
59 | def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
60 | # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
61 | mu1 = np.atleast_1d(mu1)
62 | mu2 = np.atleast_1d(mu2)
63 | sigma1 = np.atleast_1d(sigma1)
64 | sigma2 = np.atleast_1d(sigma2)
65 |
66 | assert mu1.shape == mu2.shape
67 | assert sigma1.shape == sigma2.shape
68 |
69 | diff = mu1 - mu2
70 |
71 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
72 | if not np.isfinite(covmean).all():
73 | msg = ('fid calculation produces singular product; '
74 | 'adding %s to diagonal of cov estimates') % eps
75 | print(msg)
76 | offset = np.eye(sigma1.shape[0]) * eps
77 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
78 |
79 | # Numerical error might give slight imaginary component
80 | if np.iscomplexobj(covmean):
81 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
82 | m = np.max(np.abs(covmean.imag))
83 | raise ValueError('Imaginary component {}'.format(m))
84 | covmean = covmean.real
85 |
86 | tr_covmean = np.trace(covmean)
87 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
88 |
89 | def compute_stats_for_dataset(self):
90 | if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
91 | print('Use cached statistics for dataset...')
92 | return self.cache['mu'], self.cache['sigma']
93 |
94 | print()
95 | print('Compute statistics for dataset...')
96 | pbar = tqdm(total=self.num_images)
97 | image_count = 0
98 |
99 | activations = []
100 | for batch in utils.prefetch(self.dataset, n_prefetch=2):
101 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
102 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1))
103 | activations.append(act)
104 |
105 | pbar.update(self.num_devices * self.batch_size)
106 | image_count += self.num_devices * self.batch_size
107 | if image_count >= self.num_images:
108 | break
109 | pbar.close()
110 |
111 | activations = jnp.concatenate(activations, axis=0)
112 | activations = activations[:self.num_images]
113 | mu = np.mean(activations, axis=0)
114 | sigma = np.cov(activations, rowvar=False)
115 | self.cache['mu'] = mu
116 | self.cache['sigma'] = sigma
117 | return mu, sigma
118 |
119 | def compute_stats_for_generator(self, generator_params, seed_offset):
120 | print()
121 | print('Compute statistics for generator...')
122 | num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_devices)))
123 |
124 | pbar = tqdm(total=self.num_images)
125 | activations = []
126 |
127 | for i in range(num_batches):
128 | rng = jax.random.PRNGKey(seed_offset + i)
129 | z_latent = jax.random.normal(rng, shape=(self.num_devices, self.batch_size, self.z_dim))
130 |
131 | labels = None
132 | if self.c_dim > 0:
133 | labels = jax.random.randint(rng, shape=(self.num_devices * self.batch_size,), minval=0, maxval=self.c_dim)
134 | labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
135 | labels = jnp.reshape(labels, (self.num_devices, self.batch_size, self.c_dim))
136 |
137 | image, _ = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
138 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
139 |
140 | image = 2 * image - 1
141 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
142 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1))
143 | activations.append(act)
144 | pbar.update(self.num_devices * self.batch_size)
145 | pbar.close()
146 |
147 | activations = jnp.concatenate(activations, axis=0)
148 | activations = activations[:self.num_images]
149 | mu = np.mean(activations, axis=0)
150 | sigma = np.cov(activations, rowvar=False)
151 | return mu, sigma
152 |
153 |
154 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/fid/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import flax
3 | import numpy as np
4 | from tqdm import tqdm
5 | import requests
6 | import os
7 | import tempfile
8 |
9 |
10 | def download(url, ckpt_dir=None):
11 | name = url[url.rfind('/') + 1 : url.rfind('?')]
12 | if ckpt_dir is None:
13 | ckpt_dir = tempfile.gettempdir()
14 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
15 | ckpt_file = os.path.join(ckpt_dir, name)
16 | if not os.path.exists(ckpt_file):
17 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
18 | if not os.path.exists(ckpt_dir):
19 | os.makedirs(ckpt_dir)
20 |
21 | response = requests.get(url, stream=True)
22 | total_size_in_bytes = int(response.headers.get('content-length', 0))
23 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
24 |
25 | # first create temp file, in case the download fails
26 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
27 | with open(ckpt_file_temp, 'wb') as file:
28 | for data in response.iter_content(chunk_size=1024):
29 | progress_bar.update(len(data))
30 | file.write(data)
31 | progress_bar.close()
32 |
33 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
34 | print('An error occured while downloading, please try again.')
35 | if os.path.exists(ckpt_file_temp):
36 | os.remove(ckpt_file_temp)
37 | else:
38 | # if download was successful, rename the temp file
39 | os.rename(ckpt_file_temp, ckpt_file)
40 | return ckpt_file
41 |
42 |
43 | def get(dictionary, key):
44 | if dictionary is None or key not in dictionary:
45 | return None
46 | return dictionary[key]
47 |
48 |
49 | def prefetch(dataset, n_prefetch):
50 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
51 | ds_iter = iter(dataset)
52 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
53 | ds_iter)
54 | if n_prefetch:
55 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
56 | return ds_iter
57 |
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/images/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/few_shot_gan_adaption/images/overview.jpg
--------------------------------------------------------------------------------
/training/few_shot_gan_adaption/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import jax
4 | import wandb
5 | import training
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | # Paths
11 | parser.add_argument('--work_dir', type=str, default='logging', help='Directory for logging and checkpoints.')
12 | parser.add_argument('--data_dir', type=str, default='datasets/Sketches', help='Directory of the dataset.')
13 | parser.add_argument('--project', type=str, default='few-shot-adapt', help='Name of this project.')
14 | parser.add_argument('--name', type=str, default='default', help='Name of this experiment.')
15 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment (for Weights&Biases).')
16 | parser.add_argument('--source_ckpt_path', type=str, default='ffhq_256x256.pickle', help='Path to the checkpoint of the source model.')
17 | # Training
18 | parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.')
19 | parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.')
20 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size.')
21 | parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.')
22 | parser.add_argument('--resolution', type=int, default=256, help='Image resolution. Must be a multiple of 2.')
23 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
24 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
25 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
26 | parser.add_argument('--subspace_std', type=float, default=0.1, help='Std for sampling z_latents around the anchor points.')
27 | parser.add_argument('--subspace_freq', type=int, default=4, help='Frequency for sampling z_latents from anchor regions.')
28 | # Generator
29 | parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.')
30 | # Discriminator
31 | parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.')
32 | # Exponentially Moving Average of Generator Weights
33 | parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).')
34 | # Losses
35 | parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).')
36 | parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.')
37 | parser.add_argument('--kl_weight', type=float, default=1000.0, help='Weight for distance consistency loss.')
38 | # Regularization
39 | parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.')
40 | parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.')
41 | parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.')
42 | parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.')
43 | # Model
44 | parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.')
45 | # Logging
46 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.')
47 | parser.add_argument('--log_every', type=int, default=50, help='Log every log_every steps.')
48 | parser.add_argument('--save_every', type=int, default=1000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.')
49 | # FID
50 | parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.')
51 | parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.')
52 | parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.')
53 |
54 | args = parser.parse_args()
55 |
56 | if jax.process_index() == 0:
57 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints')
58 | if not os.path.exists(args.ckpt_dir):
59 | os.makedirs(args.ckpt_dir)
60 |
61 | if args.wandb:
62 | wandb.init(project=args.project,
63 | group=args.group,
64 | config=args,
65 | name=args.name,
66 | dir=os.path.join(args.work_dir, args.group, args.name))
67 |
68 | training.train_and_evaluate(args)
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
74 |
--------------------------------------------------------------------------------
/training/resnet/README.md:
--------------------------------------------------------------------------------
1 | # ResNet Training
2 |
3 | ##### Table of Contents
4 | * [Getting Started](#getting_started)
5 | * [Training](#training)
6 | * [Options](#options)
7 | * [Results](#results)
8 | * [References](#references)
9 | * [License](#license)
10 |
11 |
12 |
13 | ## Getting Started
14 | You will need Python 3.7 or later.
15 |
16 | 1. Clone the repository:
17 | ```sh
18 | > git clone https://github.com/matthias-wright/flaxmodels.git
19 | ```
20 | 2. Go into the directory:
21 | ```sh
22 | > cd flaxmodels/training/resnet
23 | ```
24 | 3. Install Jax with CUDA.
25 | 4. Install requirements:
26 | ```sh
27 | > pip install -r requirements.txt
28 | ```
29 |
30 |
31 | ## Training
32 |
33 | ### Basic Training
34 | ```python
35 | CUDA_VISIBLE_DEVICES=0 python main.py
36 | ```
37 |
38 | ### Multi GPU Training
39 | The script will automatically use all the visible GPUs for distributed training.
40 | ```python
41 | CUDA_VISIBLE_DEVICES=0,1 python main.py
42 | ```
43 |
44 | ### Mixed-Precision Training
45 | ```python
46 | CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision
47 | ```
48 |
49 |
50 | ## Options
51 | * `--work_dir` - Path to directory for logging and checkpoints (str).
52 | * `--data_dir` - Path for storing the dataset (str).
53 | * `--name` - Name of the training run (str).
54 | * `--group` - Group name of the training run (str).
55 | * `--arch` - Architecture (str). Options: resnet18, resnet34, resnet50, resnet101, resnet152.
56 | * `--resume` - Resume training from best checkpoint (bool).
57 | * `--num_epochs` - Number of epochs (int).
58 | * `--learning_rate` - Learning rate (float).
59 | * `--warmup_epochs` - Number of warmup epochs with lower learning rate (int).
60 | * `--batch_size` - Batch size (int).
61 | * `--num_classes` - Number of classes (int).
62 | * `--img_size` - Image size (int).
63 | * `--img_channels` - Number of image channels (int).
64 | * `--mixed_precision` - Use mixed precision training (bool).
65 | * `--random_seed` - Random seed (int).
66 | * `--wandb` - Use Weights&Biases for logging (bool).
67 | * `--log_every` - Log every log_every steps (int).
68 |
69 |
70 |
71 | ## Results
72 | ResNet18 was trained on the Imagenette dataset. The validation accuracy is around 90%.
73 |
74 | * Images were resized to 256x256 (random crops for training and center crops for evaluation).
75 | * Data augmentation: flipping, brightness, hue, contrast.
76 | * Learning rate schedule: Cosine Annealing.
77 | * Training was done from scratch, no transfer learning.
78 |
79 |
80 |
81 | ## References
82 | * [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
83 | * [pytorch/vision/torchvision/models/resnet.py](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py)
84 | * [google/flax/examples](https://github.com/google/flax/tree/main/examples)
85 |
86 |
87 |
88 | ## License
89 | [MIT License](https://opensource.org/licenses/MIT)
90 |
91 |
92 |
--------------------------------------------------------------------------------
/training/resnet/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import jax
4 | import wandb
5 | import training
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | # Paths
11 | parser.add_argument('--work_dir', type=str, default='/export/scratch/mwright/projects/misc/imagenette', help='Directory for logging and checkpoints.')
12 | parser.add_argument('--data_dir', type=str, default='/export/data/mwright/tensorflow_datasets', help='Directory for storing data.')
13 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.')
14 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment.')
15 | # Training
16 | parser.add_argument('--arch', type=str, default='resnet18', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], help='Architecture.')
17 | parser.add_argument('--resume', action='store_true', help='Resume training from best checkpoint.')
18 | parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs.')
19 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.')
20 | parser.add_argument('--warmup_epochs', type=int, default=9, help='Number of warmup epochs with lower learning rate.')
21 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
22 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes.')
23 | parser.add_argument('--img_size', type=int, default=224, help='Image size.')
24 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
25 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
26 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
27 | # Logging
28 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.')
29 | parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.')
30 | args = parser.parse_args()
31 |
32 | if jax.process_index() == 0:
33 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints')
34 | if not os.path.exists(args.ckpt_dir):
35 | os.makedirs(args.ckpt_dir)
36 |
37 | if args.wandb:
38 | wandb.init(entity='matthias-wright',
39 | project='imagenette',
40 | group=args.group,
41 | config=args,
42 | name=args.name,
43 | dir=os.path.join(args.work_dir, args.group, args.name))
44 |
45 | training.train_and_evaluate(args)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/training/resnet/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | git+https://github.com/matthias-wright/flaxmodels.git
3 | tensorflow-datasets
4 | optax
5 | argparse
6 | wandb
7 | tqdm
8 |
--------------------------------------------------------------------------------
/training/stylegan2/checkpoint.py:
--------------------------------------------------------------------------------
1 | import flax
2 | import dill as pickle
3 | import os
4 | import glob
5 |
6 |
7 | def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
8 | """
9 | Saves checkpoint.
10 |
11 | Args:
12 | ckpt_dir (str): Path to the directory, where checkpoints are saved.
13 | state_G (train_state.TrainState): Generator state.
14 | state_D (train_state.TrainState): Discriminator state.
15 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
16 | pl_mean (array): Moving average of the path length (generator regularization).
17 | config (argparse.Namespace): Configuration.
18 | step (int): Current step.
19 | epoch (int): Current epoch.
20 | fid_score (float): FID score corresponding to the checkpoint.
21 | keep (int): Number of checkpoints to keep.
22 | """
23 | state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
24 | 'state_D': flax.jax_utils.unreplicate(state_D),
25 | 'params_ema_G': params_ema_G,
26 | 'pl_mean': flax.jax_utils.unreplicate(pl_mean),
27 | 'config': config,
28 | 'fid_score': fid_score,
29 | 'step': step,
30 | 'epoch': epoch}
31 |
32 | with open(os.path.join(ckpt_dir, f'ckpt_{step}.pickle'), 'wb') as handle:
33 | pickle.dump(state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
34 |
35 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle'))
36 | if len(ckpts) > keep:
37 | oldest_ckpt = min(ckpts, key=os.path.getctime)
38 | os.remove(oldest_ckpt)
39 |
40 |
41 | def load_checkpoint(filename):
42 | """
43 | Loads checkpoints.
44 |
45 | Args:
46 | filename (str): Path to the checkpoint file.
47 |
48 | Returns:
49 | (dict): Checkpoint.
50 | """
51 | state_dict = pickle.load(open(filename, 'rb'))
52 | state_dict['state_G'] = flax.jax_utils.replicate(state_dict['state_G'])
53 | state_dict['state_D'] = flax.jax_utils.replicate(state_dict['state_D'])
54 | state_dict['pl_mean'] = flax.jax_utils.replicate(state_dict['pl_mean'])
55 | return state_dict
56 |
57 |
58 | def get_latest_checkpoint(ckpt_dir):
59 | """
60 | Returns the path of the latest checkpoint.
61 |
62 | Args:
63 | ckpt_dir (str): Path to the directory, where checkpoints are saved.
64 |
65 | Returns:
66 | (str): Path to latest checkpoint (if it exists).
67 | """
68 | ckpts = glob.glob(os.path.join(ckpt_dir, '*.pickle'))
69 | if len(ckpts) == 0:
70 | return None
71 | latest_ckpt = max(ckpts, key=os.path.getctime)
72 | return latest_ckpt
73 |
74 |
--------------------------------------------------------------------------------
/training/stylegan2/data_pipeline.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_datasets as tfds
3 | import jax
4 | import flax
5 | import numpy as np
6 | from PIL import Image
7 | import os
8 | from typing import Sequence
9 | from tqdm import tqdm
10 | import json
11 | from tqdm import tqdm
12 |
13 |
14 | def prefetch(dataset, n_prefetch):
15 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
16 | ds_iter = iter(dataset)
17 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
18 | ds_iter)
19 | if n_prefetch:
20 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
21 | return ds_iter
22 |
23 |
24 | def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
25 | """
26 |
27 | Args:
28 | data_dir (str): Root directory of the dataset.
29 | img_size (int): Image size for training.
30 | img_channels (int): Number of image channels.
31 | num_classes (int): Number of classes, 0 for no classes.
32 | num_devices (int): Number of devices.
33 | batch_size (int): Batch size (per device).
34 | shuffle_buffer (int): Buffer used for shuffling the dataset.
35 |
36 | Returns:
37 | (tf.data.Dataset): Dataset.
38 | """
39 |
40 | def pre_process(serialized_example):
41 | feature = {'height': tf.io.FixedLenFeature([], tf.int64),
42 | 'width': tf.io.FixedLenFeature([], tf.int64),
43 | 'channels': tf.io.FixedLenFeature([], tf.int64),
44 | 'image': tf.io.FixedLenFeature([], tf.string),
45 | 'label': tf.io.FixedLenFeature([], tf.int64)}
46 | example = tf.io.parse_single_example(serialized_example, feature)
47 |
48 | height = tf.cast(example['height'], dtype=tf.int64)
49 | width = tf.cast(example['width'], dtype=tf.int64)
50 | channels = tf.cast(example['channels'], dtype=tf.int64)
51 |
52 | image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
53 | image = tf.reshape(image, shape=[height, width, channels])
54 |
55 | image = tf.cast(image, dtype='float32')
56 | image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
57 | image = tf.image.random_flip_left_right(image)
58 |
59 | image = (image - 127.5) / 127.5
60 |
61 | label = tf.one_hot(example['label'], num_classes)
62 | return {'image': image, 'label': label}
63 |
64 | def shard(data):
65 | # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
66 | # because the first dimension will be mapped across devices using jax.pmap
67 | data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
68 | data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
69 | return data
70 |
71 | print('Loading TFRecord...')
72 | with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
73 | dataset_info = json.load(fin)
74 |
75 | ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
76 |
77 | ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
78 | ds = ds.map(pre_process, tf.data.AUTOTUNE)
79 | ds = ds.batch(batch_size * num_devices, drop_remainder=True)
80 | ds = ds.map(shard, tf.data.AUTOTUNE)
81 | ds = ds.prefetch(1)
82 | return ds, dataset_info
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/training/stylegan2/dataset_utils/crop_image_borders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import os
4 | from tqdm import tqdm
5 | import argparse
6 |
7 |
8 | def crop_border(x, constant=0.0):
9 | top = 0
10 | while True:
11 | if np.sum(x[top] != constant) != 0.0:
12 | break
13 | top += 1
14 | bottom = x.shape[0] - 1
15 | while True:
16 | if np.sum(x[bottom] != constant) != 0.0:
17 | bottom += 1
18 | break
19 | bottom -= 1
20 | left = 0
21 | while True:
22 | if np.sum(x[:, left] != constant) != 0.0:
23 | break
24 | left += 1
25 | right = x.shape[1] - 1
26 | while True:
27 | if np.sum(x[:, right] != constant) != 0.0:
28 | right += 1
29 | break
30 | right -= 1
31 | return x[top:bottom, left:right]
32 |
33 |
34 | def crop_images(path, constant_value):
35 | print('Crop image borders...')
36 | for f in tqdm(os.listdir(path)):
37 | img = Image.open(os.path.join(path, f))
38 | img = crop_border(np.array(img), constant=constant_value)
39 | img = Image.fromarray(img)
40 | img.save(os.path.join(path, f))
41 |
42 |
43 | if __name__ == '__main__':
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
46 | parser.add_argument('--constant_value', type=float, default=0.0, help='Value of the border that should be cropped.')
47 |
48 | args = parser.parse_args()
49 |
50 | crop_images(args.image_dir, args.constant_value)
51 |
52 |
--------------------------------------------------------------------------------
/training/stylegan2/dataset_utils/images_to_tfrecords.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from PIL import Image
4 | from typing import Sequence
5 | from tqdm import tqdm
6 | import argparse
7 | import json
8 | import os
9 |
10 |
11 | def images_to_tfrecords(image_dir, data_dir, has_labels):
12 | """
13 | Converts a folder of images to a TFRecord file.
14 |
15 | The image directory should have one of the following structures:
16 |
17 | If has_labels = False, image_dir should look like this:
18 |
19 | path/to/image_dir/
20 | 0.jpg
21 | 1.jpg
22 | 2.jpg
23 | 4.jpg
24 | ...
25 |
26 |
27 | If has_labels = True, image_dir should look like this:
28 |
29 | path/to/image_dir/
30 | label0/
31 | 0.jpg
32 | 1.jpg
33 | ...
34 | label1/
35 | a.jpg
36 | b.jpg
37 | c.jpg
38 | ...
39 | ...
40 |
41 |
42 | The labels will be label0 -> 0, label1 -> 1.
43 |
44 | Args:
45 | image_dir (str): Path to images.
46 | data_dir (str): Path where the TFrecords dataset is stored.
47 | has_labels (bool): If True, 'image_dir' contains label directories.
48 |
49 | Returns:
50 | (dict): Dataset info.
51 | """
52 |
53 | def _bytes_feature(value):
54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
55 |
56 | def _int64_feature(value):
57 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
58 |
59 | os.makedirs(data_dir, exist_ok=True)
60 | writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
61 |
62 | num_examples = 0
63 | num_classes = 0
64 |
65 | if has_labels:
66 | for label_dir in os.listdir(image_dir):
67 | if not os.path.isdir(os.path.join(image_dir, label_dir)):
68 | print('The image directory should contain one directory for each label.')
69 | print('These label directories should contain the image files.')
70 | if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
71 | os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
72 | return
73 |
74 | for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
75 | file_format = img_file[img_file.rfind('.') + 1:]
76 | if file_format not in ['png', 'jpg', 'jpeg']:
77 | continue
78 |
79 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
80 | img = Image.open(os.path.join(image_dir, label_dir, img_file))
81 | img = np.array(img, dtype=np.uint8)
82 |
83 | height = img.shape[0]
84 | width = img.shape[1]
85 | channels = img.shape[2]
86 |
87 | img_encoded = img.tobytes()
88 |
89 | example = tf.train.Example(features=tf.train.Features(feature={
90 | 'height': _int64_feature(height),
91 | 'width': _int64_feature(width),
92 | 'channels': _int64_feature(channels),
93 | 'image': _bytes_feature(img_encoded),
94 | 'label': _int64_feature(num_classes)}))
95 |
96 | writer.write(example.SerializeToString())
97 | num_examples += 1
98 |
99 | num_classes += 1
100 | else:
101 | for img_file in tqdm(os.listdir(os.path.join(image_dir))):
102 | file_format = img_file[img_file.rfind('.') + 1:]
103 | if file_format not in ['png', 'jpg', 'jpeg']:
104 | continue
105 |
106 | #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
107 | img = Image.open(os.path.join(image_dir, img_file))
108 | img = np.array(img, dtype=np.uint8)
109 |
110 | height = img.shape[0]
111 | width = img.shape[1]
112 | channels = img.shape[2]
113 |
114 | img_encoded = img.tobytes()
115 |
116 | example = tf.train.Example(features=tf.train.Features(feature={
117 | 'height': _int64_feature(height),
118 | 'width': _int64_feature(width),
119 | 'channels': _int64_feature(channels),
120 | 'image': _bytes_feature(img_encoded),
121 | 'label': _int64_feature(num_classes)})) # dummy label
122 |
123 | writer.write(example.SerializeToString())
124 | num_examples += 1
125 |
126 | writer.close()
127 |
128 | dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
129 | with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
130 | json.dump(dataset_info, fout)
131 |
132 |
133 | if __name__ == '__main__':
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
136 | parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
137 | parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
138 |
139 | args = parser.parse_args()
140 |
141 | images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
142 |
143 |
--------------------------------------------------------------------------------
/training/stylegan2/fid/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import FID
2 |
--------------------------------------------------------------------------------
/training/stylegan2/fid/core.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax
4 | import flax.linen as nn
5 | import numpy as np
6 | import os
7 | import functools
8 | import argparse
9 | import scipy
10 | from tqdm import tqdm
11 |
12 | from . import inception
13 | from . import utils
14 |
15 |
16 | class FID:
17 |
18 | def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
19 | """
20 | Evaluates the FID score for a given generator and a given dataset.
21 | Implementation mostly taken from https://github.com/matthias-wright/jax-fid
22 |
23 | Reference: https://arxiv.org/abs/1706.08500
24 |
25 | Args:
26 | generator (nn.Module): Generator network.
27 | dataset (tf.data.Dataset): Dataset containing the real images.
28 | config (argparse.Namespace): Configuration.
29 | use_cache (bool): If True, only compute the activation stats once for the real images and store them.
30 | truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
31 | """
32 | self.num_images = config.num_fid_images
33 | self.batch_size = config.batch_size
34 | self.c_dim = config.c_dim
35 | self.z_dim = config.z_dim
36 | self.dataset = dataset
37 | self.num_devices = jax.device_count()
38 | self.use_cache = use_cache
39 |
40 | if self.use_cache:
41 | self.cache = {}
42 |
43 | rng = jax.random.PRNGKey(0)
44 | inception_net = inception.InceptionV3(pretrained=True)
45 | self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
46 | self.inception_params = flax.jax_utils.replicate(self.inception_params)
47 | #self.inception = jax.jit(functools.partial(model.apply, train=False))
48 | self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
49 |
50 | self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
51 |
52 | def compute_fid(self, generator_params, seed_offset=0):
53 | generator_params = flax.jax_utils.replicate(generator_params)
54 | mu_real, sigma_real = self.compute_stats_for_dataset()
55 | mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
56 | fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
57 | return fid_score
58 |
59 | def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
60 | # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
61 | mu1 = np.atleast_1d(mu1)
62 | mu2 = np.atleast_1d(mu2)
63 | sigma1 = np.atleast_1d(sigma1)
64 | sigma2 = np.atleast_1d(sigma2)
65 |
66 | assert mu1.shape == mu2.shape
67 | assert sigma1.shape == sigma2.shape
68 |
69 | diff = mu1 - mu2
70 |
71 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
72 | if not np.isfinite(covmean).all():
73 | msg = ('fid calculation produces singular product; '
74 | 'adding %s to diagonal of cov estimates') % eps
75 | print(msg)
76 | offset = np.eye(sigma1.shape[0]) * eps
77 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
78 |
79 | # Numerical error might give slight imaginary component
80 | if np.iscomplexobj(covmean):
81 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
82 | m = np.max(np.abs(covmean.imag))
83 | raise ValueError('Imaginary component {}'.format(m))
84 | covmean = covmean.real
85 |
86 | tr_covmean = np.trace(covmean)
87 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
88 |
89 | def compute_stats_for_dataset(self):
90 | if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
91 | print('Use cached statistics for dataset...')
92 | return self.cache['mu'], self.cache['sigma']
93 |
94 | print()
95 | print('Compute statistics for dataset...')
96 | pbar = tqdm(total=self.num_images)
97 | image_count = 0
98 |
99 | activations = []
100 | for batch in utils.prefetch(self.dataset, n_prefetch=2):
101 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
102 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1))
103 | activations.append(act)
104 |
105 | pbar.update(self.num_devices * self.batch_size)
106 | image_count += self.num_devices * self.batch_size
107 | if image_count >= self.num_images:
108 | break
109 | pbar.close()
110 |
111 | activations = jnp.concatenate(activations, axis=0)
112 | activations = activations[:self.num_images]
113 | mu = np.mean(activations, axis=0)
114 | sigma = np.cov(activations, rowvar=False)
115 | self.cache['mu'] = mu
116 | self.cache['sigma'] = sigma
117 | return mu, sigma
118 |
119 | def compute_stats_for_generator(self, generator_params, seed_offset):
120 | print()
121 | print('Compute statistics for generator...')
122 | num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_devices)))
123 |
124 | pbar = tqdm(total=self.num_images)
125 | activations = []
126 |
127 | for i in range(num_batches):
128 | rng = jax.random.PRNGKey(seed_offset + i)
129 | z_latent = jax.random.normal(rng, shape=(self.num_devices, self.batch_size, self.z_dim))
130 |
131 | labels = None
132 | if self.c_dim > 0:
133 | labels = jax.random.randint(rng, shape=(self.num_devices * self.batch_size,), minval=0, maxval=self.c_dim)
134 | labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
135 | labels = jnp.reshape(labels, (self.num_devices, self.batch_size, self.c_dim))
136 |
137 | image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
138 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
139 |
140 | image = 2 * image - 1
141 | act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
142 | act = jnp.reshape(act, (self.num_devices * self.batch_size, -1))
143 | activations.append(act)
144 | pbar.update(self.num_devices * self.batch_size)
145 | pbar.close()
146 |
147 | activations = jnp.concatenate(activations, axis=0)
148 | activations = activations[:self.num_images]
149 | mu = np.mean(activations, axis=0)
150 | sigma = np.cov(activations, rowvar=False)
151 | return mu, sigma
152 |
153 |
154 |
--------------------------------------------------------------------------------
/training/stylegan2/fid/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import flax
3 | import numpy as np
4 | from tqdm import tqdm
5 | import requests
6 | import os
7 | import tempfile
8 |
9 |
10 | def download(url, ckpt_dir=None):
11 | name = url[url.rfind('/') + 1 : url.rfind('?')]
12 | if ckpt_dir is None:
13 | ckpt_dir = tempfile.gettempdir()
14 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
15 | ckpt_file = os.path.join(ckpt_dir, name)
16 | if not os.path.exists(ckpt_file):
17 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
18 | if not os.path.exists(ckpt_dir):
19 | os.makedirs(ckpt_dir)
20 |
21 | response = requests.get(url, stream=True)
22 | total_size_in_bytes = int(response.headers.get('content-length', 0))
23 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
24 |
25 | # first create temp file, in case the download fails
26 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
27 | with open(ckpt_file_temp, 'wb') as file:
28 | for data in response.iter_content(chunk_size=1024):
29 | progress_bar.update(len(data))
30 | file.write(data)
31 | progress_bar.close()
32 |
33 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
34 | print('An error occured while downloading, please try again.')
35 | if os.path.exists(ckpt_file_temp):
36 | os.remove(ckpt_file_temp)
37 | else:
38 | # if download was successful, rename the temp file
39 | os.rename(ckpt_file_temp, ckpt_file)
40 | return ckpt_file
41 |
42 |
43 | def get(dictionary, key):
44 | if dictionary is None or key not in dictionary:
45 | return None
46 | return dictionary[key]
47 |
48 |
49 | def prefetch(dataset, n_prefetch):
50 | # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
51 | ds_iter = iter(dataset)
52 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
53 | ds_iter)
54 | if n_prefetch:
55 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
56 | return ds_iter
57 |
--------------------------------------------------------------------------------
/training/stylegan2/generate_images.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax
4 | import numpy as np
5 | import dill as pickle
6 | import flaxmodels as fm
7 | import data_pipeline
8 | import checkpoint
9 | from PIL import Image
10 | from tqdm import tqdm
11 | import argparse
12 | import functools
13 | import os
14 |
15 |
16 | def generate_images(args):
17 | num_devices = jax.device_count()
18 | ckpt = checkpoint.load_checkpoint(args.ckpt_path)
19 | config = ckpt['config']
20 |
21 | dtype = jnp.float32
22 |
23 | generator_ema = fm.stylegan2.Generator(resolution=config.resolution,
24 | num_channels=config.img_channels,
25 | z_dim=config.z_dim,
26 | c_dim=config.c_dim,
27 | w_dim=config.w_dim,
28 | num_ws=int(np.log2(config.resolution)) * 2 - 3,
29 | num_mapping_layers=8,
30 | fmap_base=config.fmap_base,
31 | dtype=dtype)
32 |
33 | generator_apply = jax.jit(functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const'))
34 | params_ema_G = ckpt['params_ema_G']
35 |
36 | for seed in tqdm(args.seeds):
37 | rng = jax.random.PRNGKey(seed)
38 | z_latent = jax.random.normal(rng, shape=(1, config.z_dim))
39 | labels = None
40 | if config.c_dim > 0:
41 | labels = jax.random.randint(rng, shape=(1,), minval=0, maxval=config.c_dim)
42 | labels = jax.nn.one_hot(labels, num_classes=config.c_dim)
43 | labels = jnp.reshape(labels, (1, config.c_dim))
44 |
45 | image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), labels)
46 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
47 |
48 | Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png'))
49 | print('Images saved at:', args.out_path)
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('--ckpt_path', type=str, help='Path to the checkpoint.')
55 | parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
56 | parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
57 | parser.add_argument('--seeds', type=int, nargs='*', help='List of random seeds.')
58 | args = parser.parse_args()
59 | os.makedirs(args.out_path, exist_ok=True)
60 |
61 | generate_images(args)
62 |
63 |
64 |
--------------------------------------------------------------------------------
/training/stylegan2/images/anime_grid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_grid.jpg
--------------------------------------------------------------------------------
/training/stylegan2/images/anime_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_overview.jpg
--------------------------------------------------------------------------------
/training/stylegan2/images/anime_style_mixing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/anime_style_mixing.jpg
--------------------------------------------------------------------------------
/training/stylegan2/images/ffhq_grid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_grid.jpg
--------------------------------------------------------------------------------
/training/stylegan2/images/ffhq_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_overview.jpg
--------------------------------------------------------------------------------
/training/stylegan2/images/ffhq_style_mixing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matthias-wright/flaxmodels/09bc77215032375d124e3f1eba828dd89c80a850/training/stylegan2/images/ffhq_style_mixing.jpg
--------------------------------------------------------------------------------
/training/stylegan2/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import jax
4 | import wandb
5 | import training
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | # Paths
11 | parser.add_argument('--work_dir', type=str, default='logging', help='Directory for logging and checkpoints.')
12 | parser.add_argument('--data_dir', type=str, help='Directory of the dataset.')
13 | parser.add_argument('--project', type=str, default='stylegan', help='Name of this project.')
14 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.')
15 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment (for Weights&Biases).')
16 | # Training
17 | parser.add_argument('--resume', action='store_true', help='Resume training from latest checkpoint.')
18 | parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.')
19 | parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.')
20 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size.')
21 | parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.')
22 | parser.add_argument('--resolution', type=int, default=128, help='Image resolution. Must be a multiple of 2.')
23 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
24 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
25 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
26 | # Generator
27 | parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.')
28 | # Discriminator
29 | parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.')
30 | # Exponentially Moving Average of Generator Weights
31 | parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).')
32 | # Losses
33 | parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).')
34 | parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.')
35 | # Regularization
36 | parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.')
37 | parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.')
38 | parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.')
39 | parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.')
40 | # Model
41 | parser.add_argument('--z_dim', type=int, default=512, help='Input latent (Z) dimensionality.')
42 | parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.')
43 | parser.add_argument('--w_dim', type=int, default=512, help='Conditioning label (W) dimensionality.')
44 | # Logging
45 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.')
46 | parser.add_argument('--log_every', type=int, default=50, help='Log every log_every steps.')
47 | parser.add_argument('--save_every', type=int, default=2000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.')
48 | # FID
49 | parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.')
50 | parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.')
51 | parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.')
52 |
53 | args = parser.parse_args()
54 |
55 | if jax.process_index() == 0:
56 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints')
57 | if not os.path.exists(args.ckpt_dir):
58 | os.makedirs(args.ckpt_dir)
59 |
60 | if args.wandb:
61 | wandb.init(project=args.project,
62 | group=args.group,
63 | config=args,
64 | name=args.name,
65 | dir=os.path.join(args.work_dir, args.group, args.name))
66 |
67 | training.train_and_evaluate(args)
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
73 |
--------------------------------------------------------------------------------
/training/stylegan2/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | git+https://github.com/matthias-wright/flaxmodels.git
3 | tensorflow-datasets
4 | tensorflow==2.4.1
5 | optax
6 | argparse
7 | wandb
8 | tqdm
9 | dill
10 |
--------------------------------------------------------------------------------
/training/stylegan2/style_mixing.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax
4 | from flax.core import frozen_dict
5 | import numpy as np
6 | import dill as pickle
7 | import flaxmodels as fm
8 | import data_pipeline
9 | import checkpoint
10 | from PIL import Image
11 | from tqdm import tqdm
12 | import argparse
13 | import functools
14 | import os
15 |
16 |
17 | def style_mixing(args):
18 | num_devices = jax.device_count()
19 | ckpt = checkpoint.load_checkpoint(args.ckpt_path)
20 | config = ckpt['config']
21 |
22 | dtype = jnp.float32
23 |
24 | mapping_net = fm.stylegan2.MappingNetwork(z_dim=config.z_dim,
25 | c_dim=config.c_dim,
26 | w_dim=config.w_dim,
27 | num_ws=int(np.log2(config.resolution)) * 2 - 3,
28 | num_layers=8,
29 | dtype=dtype)
30 |
31 | synthesis_net = fm.stylegan2.SynthesisNetwork(resolution=config.resolution,
32 | num_channels=config.img_channels,
33 | w_dim=config.w_dim,
34 | fmap_base=config.fmap_base,
35 | dtype=dtype)
36 |
37 | params_ema_G = ckpt['params_ema_G']
38 | params_ema_G = params_ema_G.unfreeze()
39 | synthesis_params = {'params': params_ema_G['params']['synthesis_network'],
40 | 'noise_consts': params_ema_G['noise_consts']['synthesis_network']}
41 | synthesis_params = frozen_dict.freeze(synthesis_params)
42 |
43 | mapping_params = {'params': params_ema_G['params']['mapping_network'],
44 | 'moving_stats': params_ema_G['moving_stats']['mapping_network']}
45 | mapping_params = frozen_dict.freeze(mapping_params)
46 |
47 | synthesis_apply = jax.jit(functools.partial(synthesis_net.apply, noise_mode='const'))
48 | mapping_apply = jax.jit(functools.partial(mapping_net.apply, truncation_psi=args.truncation_psi, train=False))
49 |
50 | all_seeds = args.row_seeds + args.col_seeds
51 | # Generate noise inputs, [minibatch, component]
52 | all_z = jnp.concatenate([jax.random.normal(jax.random.PRNGKey(seed), shape=(1, 512)) for seed in all_seeds])
53 | # Generate latent vectors, [minibatch, num_ws, component]
54 | all_w = mapping_apply(mapping_params, all_z)
55 | # Generate images, [minibatch, H, W, 3]
56 | all_images = synthesis_apply(synthesis_params, all_w)
57 | # Normalize image to be in range [0, 1]
58 | all_images = (all_images - jnp.min(all_images)) / (jnp.max(all_images) - jnp.min(all_images))
59 | col_images = jnp.concatenate([all_images[i] for i in range(len(args.row_seeds))], axis=0)
60 | row_images = jnp.concatenate([all_images[len(args.row_seeds) + i] for i in range(len(args.col_seeds))], axis=1)
61 |
62 | images_grid = []
63 |
64 | cutoff = mapping_net.num_ws // 2
65 |
66 | # Generate style mixing images
67 | for row in range(len(args.row_seeds)):
68 | image_row = []
69 | for col in range(len(args.col_seeds)):
70 | # Combine first 9 dimensions from row seed latent w with last 9 dimensions from col seed latent w
71 | w = jnp.concatenate([all_w[row, :cutoff], all_w[len(args.row_seeds) + col, cutoff:]], axis=0)
72 | # Add batch dimension
73 | w = jnp.expand_dims(w, axis=0)
74 | image = synthesis_apply(synthesis_params, w)
75 | # Remove batch dimension
76 | image = jnp.squeeze(image, axis=0)
77 |
78 | # Normalize image to be in range [0, 1]
79 | image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
80 | image_row.append(image)
81 | image_row = jnp.concatenate(image_row, axis=1)
82 | images_grid.append(image_row)
83 |
84 | images_grid = jnp.concatenate(images_grid, axis=0)
85 |
86 | # Add row and column images to the grid
87 | border = 20
88 | grid = np.ones((row_images.shape[0] + images_grid.shape[0] + border,
89 | col_images.shape[1] + images_grid.shape[1] + border,
90 | 3))
91 | grid[grid.shape[0] - images_grid.shape[0]:, grid.shape[1] - images_grid.shape[1]:] = images_grid
92 | grid[:row_images.shape[0], grid.shape[1] - row_images.shape[1]:] = row_images
93 | grid[grid.shape[0] - col_images.shape[0]:, :col_images.shape[1]] = col_images
94 | Image.fromarray(np.uint8(np.clip(grid * 255, 0, 255))).save(os.path.join(args.out_path, 'style_mixing.png'))
95 | print('Style mixing grid saved at:', args.out_path)
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--ckpt_path', type=str, help='Path to the checkpoint.')
101 | parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
102 | parser.add_argument('--num_images', type=int, default=100, help='Number of images to generate.')
103 | parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
104 | parser.add_argument('--row_seeds', type=int, nargs='*', help='List of random seeds for row images.')
105 | parser.add_argument('--col_seeds', type=int, nargs='*', help='List of random seeds for column images.')
106 | args = parser.parse_args()
107 | assert len(args.row_seeds) == len(args.col_seeds), 'row_seeds and col_seeds must have the same length.'
108 | os.makedirs(args.out_path, exist_ok=True)
109 |
110 | style_mixing(args)
111 |
112 |
113 |
--------------------------------------------------------------------------------
/training/stylegan2/training_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jaxlib.xla_extension import DeviceArray
4 | import flax
5 | from flax.optim import dynamic_scale as dynamic_scale_lib
6 | from flax.core import frozen_dict
7 | from flax.training import train_state
8 | from flax import struct
9 | import numpy as np
10 | from PIL import Image
11 | from typing import Any, Callable
12 |
13 |
14 | def sync_moving_stats(state):
15 | """
16 | Sync moving statistics across devices.
17 |
18 | Args:
19 | state (train_state.TrainState): Training state.
20 |
21 | Returns:
22 | (train_state.TrainState): Updated training state.
23 | """
24 | cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
25 | return state.replace(moving_stats=cross_replica_mean(state.moving_stats))
26 |
27 |
28 | def update_generator_ema(state_G, params_ema_G, config, ema_beta=None):
29 | """
30 | Update exponentially moving average of the generator weights.
31 | Moving stats and noise constants will be copied over.
32 |
33 | Args:
34 | state_G (train_state.TrainState): Generator state.
35 | params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
36 | config (Any): Config object.
37 | ema_beta (float): Beta parameter of the ema. If None, will be computed
38 | from 'ema_nimg' and 'batch_size'.
39 |
40 | Returns:
41 | (frozen_dict.FrozenDict): Updates parameters of the ema generator.
42 | """
43 | def _update_ema(src, trg, beta):
44 | for name, src_child in src.items():
45 | if isinstance(src_child, DeviceArray):
46 | trg[name] = src[name] + ema_beta * (trg[name] - src[name])
47 | else:
48 | _update_ema(src_child, trg[name], beta)
49 |
50 | if ema_beta is None:
51 | ema_nimg = config.ema_kimg * 1000
52 | ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8))
53 |
54 | params_ema_G = params_ema_G.unfreeze()
55 |
56 | # Copy over moving stats
57 | params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats
58 | params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts
59 |
60 | # Update exponentially moving average of the trainable parameters
61 | _update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta)
62 | _update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta)
63 |
64 | params_ema_G = frozen_dict.freeze(params_ema_G)
65 | return params_ema_G
66 |
67 |
68 | class TrainStateG(train_state.TrainState):
69 | """
70 | Generator train state for a single Optax optimizer.
71 |
72 | Attributes:
73 | apply_mapping (Callable): Apply function of the Mapping Network.
74 | apply_synthesis (Callable): Apply function of the Synthesis Network.
75 | dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
76 | epoch (int): Current epoch.
77 | moving_stats (Any): Moving average of the latent W.
78 | noise_consts (Any): Noise constants from synthesis layers.
79 | """
80 | apply_mapping: Callable = struct.field(pytree_node=False)
81 | apply_synthesis: Callable = struct.field(pytree_node=False)
82 | dynamic_scale_main: dynamic_scale_lib.DynamicScale
83 | dynamic_scale_reg: dynamic_scale_lib.DynamicScale
84 | epoch: int
85 | moving_stats: Any=None
86 | noise_consts: Any=None
87 |
88 |
89 | class TrainStateD(train_state.TrainState):
90 | """
91 | Discriminator train state for a single Optax optimizer.
92 |
93 | Attributes:
94 | dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
95 | epoch (int): Current epoch.
96 | """
97 | dynamic_scale_main: dynamic_scale_lib.DynamicScale
98 | dynamic_scale_reg: dynamic_scale_lib.DynamicScale
99 | epoch: int
100 |
101 |
102 | def get_training_snapshot(image_real, image_gen, max_num=10):
103 | """
104 | Creates a snapshot of generated images and real images.
105 |
106 | Args:
107 | images_real (DeviceArray): Batch of real images, shape [B, H, W, C].
108 | images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C].
109 | max_num (int): Maximum number of images used for snapshot.
110 |
111 | Returns:
112 | (PIL.Image): Training snapshot. Top row: generated images, bottom row: real images.
113 | """
114 | if image_real.shape[0] > max_num:
115 | image_real = image_real[:max_num]
116 | if image_gen.shape[0] > max_num:
117 | image_gen = image_gen[:max_num]
118 |
119 | image_real = jnp.split(image_real, image_real.shape[0], axis=0)
120 | image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0)
121 |
122 | image_real = [jnp.squeeze(x, axis=0) for x in image_real]
123 | image_gen = [jnp.squeeze(x, axis=0) for x in image_gen]
124 |
125 | image_real = jnp.concatenate(image_real, axis=1)
126 | image_gen = jnp.concatenate(image_gen, axis=1)
127 |
128 | image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen))
129 | image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real))
130 | image = jnp.concatenate((image_gen, image_real), axis=0)
131 |
132 | image = np.uint8(image * 255)
133 | if image.shape[-1] == 1:
134 | image = np.repeat(image, 3, axis=-1)
135 | return Image.fromarray(image)
136 |
137 |
138 | def get_eval_snapshot(image, max_num=10):
139 | """
140 | Creates a snapshot of generated images.
141 |
142 | Args:
143 | image (DeviceArray): Generated images, shape [B, H, W, C].
144 |
145 | Returns:
146 | (PIL.Image): Eval snapshot.
147 | """
148 | if image.shape[0] > max_num:
149 | image = image[:max_num]
150 |
151 | image = jnp.split(image, image.shape[0], axis=0)
152 | image = [jnp.squeeze(x, axis=0) for x in image]
153 | image = jnp.concatenate(image, axis=1)
154 | image = (image - np.min(image)) / (np.max(image) - np.min(image))
155 | image = np.uint8(image * 255)
156 | if image.shape[-1] == 1:
157 | image = np.repeat(image, 3, axis=-1)
158 | return Image.fromarray(image)
159 |
--------------------------------------------------------------------------------
/training/vgg/README.md:
--------------------------------------------------------------------------------
1 | # VGG Training in Jax/Flax
2 | This is the training code for the [Jax/Flax implementation](https://github.com/matthias-wright/flaxmodels/tree/main/flaxmodels/vgg) of [VGG](https://arxiv.org/abs/1409.1556).
3 |
4 | ##### Table of Contents
5 | * [Getting Started](#getting_started)
6 | * [Training](#training)
7 | * [Options](#options)
8 | * [References](#references)
9 | * [License](#license)
10 |
11 |
12 |
13 | ## Getting Started
14 | You will need Python 3.7 or later.
15 |
16 | 1. Clone the repository:
17 | ```sh
18 | > git clone https://github.com/matthias-wright/flaxmodels.git
19 | ```
20 | 2. Go into the directory:
21 | ```sh
22 | > cd flaxmodels/training/vgg
23 | ```
24 | 3. Install Jax with CUDA.
25 | 4. Install requirements:
26 | ```sh
27 | > pip install -r requirements.txt
28 | ```
29 |
30 |
31 | ## Training
32 |
33 | ### Basic Training
34 | ```python
35 | CUDA_VISIBLE_DEVICES=0 python main.py
36 | ```
37 |
38 | ### Multi GPU Training
39 | The script will automatically use all the visible GPUs for distributed training.
40 | ```python
41 | CUDA_VISIBLE_DEVICES=0,1 python main.py
42 | ```
43 |
44 | ### Mixed-Precision Training
45 | ```python
46 | CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision
47 | ```
48 |
49 |
50 | ## Options
51 | * `--work_dir` - Path to directory for logging and checkpoints (str).
52 | * `--data_dir` - Path for storing the dataset (str).
53 | * `--name` - Name of the training run (str).
54 | * `--group` - Group name of the training run (str).
55 | * `--arch` - Architecture (str). Options: vgg16, vgg19.
56 | * `--resume` - Resume training from best checkpoint (bool).
57 | * `--num_epochs` - Number of epochs (int).
58 | * `--learning_rate` - Learning rate (float).
59 | * `--warmup_epochs` - Number of warmup epochs with lower learning rate (int).
60 | * `--batch_size` - Batch size (int).
61 | * `--num_classes` - Number of classes (int).
62 | * `--img_size` - Image size (int).
63 | * `--img_channels` - Number of image channels (int).
64 | * `--mixed_precision` - Use mixed precision training (bool).
65 | * `--random_seed` - Random seed (int).
66 | * `--wandb` - Use Weights&Biases for logging (bool).
67 | * `--log_every` - Log every log_every steps (int).
68 |
69 |
70 | ## References
71 | * [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556)
72 | * [pytorch/vision/torchvision/models/vgg.py](https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py)
73 | * [google/flax/examples](https://github.com/google/flax/tree/main/examples)
74 |
75 |
76 |
77 | ## License
78 | [MIT License](https://opensource.org/licenses/MIT)
79 |
80 |
--------------------------------------------------------------------------------
/training/vgg/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import jax
4 | import wandb
5 | import training
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | # Paths
11 | parser.add_argument('--work_dir', type=str, default='/export/scratch/mwright/projects/misc/imagenette', help='Directory for logging and checkpoints.')
12 | parser.add_argument('--data_dir', type=str, default='/export/data/mwright/tensorflow_datasets', help='Directory for storing data.')
13 | parser.add_argument('--name', type=str, default='test', help='Name of this experiment.')
14 | parser.add_argument('--group', type=str, default='default', help='Group name of this experiment.')
15 | # Training
16 | parser.add_argument('--arch', type=str, default='vgg16', choices=['vgg16', 'vgg19'], help='Architecture.')
17 | parser.add_argument('--resume', action='store_true', help='Resume training from best checkpoint.')
18 | parser.add_argument('--num_epochs', type=int, default=200, help='Number of epochs.')
19 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate.')
20 | parser.add_argument('--warmup_epochs', type=int, default=9, help='Number of warmup epochs with lower learning rate.')
21 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
22 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes.')
23 | parser.add_argument('--img_size', type=int, default=224, help='Image size.')
24 | parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
25 | parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
26 | parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
27 | # Logging
28 | parser.add_argument('--wandb', action='store_true', help='Log to Weights&bBiases.')
29 | parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.')
30 | args = parser.parse_args()
31 |
32 | if jax.process_index() == 0:
33 | args.ckpt_dir = os.path.join(args.work_dir, args.group, args.name, 'checkpoints')
34 | if not os.path.exists(args.ckpt_dir):
35 | os.makedirs(args.ckpt_dir)
36 |
37 | if args.wandb:
38 | wandb.init(entity='matthias-wright',
39 | project='imagenette',
40 | group=args.group,
41 | config=args,
42 | name=args.name,
43 | dir=os.path.join(args.work_dir, args.group, args.name))
44 |
45 | training.train_and_evaluate(args)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/training/vgg/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | git+https://github.com/matthias-wright/flaxmodels.git
3 | tensorflow
4 | tensorflow-datasets
5 | optax
6 | argparse
7 | wandb
8 | tqdm
9 |
--------------------------------------------------------------------------------