`_.
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | [home](../README.md) > examples
2 |
3 | # Examples
4 |
5 | This directory contains fully functional code examples that you can use to learn more about Objax.
6 |
7 | Examples from classic machine learning:
8 | * [Image Classification](image_classification/README.md)
9 | * [Text Generation](text_generation/README.md)
10 |
11 | Examples from recent research:
12 | * [Model-Agnostic Meta-Learning](maml/README.md)
13 | * [FixMatch](fixmatch/README.md)
14 | * [GPT-2](gpt-2/README.md)
15 |
16 | Other examples:
17 | * [Tutorials](tutorials/README.md)
18 | * [JaxBoard](jaxboard/README.md)
--------------------------------------------------------------------------------
/examples/fixmatch/README.md:
--------------------------------------------------------------------------------
1 | [home](../../README.md) > [examples](../README.md) > fixmatch
2 |
3 | # Semi-Supervised Image Classification with [FixMatch](https://arxiv.org/abs/2001.07685)
4 |
5 | ## Setup
6 |
7 | ### Required environment variables
8 |
9 | ```bash
10 | export PYTHONPATH=$PYTHONPATH:.
11 | export ML_DATA="path to where you want the datasets saved"
12 | export PROJECT="ObjaxSSL"
13 | export SSL_PATH=examples/fixmatch
14 | ```
15 |
16 | ## Data preparation
17 |
18 | ```bash
19 | # Download datasets
20 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_datasets.py
21 | cp $ML_DATA/$PROJECT/svhn-test.tfrecord $ML_DATA/$PROJECT/svhnx-test.tfrecord
22 |
23 | # Create unlabeled datasets
24 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord &
25 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord &
26 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord &
27 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord &
28 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_unlabeled.py $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord &
29 | wait
30 |
31 | # Create semi-supervised subsets
32 | for seed in 0 1 2 3 4 5; do
33 | for size in 40 100 250 1000 4000; do
34 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar10 $ML_DATA/$PROJECT/cifar10-train.tfrecord &
35 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhn $ML_DATA/$PROJECT/svhn-train.tfrecord &
36 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/svhnx $ML_DATA/$PROJECT/svhn-train.tfrecord $ML_DATA/$PROJECT/svhn-extra.tfrecord &
37 | done
38 | for size in 400 1000 2500 10000; do
39 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=$size $ML_DATA/$PROJECT/SSL/cifar100 $ML_DATA/$PROJECT/cifar100-train.tfrecord &
40 | done
41 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=$seed --size=1000 $ML_DATA/$PROJECT/SSL/stl10 $ML_DATA/$PROJECT/stl10-train.tfrecord $ML_DATA/$PROJECT/stl10-unlabeled.tfrecord &
42 | wait
43 | done
44 | CUDA_VISIBLE_DEVICES= $SSL_PATH/scripts/create_split.py --seed=1 --size=5000 $ML_DATA/$PROJECT/stl10 $ML_DATA/stl10-train.tfrecord $ML_DATA/stl10-unlabeled.tfrecord
45 | ```
46 |
47 | ## Training
48 |
49 | ```bash
50 | # FixMatch
51 | python $SSL_PATH/fixmatch.py --dataset=cifar10.3@250-0 --unlabeled=cifar10 --uratio=5 --augment='CTA(sm,sm,sm)'
52 | ```
53 |
54 | ## Tensorboard
55 |
56 | ```bash
57 | tensorboard --port 6006 --logdir_spec=experiments
58 | ```
59 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/augment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/augment/core.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Augmentations for images.
16 | """
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def cutout(x, w):
22 | offsets = tf.random.uniform([2], 0, 1)
23 | s = tf.shape(x)
24 | y0 = tf.cast(tf.round(offsets[0] * (tf.cast(s[0], tf.float32) - w)), tf.int32)
25 | x0 = tf.cast(tf.round(offsets[1] * (tf.cast(s[1], tf.float32) - w)), tf.int32)
26 | hr, wr = tf.range(s[0])[:, None, None], tf.range(s[1])[None, :, None]
27 | mask = 1-tf.cast((hr >= y0) & (hr < y0 + w) & (wr >= x0) & (wr < x0 + w), tf.float32)
28 | return mask * x
29 |
30 |
31 | def mirror(x):
32 | return tf.image.random_flip_left_right(x)
33 |
34 |
35 | def shift(x, w):
36 | y = tf.pad(x, [[w] * 2, [w] * 2, [0] * 2], mode='REFLECT')
37 | return tf.image.random_crop(y, tf.shape(x))
38 |
39 |
40 | def noise(x, std):
41 | return x + std * tf.random.normal(tf.shape(x), dtype=x.dtype)
42 |
43 |
44 | def get_tf_augment(augment, size=32):
45 | aug = dict(
46 | x=lambda **kw: kw,
47 | s=lambda image, **kw: dict(image=shift(image, size >> 3), **kw),
48 | sc=lambda image, **kw: dict(image=cutout(shift(image, size >> 3), size >> 1), **kw),
49 | sm=lambda image, **kw: dict(image=mirror(shift(image, size >> 3)), **kw),
50 | smc=lambda image, **kw: dict(image=cutout(mirror(shift(image, size >> 3)), size >> 1), **kw))
51 | return lambda x: aug[augment](**x)
52 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/augment/randaugment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .randaugment import RandAugment
16 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/data/ssl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from typing import Callable, List
17 |
18 | from absl import flags
19 |
20 | from examples.fixmatch.libml.data import core
21 |
22 | FLAGS = flags.FLAGS
23 |
24 |
25 | class DataSetsUnlabeled:
26 | def __init__(self, name: str, train: core.DataSet):
27 | self.name = name
28 | self.train = train
29 |
30 | @property
31 | def colors(self):
32 | return self.train.image_shape[2]
33 |
34 | @property
35 | def height(self):
36 | return self.train.image_shape[0]
37 |
38 | @property
39 | def width(self):
40 | return self.train.image_shape[1]
41 |
42 | @classmethod
43 | def creator(cls, name: str, train_files: List[str], parse_fn: Callable = core.record_parse,
44 | height: int = 32, width: int = 32, colors: int = 3, cache: bool = False):
45 | train_files = [os.path.join(core.DATA_DIR, x) for x in train_files]
46 |
47 | def create():
48 | image_shape = height, width, colors
49 | kw = dict(parse_fn=parse_fn)
50 | train = core.DataSet.from_files(train_files, image_shape, **kw)
51 | if cache:
52 | train = train.cache()
53 | return cls(name, train)
54 |
55 | return name, create
56 |
57 |
58 | def create_datasets():
59 | d = {}
60 | d.update([DataSetsUnlabeled.creator('mnist', ['mnist-train.tfrecord'], cache=True,
61 | parse_fn=core.record_parse_mnist)])
62 | d.update([DataSetsUnlabeled.creator('cifar10', ['cifar10-train.tfrecord'], cache=True)])
63 | d.update([DataSetsUnlabeled.creator('cifar100', ['cifar100-train.tfrecord'], cache=True)])
64 | d.update([DataSetsUnlabeled.creator('svhn', ['SSL/svhn-unlabel.tfrecord'])])
65 | d.update([DataSetsUnlabeled.creator('svhnx', ['SSL/svhnx-unlabel.tfrecord'])])
66 | d.update([DataSetsUnlabeled.creator('stl10', ['SSL/stl10-unlabel.tfrecord'], height=96, width=96)])
67 | return d
68 |
69 |
70 | DATASETS_UNLABELED = create_datasets
71 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from examples.fixmatch.libml.zoo.convnet import ConvNet
16 | from examples.fixmatch.libml.zoo.resnet import ResNet
17 |
18 | ARCHS = 'convnet resnet'.split()
19 |
20 |
21 | def network(arch: str):
22 | if arch == 'convnet':
23 | return ConvNet
24 | elif arch == 'resnet':
25 | return ResNet
26 | raise ValueError('Architecture not recognized', arch)
27 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 |
17 |
18 | def setup_tf():
19 | tf.config.experimental.set_visible_devices([], "GPU")
20 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/zoo/convnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 |
17 | import jax
18 |
19 | import objax
20 | from objax.typing import JaxArray
21 |
22 |
23 | class ConvNet(objax.nn.Sequential):
24 | @staticmethod
25 | def _mean_reduce(x: JaxArray) -> JaxArray:
26 | return x.mean((2, 3))
27 |
28 | def __init__(self, nin, nclass, scales, filters, filters_max, **kwargs):
29 | del kwargs
30 |
31 | def nf(scale):
32 | return min(filters_max, filters << scale)
33 |
34 | ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu]
35 | for i in range(scales):
36 | ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu,
37 | objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu,
38 | functools.partial(objax.functional.average_pool_2d, size=2, strides=2)])
39 | ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce])
40 | super().__init__(ops)
41 |
--------------------------------------------------------------------------------
/examples/fixmatch/libml/zoo/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['ResNetBlock', 'ResNet']
16 |
17 | import functools
18 | from typing import Callable
19 |
20 | import jax
21 |
22 | import objax
23 | from objax.typing import JaxArray
24 |
25 |
26 | def leaky_relu(x):
27 | return objax.functional.leaky_relu(x, 0.1)
28 |
29 |
30 | def conv_args(k, f):
31 | return dict(w_init=functools.partial(objax.random.normal, stddev=objax.functional.rsqrt(0.5 * k * k * f)))
32 |
33 |
34 | class ResNetBlock(objax.Module):
35 | def __init__(self, nin: int, nout: int, stride: int = 1, activate_before_residual: bool = False,
36 | bn: Callable = objax.nn.BatchNorm2D):
37 | self.activate_before_residual = activate_before_residual
38 | self.bn = bn(nin, momentum=0.999)
39 | self.residual = objax.nn.Sequential([objax.nn.Conv2D(nin, nout, 3, strides=stride, **conv_args(3, nout)),
40 | bn(nout, momentum=0.999), leaky_relu,
41 | objax.nn.Conv2D(nout, nout, 3, **conv_args(3, nout))])
42 | self.passthrough = objax.nn.Conv2D(nin, nout, 1, strides=stride, **conv_args(1, nout)) if nin != nout else None
43 |
44 | def __call__(self, x: JaxArray, training: bool) -> JaxArray:
45 | y = leaky_relu(self.bn(x, training))
46 | if self.activate_before_residual:
47 | x = y
48 | if self.passthrough:
49 | x = self.passthrough(x)
50 | return x + self.residual(y, training=training)
51 |
52 |
53 | class ResNet(objax.nn.Sequential):
54 | @staticmethod
55 | def mean_reduce(x: JaxArray) -> JaxArray:
56 | return x.mean((2, 3))
57 |
58 | def __init__(self, nin: int, nclass: int, scales: int, filters: int, repeat: int, dropout: int = 0,
59 | bn: Callable = objax.nn.BatchNorm2D, **kwargs):
60 | del kwargs
61 | n = 16
62 | ops = [objax.nn.Conv2D(nin, n, 3, **conv_args(3, n))]
63 | for scale in range(scales):
64 | last_n, n = n, filters << scale
65 | ops.append(ResNetBlock(last_n, n, stride=2 if scale else 1, activate_before_residual=scale == 0, bn=bn))
66 | ops.extend([ResNetBlock(n, n, bn=bn) for _ in range(repeat - 1)])
67 | ops.extend([bn(n, momentum=0.999), leaky_relu, self.mean_reduce,
68 | objax.nn.Dropout(1 - dropout),
69 | objax.nn.Linear(n, nclass, w_init=objax.nn.init.xavier_truncated_normal)])
70 | super().__init__(ops)
71 |
--------------------------------------------------------------------------------
/examples/fixmatch/scripts/create_split.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to create SSL splits from a dataset.
18 | """
19 |
20 | import json
21 | import os
22 | from collections import defaultdict
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from absl import app, flags
27 | from tqdm import trange, tqdm
28 |
29 | from examples.fixmatch.libml.data import core
30 |
31 | flags.DEFINE_integer('seed', 0, 'Random seed to use, 0 for no shuffling.')
32 | flags.DEFINE_integer('size', 0, 'Size of labelled set.')
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | def get_class(serialized_example):
38 | return tf.io.parse_single_example(serialized_example,
39 | features={'label': tf.io.FixedLenFeature([], tf.int64)})['label']
40 |
41 |
42 | def main(argv):
43 | assert FLAGS.size
44 | argv.pop(0)
45 | if any(not tf.io.gfile.exists(f) for f in argv[1:]):
46 | raise FileNotFoundError(argv[1:])
47 | target = '%s.%d@%d' % (argv[0], FLAGS.seed, FLAGS.size)
48 | if tf.io.gfile.exists(target):
49 | raise FileExistsError('For safety overwriting is not allowed', target)
50 | input_files = argv[1:]
51 | count = 0
52 | id_class = []
53 | class_id = defaultdict(list)
54 | print('Computing class distribution')
55 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10)
56 | for it in dataset:
57 | with tqdm(leave=False) as t:
58 | for i in it:
59 | id_class.append(i.numpy())
60 | class_id[i.numpy()].append(count)
61 | count += 1
62 | t.update(it.shape[0])
63 | print('%d records found' % count)
64 | nclass = len(class_id)
65 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
66 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64)
67 | train_stats /= train_stats.max()
68 | if 'stl10' in argv[1]:
69 | # All of the unlabeled data is given label 0, but we know that
70 | # STL has equally distributed data among the 10 classes.
71 | train_stats[:] = 1
72 |
73 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats]))
74 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
75 | class_id = [np.array(class_id[i], dtype=np.int64) for i in range(nclass)]
76 | if FLAGS.seed:
77 | np.random.seed(FLAGS.seed)
78 | for i in range(nclass):
79 | np.random.shuffle(class_id[i])
80 |
81 | # Distribute labels to match the input distribution.
82 | npos = np.zeros(nclass, np.int64)
83 | label = []
84 | for i in range(FLAGS.size):
85 | c = np.argmax(train_stats - npos / max(npos.max(), 1))
86 | label.append(class_id[c][npos[c]])
87 | npos[c] += 1
88 |
89 | del npos, class_id
90 | label = frozenset([int(x) for x in label])
91 | if 'stl10' in argv[1] and FLAGS.size == 1000:
92 | data = tf.io.gfile.GFile(os.path.join(core.DATA_DIR, 'stl10_fold_indices.txt'), 'r').read()
93 | label = frozenset(list(map(int, data.split('\n')[FLAGS.seed].split())))
94 |
95 | print('Creating split in %s' % target)
96 | tf.io.gfile.makedirs(os.path.dirname(target))
97 | with tf.io.TFRecordWriter(target + '-label.tfrecord') as writer_label:
98 | pos, loop = 0, trange(count, desc='Writing records')
99 | for input_file in input_files:
100 | for record in tf.compat.v1.python_io.tf_record_iterator(input_file):
101 | if pos in label:
102 | writer_label.write(record)
103 | pos += 1
104 | loop.update()
105 | loop.close()
106 | with tf.io.gfile.GFile(target + '-label.json', 'w') as writer:
107 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), label=sorted(label)), indent=2, sort_keys=True))
108 |
109 |
110 | if __name__ == '__main__':
111 | app.run(main)
112 |
--------------------------------------------------------------------------------
/examples/fixmatch/scripts/create_unlabeled.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Script to create SSL splits from a dataset.
18 | """
19 |
20 | import json
21 | import os
22 | from collections import defaultdict
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from absl import app
27 | from tqdm import trange, tqdm
28 |
29 |
30 | def get_class(serialized_example):
31 | return tf.io.parse_single_example(serialized_example,
32 | features={'label': tf.io.FixedLenFeature([], tf.int64)})['label']
33 |
34 |
35 | def main(argv):
36 | argv.pop(0)
37 | if any(not tf.io.gfile.exists(f) for f in argv[1:]):
38 | raise FileNotFoundError(argv[1:])
39 | target = argv[0]
40 | input_files = argv[1:]
41 | count = 0
42 | id_class = []
43 | class_id = defaultdict(list)
44 | print('Computing class distribution')
45 | dataset = tf.data.TFRecordDataset(input_files).map(get_class, 4).batch(1 << 10)
46 | for it in dataset:
47 | with tqdm(leave=False) as t:
48 | for i in it:
49 | id_class.append(i.numpy())
50 | class_id[i.numpy()].append(count)
51 | count += 1
52 | t.update(it.shape[0])
53 | print('%d records found' % count)
54 | nclass = len(class_id)
55 | assert min(class_id.keys()) == 0 and max(class_id.keys()) == (nclass - 1)
56 | train_stats = np.array([len(class_id[i]) for i in range(nclass)], np.float64)
57 | train_stats /= train_stats.max()
58 | if 'stl10' in argv[1]:
59 | # All of the unlabeled data is given label 0, but we know that
60 | # STL has equally distributed data among the 10 classes.
61 | train_stats[:] = 1
62 |
63 | print(' Stats', ' '.join(['%.2f' % (100 * x) for x in train_stats]))
64 | del class_id
65 |
66 | print('Creating unlabeled dataset for in %s' % target)
67 | npos = np.zeros(nclass, np.int64)
68 | class_data = [[] for _ in range(nclass)]
69 | unlabel = []
70 | tf.io.gfile.makedirs(os.path.dirname(target))
71 | with tf.io.TFRecordWriter(target + '-unlabel.tfrecord') as writer_unlabel:
72 | pos, loop = 0, trange(count, desc='Writing records')
73 | for input_file in input_files:
74 | for record in tf.compat.v1.python_io.tf_record_iterator(input_file):
75 | class_data[id_class[pos]].append((pos, record))
76 | while True:
77 | c = np.argmax(train_stats - npos / max(npos.max(), 1))
78 | if class_data[c]:
79 | p, v = class_data[c].pop(0)
80 | unlabel.append(p)
81 | writer_unlabel.write(v)
82 | npos[c] += 1
83 | else:
84 | break
85 | pos += 1
86 | loop.update()
87 | for remain in class_data:
88 | for p, v in remain:
89 | unlabel.append(p)
90 | writer_unlabel.write(v)
91 | loop.close()
92 | with tf.io.gfile.GFile(target + '-unlabel.json', 'w') as writer:
93 | writer.write(json.dumps(dict(distribution=train_stats.tolist(), indexes=unlabel), indent=2, sort_keys=True))
94 |
95 |
96 | if __name__ == '__main__':
97 | app.run(main)
98 |
--------------------------------------------------------------------------------
/examples/fixmatch/scripts/extract_accuracy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Extract and save accuracy to 'stats/accuracy.json'.
18 |
19 | The accuracy is extracted from the most recent eventfile.
20 | """
21 |
22 | import json
23 | import os
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 | from absl import app, flags
28 |
29 | FLAGS = flags.FLAGS
30 | TAG = 'accuracy/test'
31 |
32 |
33 | def summary_dict(accuracies):
34 | return {'last%02d' % x: np.median(accuracies[-x:]) for x in [1, 10, 20, 50]}
35 |
36 |
37 | def main(argv):
38 | if len(argv) > 2:
39 | raise app.UsageError('Too many command-line arguments.')
40 | folder = argv[1]
41 | matches = sorted(tf.io.gfile.glob(os.path.join(folder, 'tb/events.out.tfevents.*')))
42 | assert matches, 'No events files found'
43 | tags = set()
44 | accuracies = []
45 | for event_file in matches:
46 | try:
47 | for e in tf.compat.v1.train.summary_iterator(event_file):
48 | for v in e.summary.value:
49 | if v.tag == TAG:
50 | accuracies.append(v.simple_value)
51 | break
52 | elif not accuracies:
53 | tags.add(v.tag)
54 | except tf.errors.DataLossError:
55 | continue
56 |
57 | assert accuracies, 'No "%s" tag found. Found tags = %s' % (TAG, tags)
58 | target_dir = os.path.join(folder, 'stats')
59 | target_file = os.path.join(target_dir, 'accuracy.json')
60 | tf.io.gfile.makedirs(target_dir)
61 |
62 | with tf.io.gfile.GFile(target_file, 'w') as f:
63 | json.dump(summary_dict(accuracies), f, sort_keys=True, indent=4)
64 | print('Saved: %s' % target_file)
65 |
66 |
67 | if __name__ == '__main__':
68 | app.run(main)
69 |
--------------------------------------------------------------------------------
/examples/gpt-2/README.md:
--------------------------------------------------------------------------------
1 | # GPT-2 Example usage
2 |
3 | ## Setup
4 |
5 | ```bash
6 | cd examples/gpt-2
7 |
8 | # Install gpt-2 dependencies
9 | pip3 install --upgrade regex
10 |
11 | # Clone the OpenAI GPT-2 repository
12 | git clone https://github.com/openai/gpt-2.git
13 |
14 | # Download model weights
15 | cd gpt-2
16 | python3 download_model.py 124M
17 | cd ..
18 | ```
19 |
20 | ## Running
21 |
22 | ```bash
23 | python3 gpt2.py
24 | ```
25 |
26 | You should see something like this:
27 |
28 | > The definition of unicorn is a creature that is a unicorn, but not a
29 | >
30 | > All Together Now (all together now!)
31 | >
32 | > The definition of unicorn is a creature that is a unicorn, but not a
33 |
--------------------------------------------------------------------------------
/examples/image_classification/README.md:
--------------------------------------------------------------------------------
1 | [home](../../README.md) > [examples](../README.md) > image_classification
2 |
3 | # Image Classification Examples
4 |
5 | This directory contains various classification examples on image datasets:
6 |
7 | * [MNIST](http://yann.lecun.com/exdb/mnist/):
8 |
9 | * `mnist_dnn.py` - simple MNIST classification example.
10 | *Note*: The purpose of the example on MNIST is to demonstrate the use of a deep
11 | neural network for classification. As such, the network does not achieve State
12 | of the Art (SOTA) classification accurary. A Convolutional Neural Network (CNN)
13 | should be used for that purpose.
14 |
15 | * `mnist_cnn.py` - a CNN-based MNIST classification example.
16 |
17 | * `mnist_dp.py` - MNIST example with differential privacy.
18 |
19 | * [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)
20 |
21 | * `cifar10_simple.py` - very simple CIFAR10 classification example which
22 | demonstrated how to write basic training loop with data augmentation
23 |
24 | * `cifar10_advanced.py` - more advanced CIFAR10 example which allows user to configure
25 | neural network architecture and other hyperparameters. It also supports training on multiple
26 | GPUs using `objax.Parallel`.
27 |
28 | * [Imagenet](http://www.image-net.org/challenges/LSVRC/2012/)
29 |
30 | * `imagenet_pretrained_vgg.py` - example which shows how to load pre-trained weights for a VGG model and use it
31 | to classify input images. For more details see [documentation](imagenet_pretrained_vgg.md).
32 |
33 | * `imagenet_resnet50_train.py` - example which shows how to train Resnet50 model on Imagenet.
34 | For more details see example [documentation](imagenet_resnet50.md).
35 |
36 | * [Horses or Humans](https://www.kaggle.com/sanikamal/horses-or-humans-dataset)
37 |
38 | * `horses_or_humans_logistic.py` - simple example using logistic regression.
--------------------------------------------------------------------------------
/examples/image_classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/image_classification/cifar10_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import random
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | import objax
21 | from objax.zoo.wide_resnet import WideResNet
22 |
23 | # Data
24 | (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
25 | X_train = X_train.transpose(0, 3, 1, 2) / 255.0
26 | X_test = X_test.transpose(0, 3, 1, 2) / 255.0
27 |
28 | # Model
29 | model = WideResNet(nin=3, nclass=10, depth=28, width=2)
30 | opt = objax.optimizer.Adam(model.vars())
31 |
32 |
33 | # Losses
34 | @objax.Function.with_vars(model.vars())
35 | def loss(x, label):
36 | logit = model(x, training=True)
37 | return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
38 |
39 |
40 | gv = objax.GradValues(loss, model.vars())
41 |
42 |
43 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars())
44 | def train_op(x, y, lr):
45 | g, v = gv(x, y)
46 | opt(lr=lr, grads=g)
47 | return v
48 |
49 |
50 | train_op = objax.Jit(train_op)
51 | predict = objax.Jit(objax.nn.Sequential([
52 | objax.ForceArgs(model, training=False), objax.functional.softmax
53 | ]))
54 |
55 |
56 | def augment(x):
57 | if random.random() < .5:
58 | x = x[:, :, :, ::-1] # Flip the batch images about the horizontal axis
59 | # Pixel-shift all images in the batch by up to 4 pixels in any direction.
60 | x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
61 | rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
62 | x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
63 | return x
64 |
65 |
66 | # Training
67 | print(model.vars())
68 | for epoch in range(30):
69 | # Train
70 | loss = []
71 | sel = np.arange(len(X_train))
72 | np.random.shuffle(sel)
73 | for it in range(0, X_train.shape[0], 64):
74 | loss.append(train_op(augment(X_train[sel[it:it + 64]]), Y_train[sel[it:it + 64]].flatten(),
75 | 4e-3 if epoch < 20 else 4e-4))
76 |
77 | # Eval
78 | test_predictions = [predict(x_batch).argmax(1) for x_batch in X_test.reshape((50, -1) + X_test.shape[1:])]
79 | accuracy = np.array(test_predictions).flatten() == Y_test.flatten()
80 | print(f'Epoch {epoch + 1:4d} Loss {np.mean(loss):.2f} Accuracy {100 * np.mean(accuracy):.2f}')
81 |
--------------------------------------------------------------------------------
/examples/image_classification/horses_or_humans_logistic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import numpy as np
18 | import tensorflow_datasets as tfds
19 |
20 | import objax
21 | from objax.util import EasyDict
22 |
23 | # Data: train has 1027 images - test has 256 images
24 | # Each image is 300 x 300 x 3 bytes
25 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
26 | data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR))
27 |
28 |
29 | def prepare(x, downscale=3):
30 | """Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them."""
31 | s = x.shape
32 | x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3]))
33 | return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1
34 |
35 |
36 | train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
37 | test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
38 | ndim = train.image.shape[-1]
39 | del data
40 |
41 | # Settings
42 | lr = 0.0001 # learning rate
43 | batch = 256
44 | epochs = 20
45 |
46 | # Model
47 | model = objax.nn.Linear(ndim, 1)
48 | opt = objax.optimizer.SGD(model.vars())
49 | print(model.vars())
50 |
51 |
52 | # Cross Entropy Loss
53 | @objax.Function.with_vars(model.vars())
54 | def loss(x, label):
55 | return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()
56 |
57 |
58 | gv = objax.GradValues(loss, model.vars())
59 |
60 |
61 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars())
62 | def train_op(x, label):
63 | g, v = gv(x, label) # returns gradients, loss
64 | opt(lr, g)
65 | return v
66 |
67 |
68 | # This line is optional: it is compiling the code to make it faster.
69 | train_op = objax.Jit(train_op)
70 |
71 | # Training
72 | for epoch in range(epochs):
73 | # Train
74 | avg_loss = 0
75 | for it in range(0, train.image.shape[0], batch):
76 | sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])
77 | avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * batch
78 | avg_loss /= it + batch
79 |
80 | # Eval
81 | accuracy = 0
82 | for it in range(0, test.image.shape[0], batch):
83 | x, y = test.image[it: it + batch], test.label[it: it + batch]
84 | accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
85 | accuracy /= test.image.shape[0]
86 | print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
87 |
--------------------------------------------------------------------------------
/examples/image_classification/imagenet_pretrained_vgg.md:
--------------------------------------------------------------------------------
1 | # Image Classification with Pretrained VGG model
2 |
3 | This [example](pretrained_vgg.py) demonstrates how to run image classification with
4 | [VGG-19](https://www.robots.ox.ac.uk/~vgg/publications/2015/Simonyan15/simonyan15.pdf) model using
5 | weights pretrained on [ImageNet dataset](http://www.image-net.org/).
6 |
7 | ## Getting weights of VGG-19 pretrained model
8 |
9 | Please download the weights of VGG-19 pretrained model from this
10 | [link](https://mega.nz/file/xZ8glS6J#MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) and copy to
11 | `./objax/zoo/pretrained/vgg19.npy`.
12 |
13 | ## Classifying images
14 |
15 | This [example](pretrained_vgg.py) shows how to classifying an image downloaded from the internet.
16 | You can set an `IMAGE_PATH` to classify your own image.
17 |
--------------------------------------------------------------------------------
/examples/image_classification/imagenet_pretrained_vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from urllib import request
17 |
18 | import jax.numpy as jn
19 | import numpy as np
20 | from PIL import Image
21 |
22 | import objax
23 | from objax.zoo import vgg
24 |
25 | IMAGE_URL = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg'
26 | IMAGE_PATH = './examples/classify/img/misc/001.jpg'
27 | SYNSET_PATH = './objax/zoo/pretrained/synset.txt'
28 |
29 | # Load input image.
30 | if not os.path.exists(os.path.dirname(IMAGE_PATH)):
31 | os.makedirs(os.path.dirname(IMAGE_PATH))
32 | request.urlretrieve(IMAGE_URL, IMAGE_PATH)
33 | img = Image.open(IMAGE_PATH)
34 | img = np.array(img.resize((224, 224))).astype(np.float32)
35 | img = jn.array(img).transpose((2, 0, 1))[None,]
36 |
37 | # Load model with pretrained weights and make a prediction.
38 | model = vgg.VGG19(pretrained=True)
39 | logit = model(img)
40 | prob = objax.functional.softmax(logit)[0]
41 |
42 | # Present prediction output.
43 | synset = [l.strip() for l in open(SYNSET_PATH).readlines()]
44 | pred = jn.argsort(prob)[::-1][:5]
45 | for i in range(5):
46 | print('Top {:d} (prob {:.3f}) {}'.format(i + 1, prob[pred[i]], synset[pred[i]]))
47 |
--------------------------------------------------------------------------------
/examples/image_classification/imagenet_resnet50.md:
--------------------------------------------------------------------------------
1 | # Example of training and evaluation of ResNet50 on Imagenet
2 |
3 | This example trains a ResNet50 model on the ImageNet2012 dataset.
4 |
5 | ## Getting data
6 |
7 | You have to obtain the Imagenet dataset to train the model.
8 |
9 | Internally this code uses [TFDS](https://github.com/tensorflow/datasets) which will show download instructions on the first run.
10 | Run `python examples/image_classification/imagenet_resnet50_train.py` and you will see download instructions, similar to the following:
11 |
12 | ```
13 | AssertionError: Manual directory /home/${USER}/tensorflow_datasets/downloads/manual does not exist or is empty. Create it and download/extract dataset artifacts in there. Additional instructions: manual_dir should contain two files: ILSVRC2012_img_train.tar and
14 | ILSVRC2012_img_val.tar.
15 | ```
16 |
17 | You have to download data from http://www.image-net.org/download-images and then put it into
18 | the directory mentioned in the message.
19 | On the next run, run `imagenet_resnet50_train.py` which will process the data and rearrange it inside the data directory which might take a while.
20 | Subsequent runs will re-use the already downloaded data.
21 |
22 | You can override TFDS data directory by providing the `--tfds_data_dir` flag. This might be useful if you don't have enough disk space in the default location or already have a copy of Imagenet data somewhere else.
23 |
24 | ## Training the model
25 |
26 | Use the following command to train:
27 |
28 | ```
29 | python examples/classify/img/imagenet/imagenet_train.py \
30 | --model_dir="${HOME}/experiments/resnet50"
31 | ```
32 |
33 | Some additional useful flags include the following:
34 |
35 | * `--train_device_batch_size` controls per-device training batch size. You may need to adjust it if you don't have enough GPU memory.
36 | * `--eval_device_batch_size` controls per-device evaluation batch size. You may need to adjust it if you don't have enough GPU memory.
37 | * `--eval_every_n_steps` controls the number of training steps between evaluation and checkpointing.
38 | * `--tfds_data_dir` overrides the directory where TFDS looks for datasets.
39 |
--------------------------------------------------------------------------------
/examples/image_classification/mnist_cnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import numpy as np
18 | import tensorflow_datasets as tfds
19 | from tqdm import trange
20 |
21 | import objax
22 | from objax.util import EasyDict
23 |
24 |
25 | def simple_net_block(nin, nout):
26 | return objax.nn.Sequential([
27 | objax.nn.Conv2D(nin, nout, k=3), objax.functional.leaky_relu,
28 | objax.functional.max_pool_2d,
29 | objax.nn.Conv2D(nout, nout, k=3), objax.functional.leaky_relu,
30 | ])
31 |
32 |
33 | class SimpleNet(objax.Module):
34 | def __init__(self, nclass, colors, n):
35 | self.pre_conv = objax.nn.Sequential([objax.nn.Conv2D(colors, n, k=3), objax.functional.leaky_relu])
36 | self.block1 = simple_net_block(1 * n, 2 * n)
37 | self.block2 = simple_net_block(2 * n, 4 * n)
38 | self.post_conv = objax.nn.Conv2D(4 * n, nclass, k=3)
39 |
40 | def __call__(self, x, training=False): # x = (batch, colors, height, width)
41 | y = self.pre_conv(x)
42 | y = self.block1(y)
43 | y = self.block2(y)
44 | logits = self.post_conv(y).mean((2, 3)) # logits = (batch, nclass)
45 | if training:
46 | return logits
47 | return objax.functional.softmax(logits)
48 |
49 |
50 | # Data
51 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
52 | data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
53 | train = EasyDict(image=data['train']['image'].transpose(0, 3, 1, 2) / 255, label=data['train']['label'])
54 | test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 255, label=data['test']['label'])
55 | del data
56 |
57 |
58 | def augment(x, shift=4): # Shift all images in the batch by up to "shift" pixels in any direction.
59 | x_pad = np.pad(x, [[0, 0], [0, 0], [shift, shift], [shift, shift]])
60 | rx, ry = np.random.randint(0, shift, size=2)
61 | return x_pad[:, :, rx:rx + 28, ry:ry + 28]
62 |
63 |
64 | # Settings
65 | batch = 512
66 | test_batch = 2048
67 | weight_decay = 0.0001
68 | epochs = 40
69 | lr = 0.0004 * (batch / 64)
70 | train_size = train.image.shape[0]
71 |
72 | # Model
73 | model = SimpleNet(nclass=10, colors=1, n=16) # Use higher values of n to get higher accuracy.
74 | model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999, debias=True)
75 | opt = objax.optimizer.Adam(model.vars())
76 |
77 |
78 | @objax.Function.with_vars(model.vars())
79 | def loss(x, y):
80 | logits = model(x, training=True)
81 | loss_xe = objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
82 | loss_l2 = 0.5 * sum((v.value ** 2).sum() for k, v in model.vars().items() if k.endswith('.w'))
83 | return loss_xe + weight_decay * loss_l2, {'loss/xe': loss_xe, 'loss/l2': loss_l2}
84 |
85 |
86 | gv = objax.GradValues(loss, model.vars())
87 |
88 |
89 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars())
90 | def train_op(x, y):
91 | g, v = gv(x, y)
92 | opt(lr, g)
93 | model_ema.update_ema()
94 | return v
95 |
96 |
97 | train_op = objax.Jit(train_op) # Compile train_op to make it run faster.
98 | predict = objax.Jit(model_ema) # Compile predict to make it run faster.
99 |
100 | # Training
101 | print(model.vars())
102 | for epoch in range(epochs):
103 | # Train one epoch
104 | loop = trange(0, train_size, batch,
105 | leave=False, unit='img', unit_scale=batch,
106 | desc='Epoch %d/%d ' % (1 + epoch, epochs))
107 | for it in loop:
108 | sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])
109 | v = train_op(augment(train.image[sel]), train.label[sel])
110 |
111 | # Eval
112 | accuracy = 0
113 | for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'):
114 | x = test.image[it: it + test_batch]
115 | xl = test.label[it: it + test_batch]
116 | accuracy += (np.argmax(predict(x), axis=1) == xl).sum()
117 | accuracy /= test.image.shape[0]
118 | print(f'Epoch {epoch + 1:04d} Accuracy {100 * accuracy:.2f}')
119 |
--------------------------------------------------------------------------------
/examples/image_classification/mnist_dnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # *Note*: The purpose of the example on MNIST is to demonstrate the use of a deep
16 | # neural network for classification. As such, the network does not achieve State
17 | # of the Art (SOTA) classification accurary. A Convolutional Neural Network (CNN)
18 | # should be used for that purpose.
19 |
20 | import os
21 |
22 | import numpy as np
23 | import tensorflow_datasets as tfds
24 | from tqdm import trange
25 |
26 | import objax
27 | from objax.functional import leaky_relu, one_hot
28 | from objax.jaxboard import SummaryWriter, Summary
29 | from objax.util import EasyDict
30 | from objax.zoo.dnnet import DNNet
31 |
32 | # Data
33 | DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
34 | data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
35 | train_size = len(data['train']['image'])
36 | test_size = len(data['test']['image'])
37 | train_shape = data['train']['image'].shape
38 | image_size = train_shape[1] * train_shape[2] * train_shape[3]
39 | nclass = len(np.unique(data['train']['label']))
40 | flat_train_images = np.reshape(data['train']['image'].transpose(0, 3, 1, 2) / 127.5 - 1,
41 | (train_size, image_size))
42 | flat_test_images = np.reshape(data['test']['image'].transpose(0, 3, 1, 2) / 127.5 - 1, (test_size, image_size))
43 | test = EasyDict(image=flat_test_images, label=data['test']['label'])
44 | train = EasyDict(image=flat_train_images, label=data['train']['label'])
45 | del data
46 |
47 | # Settings
48 | lr = 0.0002
49 | batch = 64
50 | num_train_epochs = 40
51 | dnn_layer_sizes = image_size, 128, 10
52 | logdir = f'experiments/classify/img/mnist/filters{dnn_layer_sizes}'
53 |
54 | # Model
55 | model = DNNet(dnn_layer_sizes, leaky_relu)
56 | model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999)
57 | opt = objax.optimizer.Adam(model.vars())
58 |
59 |
60 | @objax.Function.with_vars(model.vars())
61 | def loss(x, label):
62 | logit = model(x)
63 | return objax.functional.loss.cross_entropy_logits(logit, label).mean()
64 |
65 |
66 | gv = objax.GradValues(loss, model.vars())
67 |
68 |
69 | @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars())
70 | def train_op(x, xl):
71 | g, v = gv(x, xl) # returns gradients, loss
72 | opt(lr, g)
73 | model_ema.update_ema()
74 | return v
75 |
76 |
77 | train_op = objax.Jit(train_op) # Compile train_op to make it run faster.
78 | predict = objax.Jit(model_ema)
79 |
80 | # Training
81 | print(model.vars())
82 | print(f'Visualize results with: tensorboard --logdir "{logdir}"')
83 | print("Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead.")
84 | with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
85 | for epoch in range(num_train_epochs):
86 | # Train one epoch
87 | summary = Summary()
88 | loop = trange(0, train_size, batch,
89 | leave=False, unit='img', unit_scale=batch,
90 | desc='Epoch %d/%d' % (1 + epoch, num_train_epochs))
91 | for it in loop:
92 | sel = np.random.randint(size=(batch,), low=0, high=train_size)
93 | x, xl = train.image[sel], train.label[sel]
94 | xl = one_hot(xl, nclass)
95 | v = train_op(x, xl)
96 | summary.scalar('losses/xe', float(v[0]))
97 |
98 | # Eval
99 | accuracy = 0
100 | for it in trange(0, test.image.shape[0], batch, leave=False, desc='Evaluating'):
101 | x = test.image[it: it + batch]
102 | xl = test.label[it: it + batch]
103 | accuracy += (np.argmax(predict(x), axis=1) == xl).sum()
104 | accuracy /= test.image.shape[0]
105 | summary.scalar('eval/accuracy', 100 * accuracy)
106 | print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']()))
107 |
108 | tensorboard.write(summary, step=(epoch + 1) * train_size)
109 |
--------------------------------------------------------------------------------
/examples/image_classification/tfdata/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/objax/9dd7dc37e5f9d0ea71896636d3e180440b2b729e/examples/image_classification/tfdata/__init__.py
--------------------------------------------------------------------------------
/examples/image_classification/tfdata/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Callable, Optional, Tuple, List
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 |
21 | def record_parse(serialized_example: str, image_shape: Tuple[int, int, int]):
22 | features = tf.io.parse_single_example(serialized_example,
23 | features={'image': tf.io.FixedLenFeature([], tf.string),
24 | 'label': tf.io.FixedLenFeature([], tf.int64)})
25 | image = tf.image.decode_image(features['image']).set_shape(image_shape)
26 | image = tf.cast(image, tf.float32) * (2.0 / 255) - 1.0
27 | return dict(image=image, label=features['label'])
28 |
29 |
30 | class DataSet:
31 | """Wrapper for tf.data.Dataset to permit extensions."""
32 |
33 | def __init__(self, data: tf.data.Dataset,
34 | image_shape: Tuple[int, int, int],
35 | augment_fn: Optional[Callable] = None,
36 | parse_fn: Optional[Callable] = record_parse):
37 | self.data = data
38 | self.parse_fn = parse_fn
39 | self.augment_fn = augment_fn
40 | self.image_shape = image_shape
41 |
42 | @classmethod
43 | def from_arrays(cls, images: np.ndarray, labels: np.ndarray, augment_fn: Optional[Callable] = None):
44 | return cls(tf.data.Dataset.from_tensor_slices(dict(image=images, label=labels)), images.shape[1:],
45 | augment_fn=augment_fn, parse_fn=None)
46 |
47 | @classmethod
48 | def from_files(cls, filenames: List[str],
49 | image_shape: Tuple[int, int, int],
50 | augment_fn: Optional[Callable],
51 | parse_fn: Optional[Callable] = record_parse):
52 | filenames_in = filenames
53 | filenames = sorted(sum([tf.io.gfile.glob(x) for x in filenames], []))
54 | if not filenames:
55 | raise ValueError('Empty dataset, files not found:', filenames_in)
56 | return cls(tf.data.TFRecordDataset(filenames), image_shape, augment_fn=augment_fn, parse_fn=parse_fn)
57 |
58 | @classmethod
59 | def from_tfds(cls, dataset: tf.data.Dataset, image_shape: Tuple[int, int, int],
60 | augment_fn: Optional[Callable] = None):
61 | return cls(dataset.map(lambda x: dict(image=tf.cast(x['image'], tf.float32) / 127.5 - 1, label=x['label'])),
62 | image_shape, augment_fn=augment_fn, parse_fn=None)
63 |
64 | def __iter__(self):
65 | return iter(self.data)
66 |
67 | def __getattr__(self, item):
68 | if item in self.__dict__:
69 | return self.__dict__[item]
70 |
71 | def call_and_update(*args, **kwargs):
72 | v = getattr(self.__dict__['data'], item)(*args, **kwargs)
73 | if isinstance(v, tf.data.Dataset):
74 | return self.__class__(v, self.image_shape, augment_fn=self.augment_fn, parse_fn=self.parse_fn)
75 | return v
76 |
77 | return call_and_update
78 |
79 | def augment(self, para_augment: int = 4):
80 | if self.augment_fn:
81 | return self.map(self.augment_fn, para_augment)
82 | return self
83 |
84 | def nchw(self):
85 | return self.map(lambda x: dict(image=tf.transpose(x['image'], [0, 3, 1, 2]), label=x['label']))
86 |
87 | def one_hot(self, nclass: int):
88 | return self.map(lambda x: dict(image=x['image'], label=tf.one_hot(x['label'], nclass)))
89 |
90 | def parse(self, para_parse: int = 2):
91 | if not self.parse_fn:
92 | return self
93 | if self.image_shape:
94 | return self.map(lambda x: self.parse_fn(x, self.image_shape), para_parse)
95 | return self.map(self.parse_fn, para_parse)
96 |
--------------------------------------------------------------------------------
/examples/jaxboard/README.md:
--------------------------------------------------------------------------------
1 | [home](../../README.md) > [examples](../README.md) > optimization
2 |
3 | # Saving to tensorboard
4 |
5 | This directory contains examples on how to visualize data in tensorboard with Objax.
6 |
7 | ```bash
8 | python3 examples/jaxboard/summary.py
9 | tensorboard --logdir experiments/summary_test/tb
10 | ```
11 |
--------------------------------------------------------------------------------
/examples/jaxboard/summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 |
17 | import objax
18 |
19 | LOGDIR = 'experiments/summary_test/tb'
20 | with objax.jaxboard.SummaryWriter(LOGDIR) as tensorboard:
21 | summary = objax.jaxboard.Summary()
22 | summary.text('text', 'Hello this just text\nand a newline
')
23 | summary.text('html', 'col1 | col2 |
'
24 | 'row1.1 | row1.2 |
'
25 | 'row2.1 | row2.2 |
')
26 | img = np.zeros((3, 32, 32), 'f')
27 | img[0] += np.linspace(-1, 1, 32)
28 | img[1] += np.linspace(-1, 1, 32)[:, None]
29 | img[2] += np.linspace(-1, 1, 32)[:, None] * np.linspace(-1, 1, 32)
30 | summary.image('image', img)
31 | summary.scalar('avg', 0)
32 | summary.scalar('avg', 1)
33 | summary.scalar('avg', 2)
34 | summary.scalar('avg', 3)
35 | tensorboard.write(summary, step=1)
36 |
37 | summary = objax.jaxboard.Summary()
38 | summary.text('text', 'Hello this just text\nat step 2
')
39 | summary.scalar('avg', 4)
40 | summary.scalar('avg', 7)
41 | tensorboard.write(summary, step=2)
42 |
43 | print(f'Saved to {LOGDIR}')
44 | print(f'Visualize with: tensorboard --logdir "{LOGDIR}"')
45 |
--------------------------------------------------------------------------------
/examples/maml/README.md:
--------------------------------------------------------------------------------
1 | [home](../../README.md) > [examples](../README.md) > optimization
2 |
3 | # Optimization
4 |
5 | This directory contains examples on Model-Agnostic Meta-Learning (MAML).
6 |
--------------------------------------------------------------------------------
/examples/maml/maml.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | MAML implementation to demonstrate gradient of gradient.
17 |
18 | https://github.com/ericjang/maml-jax/blob/master/maml.ipynb
19 | """
20 |
21 | import jax.numpy as jn
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | from tqdm import trange
25 |
26 | import objax
27 |
28 |
29 | def sample_tasks(outer_batch_size, inner_batch_size):
30 | # Select amplitude and phase for the task
31 | amplitudes = []
32 | phases = []
33 | for _ in range(outer_batch_size):
34 | amplitudes.append(np.random.uniform(low=0.1, high=.5))
35 | phases.append(np.random.uniform(low=0., high=np.pi))
36 |
37 | def get_batch():
38 | xs, ys = [], []
39 | for amplitude, phase in zip(amplitudes, phases):
40 | x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
41 | y = amplitude * np.sin(x + phase)
42 | xs.append(x)
43 | ys.append(y)
44 | return np.stack(xs), np.stack(ys)
45 |
46 | x1, y1 = get_batch()
47 | x2, y2 = get_batch()
48 | return x1, y1, x2, y2
49 |
50 |
51 | def make_net():
52 | return objax.nn.Sequential([
53 | objax.nn.Linear(1, 40), objax.functional.relu,
54 | objax.nn.Linear(40, 40), objax.functional.relu,
55 | objax.nn.Linear(40, 1)
56 | ])
57 |
58 |
59 | source = jn.linspace(-5, 5, 100).reshape((100, 1)) # (k, 1)
60 | target = jn.sin(source)
61 |
62 | print('Standard training.')
63 | net = make_net()
64 | opt = objax.optimizer.Adam(net.vars())
65 |
66 |
67 | @objax.Function.with_vars(net.vars())
68 | def loss(x, y):
69 | return ((y - net(x)) ** 2).mean()
70 |
71 |
72 | gv = objax.GradValues(loss, net.vars())
73 |
74 |
75 | @objax.Function.with_vars(net.vars() + opt.vars())
76 | def train_op():
77 | g, v = gv(source, target)
78 | opt(0.01, g)
79 | return v
80 |
81 |
82 | train_op = objax.Jit(train_op)
83 |
84 | for i in range(100):
85 | train_op()
86 |
87 | plt.plot(source, net(source), label='prediction')
88 | plt.plot(source, (target - net(source)) ** 2, label='loss')
89 | plt.plot(source, target, label='target')
90 | plt.legend()
91 | plt.show()
92 |
93 | print('MAML training')
94 | net = make_net()
95 | opt = objax.optimizer.Adam(net.vars())
96 |
97 |
98 | @objax.Function.with_vars(net.vars())
99 | def loss(x, y):
100 | return ((y - net(x)) ** 2).mean()
101 |
102 |
103 | gv = objax.GradValues(loss, net.vars())
104 |
105 |
106 | @objax.Function.with_vars(net.vars())
107 | def maml_loss(x1, y1, x2, y2, alpha=0.1):
108 | net_vars = net.vars()
109 | original_weights = net_vars.tensors() # Save original weights
110 | g_x1y1 = gv(x1, y1)[0] # Compute gradient at (x1, y1)
111 | # Apply gradient update using SGD
112 | net_vars.assign([v - alpha * g for v, g in zip(original_weights, g_x1y1)])
113 | loss_x2y2 = loss(x2, y2)
114 | net_vars.assign(original_weights) # Restore original weights
115 | return loss_x2y2
116 |
117 |
118 | vec_maml_loss = objax.Vectorize(maml_loss, batch_axis=(0, 0, 0, 0, None))
119 |
120 |
121 | @objax.Function.with_vars(vec_maml_loss.vars())
122 | def batch_maml_loss(x1, y1, x2, y2, alpha=0.1):
123 | return vec_maml_loss(x1, y1, x2, y2, alpha).mean()
124 |
125 |
126 | maml_gv = objax.GradValues(batch_maml_loss, vec_maml_loss.vars())
127 |
128 |
129 | @objax.Function.with_vars(vec_maml_loss.vars() + opt.vars())
130 | def train_op(x1, y1, x2, y2):
131 | g, v = maml_gv(x1, y1, x2, y2)
132 | opt(0.001, g)
133 | return v
134 |
135 |
136 | train_op = objax.Jit(train_op)
137 |
138 | for i in trange(20000, leave=False):
139 | x1, y1, x2, y2 = sample_tasks(4, 20)
140 | train_op(x1, y1, x2, y2)
141 |
142 | x1 = np.random.uniform(low=-5., high=5., size=(10, 1))
143 | y1 = 1. * np.sin(x1 + 0.)
144 |
145 | tensors = net.vars().tensors()
146 | for shot in range(1, 3):
147 | for v, g in zip(net.vars(), gv(x1, y1)[0]):
148 | if isinstance(v, objax.TrainVar):
149 | v.assign(v.value - 0.1 * g)
150 | plt.plot(source, net(source), label='%d-shot predictions' % shot)
151 | net.vars().assign(tensors)
152 |
153 | plt.plot(source, net(source), label='pre-update predictions')
154 | plt.plot(source, target, label='target')
155 | plt.legend()
156 | plt.show()
157 |
--------------------------------------------------------------------------------
/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow-cpu>=2.3.0
3 | tensorflow_datasets>=3.2.1
4 | tqdm
5 |
--------------------------------------------------------------------------------
/examples/text_generation/README.md:
--------------------------------------------------------------------------------
1 | [home](../../README.md) > [examples](../README.md) > text_generation
2 |
3 | # Examples
4 |
5 | This directory contains text generation examples.
6 |
7 | See:
8 | * `shakespeare_rnn.py` - predict characters from Shakespeare's plays using an RNN.
9 |
--------------------------------------------------------------------------------
/objax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 |
17 | from ._patch_jax import *
18 |
19 | pass # To avoid reordering imports from above
20 |
21 | from . import functional
22 | from . import io
23 | from . import jaxboard
24 | from . import nn
25 | from . import optimizer
26 | from . import privacy
27 | from . import random
28 | from . import typing
29 | from . import util
30 | from ._version import __version__
31 | from .constants import *
32 | from .gradient import *
33 | from .module import *
34 | from .variable import *
35 |
36 | assert sys.version_info >= (3, 6)
37 |
--------------------------------------------------------------------------------
/objax/_patch_jax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | __all__ = []
17 |
18 | from typing import Union, Sequence, Tuple, Callable, Optional
19 |
20 | import jax.numpy as jn
21 |
22 | from .typing import JaxArray
23 | from .util import re_sign
24 |
25 |
26 | def _pad(array: JaxArray,
27 | pad_width: Union[Sequence[Tuple[int, int]], Tuple[int, int], int],
28 | mode: Optional[Union[str, Callable]] = 'constant',
29 | *,
30 | stat_length: Optional[Union[Sequence[Tuple[int, int]], int]] = None,
31 | constant_values: Optional[Union[Sequence[Tuple[int, int]], int]] = 0,
32 | end_values: Optional[Union[Sequence[Tuple[int, int]], int]] = None,
33 | reflect_type: Optional[str] = None):
34 | # This is just to have a proper signature for jax.numpy.pad since the API, like in numpy, makes use of kwargs
35 | # and doesn't expose its arguments properly.
36 | pass
37 |
38 |
39 | jn.pad = re_sign(_pad)(jn.pad)
40 |
--------------------------------------------------------------------------------
/objax/_version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __version__ = '1.8.0'
16 |
--------------------------------------------------------------------------------
/objax/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['ConvPadding', 'Interpolate']
16 |
17 | import enum
18 |
19 |
20 | class ConvPadding(enum.Enum):
21 | """An Enum holding the possible padding values for convolution modules."""
22 | SAME = 'SAME'
23 | VALID = 'VALID'
24 |
25 |
26 | class Interpolate(enum.Enum):
27 | """An Enum holding the possible interpolation values for upsampling."""
28 | NEAREST = 'nearest'
29 | LINEAR = 'linear'
30 | BILINEAR = 'bilinear'
31 | TRILINEAR = 'trilinear'
32 | TRIANGLE = 'triangle'
33 | CUBIC = 'cubic'
34 | BICUBIC = 'bicubic'
35 | TRICUBIC = 'tricubic'
36 | LANCZOS3 = 'lanczos3'
37 | LANCZOS5 = 'lanczos5'
38 |
--------------------------------------------------------------------------------
/objax/functional/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import divergence
16 | from . import loss
17 | from . import parallel
18 | from .core import *
19 |
--------------------------------------------------------------------------------
/objax/functional/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .activation import *
16 | from .ops import *
17 | from .pooling import *
18 |
--------------------------------------------------------------------------------
/objax/functional/core/activation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['celu', 'elu', 'leaky_relu', 'log_sigmoid', 'log_softmax', 'logsumexp', 'relu',
16 | 'selu', 'sigmoid', 'softmax', 'softplus', 'swish', 'tanh']
17 |
18 | import jax.nn
19 | import jax.scipy.special
20 | from jax import lax
21 |
22 | from objax.typing import JaxArray
23 |
24 | celu = jax.nn.celu
25 | elu = jax.nn.elu
26 | leaky_relu = jax.nn.leaky_relu
27 | log_sigmoid = jax.nn.log_sigmoid
28 | log_softmax = jax.nn.log_softmax
29 | logsumexp = jax.scipy.special.logsumexp
30 | selu = jax.nn.selu
31 | sigmoid = jax.nn.sigmoid
32 | softmax = jax.nn.softmax
33 | softplus = jax.nn.softplus
34 | tanh = lax.tanh
35 | swish = jax.nn.swish
36 |
37 |
38 | # Have to redefine relu since jax.nn.relu isn't pickable.
39 | def relu(x: JaxArray) -> JaxArray:
40 | """Rectified linear unit activation function.
41 |
42 | Args:
43 | x: input tensor.
44 |
45 | Returns:
46 | tensor with the element-wise output relu(x) = max(x, 0).
47 | """
48 | return jax.nn.relu(x)
49 |
--------------------------------------------------------------------------------
/objax/functional/core/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | __all__ = ['dynamic_slice', 'flatten', 'interpolate', 'one_hot', 'pad', 'rsqrt', 'scan', 'stop_gradient',
17 | 'top_k', 'upsample_2d', 'upscale_nn']
18 |
19 | from typing import Union, Tuple
20 |
21 | import jax.nn
22 | from jax import numpy as jn, lax
23 |
24 | from objax import util
25 | from objax.constants import Interpolate
26 | from objax.typing import JaxArray
27 |
28 | dynamic_slice = lax.dynamic_slice
29 | one_hot = jax.nn.one_hot
30 | pad = jn.pad
31 | scan = lax.scan
32 | stop_gradient = lax.stop_gradient
33 | top_k = lax.top_k # Current code doesn't work with gradient.
34 | rsqrt = lax.rsqrt
35 |
36 |
37 | def flatten(x: JaxArray) -> JaxArray:
38 | """Flattens input tensor to a 2D tensor.
39 |
40 | Args:
41 | x: input tensor with dimensions (n_1, n_2, ..., n_k)
42 |
43 | Returns:
44 | The input tensor reshaped to two dimensions (n_1, n_prod),
45 | where n_prod is equal to the product of n_2 to n_k.
46 | """
47 | return x.reshape([x.shape[0], -1])
48 |
49 |
50 | def interpolate(input: JaxArray,
51 | size: Union[int, Tuple[int, ...]] = None,
52 | scale_factor: Union[int, Tuple[int, ...]] = None,
53 | mode: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray:
54 | """
55 | Function to interpolate JaxArrays by size or scaling factor
56 | Args:
57 | input: input tensor
58 | size: int or tuple for output size
59 | scale_factor: int or tuple scaling factor for each dimention
60 | mode:str or Interpolate interpolation method e.g. ['bilinear', 'nearest']
61 |
62 | Returns:
63 | output : output JaxArray after interpolation
64 | """
65 | assert size or scale_factor, f'both size: {size} and scale_factor: {scale_factor} can not be None .'
66 | assert bool(size) ^ bool(scale_factor), f'either size or scale_factor must be none ' \
67 | f'scale: {size}, scale_factor: {scale_factor} .'
68 | input_shape = input.shape
69 | input_dim = len(input_shape)
70 | if scale_factor:
71 | if isinstance(scale_factor, int):
72 | size = (input_shape[0], *(jn.array(input_shape[1:]) * scale_factor))
73 | if isinstance(scale_factor, Tuple):
74 | output_dim = len(scale_factor)
75 | size = (*input_shape[:input_dim - output_dim],
76 | *(jn.array(input_shape[input_dim - output_dim:]) * jn.array(scale_factor)))
77 | else:
78 | if isinstance(size, int):
79 | size = (*input_shape[:-1], size)
80 | if isinstance(size, Tuple):
81 | output_dim = len(size)
82 | assert input_dim >= output_dim, f'Number of dimensions of "{size}"' \
83 | f' must be < = to input shape"{input_shape}" '
84 | size = (*input_shape[:input_dim - output_dim], *size)
85 | output = jax.image.resize(input,
86 | shape=size,
87 | method=util.to_interpolate(mode))
88 | return output
89 |
90 |
91 | def upsample_2d(x: JaxArray,
92 | scale: Union[Tuple[int, int], int],
93 | method: Union[Interpolate, str] = Interpolate.BILINEAR) -> JaxArray:
94 | """Function to upscale 2D images.
95 |
96 | Args:
97 | x: input tensor.
98 | scale: int or tuple scaling factor
99 | method: str or UpSample interpolation methods e.g. ['bilinear', 'nearest'].
100 |
101 | Returns:
102 | upscaled 2d image tensor
103 | """
104 | s = x.shape
105 | assert len(s) == 4, f'{s} must have 4 dimensions to be upsampled, or you can try interpolate function.'
106 | scale = util.to_tuple(scale, 2)
107 | y = jax.image.resize(x.transpose([0, 2, 3, 1]),
108 | shape=(s[0], s[2] * scale[0], s[3] * scale[1], s[1]),
109 | method=util.to_interpolate(method))
110 | return y.transpose([0, 3, 1, 2])
111 |
112 |
113 | def upscale_nn(x: JaxArray, scale: int = 2) -> JaxArray:
114 | """Nearest neighbor upscale for image batches of shape (N, C, H, W).
115 |
116 | Args:
117 | x: input tensor of shape (N, C, H, W).
118 | scale: integer scaling factor.
119 |
120 | Returns:
121 | Output tensor of shape (N, C, H * scale, W * scale).
122 | """
123 | s = x.shape
124 | x = x.reshape(s[:2] + (s[2], 1, s[3], 1))
125 | x = jn.tile(x, (1, 1, 1, scale, 1, scale))
126 | return x.reshape(s[:2] + (scale * s[2], scale * s[3]))
127 |
--------------------------------------------------------------------------------
/objax/functional/divergence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['kl']
16 |
17 | import jax.numpy as jn
18 |
19 | from objax.typing import JaxArray
20 |
21 |
22 | def kl(p: JaxArray, q: JaxArray, eps: float = 2 ** -17) -> JaxArray:
23 | """Calculates the Kullback-Leibler divergence between arrays p and q."""
24 | return p.dot(jn.log(p + eps) - jn.log(q + eps))
25 |
--------------------------------------------------------------------------------
/objax/functional/parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['partial', 'pmax', 'pmean', 'pmin', 'psum']
16 |
17 | from functools import partial
18 |
19 | import jax
20 | from jax import lax
21 |
22 |
23 | def pmax(x: jax.Array, axis_name: str = 'device'):
24 | """Compute a multi-device reduce max on x over the device axis axis_name."""
25 | return lax.pmax(x, axis_name)
26 |
27 |
28 | def pmean(x: jax.Array, axis_name: str = 'device'):
29 | """Compute a multi-device reduce mean on x over the device axis axis_name."""
30 | return lax.pmean(x, axis_name)
31 |
32 |
33 | def pmin(x: jax.Array, axis_name: str = 'device'):
34 | """Compute a multi-device reduce min on x over the device axis axis_name."""
35 | return lax.pmin(x, axis_name)
36 |
37 |
38 | def psum(x: jax.Array, axis_name: str = 'device'):
39 | """Compute a multi-device reduce sum on x over the device axis axis_name."""
40 | return lax.psum(x, axis_name)
41 |
--------------------------------------------------------------------------------
/objax/io/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .checkpoint import *
16 | from .ops import *
17 |
--------------------------------------------------------------------------------
/objax/io/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['load_var_collection', 'save_var_collection']
16 |
17 | import collections
18 | import os
19 | from typing import IO, BinaryIO, Union, Optional
20 |
21 | import jax.numpy as jn
22 | import numpy as np
23 |
24 | from objax.util import Renamer
25 | from objax.variable import TrainRef, VarCollection
26 |
27 |
28 | def load_var_collection(file: Union[str, IO[BinaryIO]],
29 | vc: VarCollection,
30 | renamer: Optional[Renamer] = None):
31 | """Loads values of all variables in the given variables collection from file.
32 |
33 | Values loaded from file will replace old values in the variables collection.
34 | If variable exists in the file, but does not exist in the variables collection it will be ignored.
35 | If variable exists in the variables collection, but not found in the file then exception will be raised.
36 |
37 | Args:
38 | file: filename or python file handle of the input file.
39 | vc: variables collection which will be loaded from file.
40 | renamer: optional renamer to pre-process variables names from the file being read.
41 |
42 | Raises:
43 | ValueError: if variable from variables collection is not found in the input file.
44 | """
45 | renamer = renamer or (lambda x: x)
46 | do_close = isinstance(file, str)
47 | if do_close:
48 | file = open(file, 'rb')
49 | data = np.load(file, allow_pickle=False)
50 | name_index = {renamer(k): str(i) for i, k in enumerate(data['names'])}
51 | var_names = collections.defaultdict(list)
52 | var_values = {}
53 | for k, v in vc.items():
54 | if isinstance(v, TrainRef):
55 | v = v.ref
56 | var_names[id(v)].append(k)
57 | var_values[id(v)] = v
58 | misses = []
59 | used_vars = set()
60 | for var_id, names in var_names.items():
61 | v = var_values[var_id]
62 | for name in names:
63 | index = name_index.get(name)
64 | if index is not None:
65 | used_vars.add(name)
66 | try:
67 | v.assign(jn.array(data[index]))
68 | except AssertionError as e:
69 | raise AssertionError(f'Error when restoring variable {name}: {str(e)}') from None
70 | break
71 | else:
72 | misses += names
73 | if misses:
74 | not_used = set(name_index.keys()) - used_vars
75 | raise ValueError(f'Missing value for variables currently in the model: {misses}. '
76 | f'The following variables on disk were not used, '
77 | f'maybe the missing variable was renamed from one of these: {not_used}.')
78 | if do_close:
79 | file.close()
80 |
81 |
82 | def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
83 | """Saves variables collection into file.
84 |
85 | Args:
86 | file: filename or python file handle of the file where variables will be saved.
87 | vc: variables collection which will be saved into file.
88 | """
89 | do_close = isinstance(file, str)
90 | if do_close:
91 | filename, file = file, open(file + '.tmp', 'wb') # Save to a temporary in case the job is killed while saving.
92 | data, names, seen, replicated = {}, [], set(), []
93 | for k, v in vc.items():
94 | if isinstance(v, TrainRef):
95 | v = v.ref
96 | if id(v) not in seen:
97 | names.append(k)
98 | data[str(len(data))] = v.value
99 | seen.add(id(v))
100 | if replicated:
101 | print('Warning: When saving VarCollection, some variables were replicated on multiple devices.')
102 | print(' While it is valid, in most use cases it is more disk efficient to save variables outside of ')
103 | print(' vars().replicate().')
104 |
105 | def _disabled_seek(*_):
106 | raise AttributeError('seek() is disabled on this object.')
107 | _old_seek = getattr(file, 'seek')
108 | setattr(file, 'seek', _disabled_seek)
109 | np.savez(file, names=np.array(names), **data)
110 | setattr(file, 'seek', _old_seek)
111 | if do_close:
112 | file.close()
113 | os.rename(filename + '.tmp', filename) # Atomic rename to avoid broken file (when killed while saving).
114 |
--------------------------------------------------------------------------------
/objax/jaxboard.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import enum
16 | import os
17 | from time import time
18 | from typing import Union, Callable, Tuple, ByteString
19 |
20 | import numpy as np
21 | from tensorboard.compat.proto import event_pb2
22 | from tensorboard.compat.proto import summary_pb2
23 | from tensorboard.summary.writer.event_file_writer import EventFileWriter
24 | from tensorboard.util.tensor_util import make_tensor_proto
25 |
26 | from objax import util
27 |
28 |
29 | class Reducer(enum.Enum):
30 | """Reduces tensor batch into a single tensor."""
31 | FIRST = lambda x: x[0]
32 | LAST = lambda x: x[-1]
33 | MEAN = lambda x: np.mean(x)
34 |
35 |
36 | class DelayedScalar:
37 | def __init__(self, reduce: Union[Callable, Reducer]):
38 | self.values = []
39 | self.reduce = reduce
40 |
41 | def __call__(self):
42 | return self.reduce(self.values)
43 |
44 |
45 | class Image:
46 | def __init__(self, shape: Tuple[int, int, int], png: ByteString):
47 | self.shape = shape
48 | self.png = png
49 |
50 |
51 | class Text:
52 | def __init__(self, text: str):
53 | self.text = text
54 |
55 |
56 | class Summary(dict):
57 | """Writes entries to `Summary` protocol buffer."""
58 |
59 | def image(self, tag: str, image: np.ndarray):
60 | """Adds image to the summary. Float image in [-1, 1] in CHW format expected."""
61 | self[tag] = Image(image.shape, util.image.to_png(image))
62 |
63 | def scalar(self, tag: str, value: float, reduce: Union[Callable, Reducer] = Reducer.MEAN):
64 | """Adds scalar to the summary."""
65 | if tag not in self:
66 | self[tag] = DelayedScalar(reduce)
67 | self[tag].values.append(value)
68 |
69 | def text(self, tag: str, text: str):
70 | """Adds text to the summary."""
71 | self[tag] = Text(text)
72 |
73 | def __call__(self):
74 | entries = []
75 | for tag, value in self.items():
76 | if isinstance(value, DelayedScalar):
77 | entries.append(summary_pb2.Summary.Value(tag=tag, simple_value=value()))
78 | elif isinstance(value, Image):
79 | image_summary = summary_pb2.Summary.Image(encoded_image_string=value.png,
80 | colorspace=value.shape[0],
81 | height=value.shape[1],
82 | width=value.shape[2])
83 | entries.append(summary_pb2.Summary.Value(tag=tag, image=image_summary))
84 | elif isinstance(value, Text):
85 | metadata = summary_pb2.SummaryMetadata(
86 | plugin_data=summary_pb2.SummaryMetadata.PluginData(plugin_name='text'))
87 | entries.append(summary_pb2.Summary.Value(tag=tag, metadata=metadata,
88 | tensor=make_tensor_proto(values=value.text.encode('utf-8'),
89 | shape=(1,))))
90 | else:
91 | raise NotImplementedError(tag, value)
92 | return summary_pb2.Summary(value=entries)
93 |
94 |
95 | class SummaryWriter:
96 | """Writes entries to event files in the logdir to be consumed by Tensorboard."""
97 |
98 | def __init__(self, logdir: str, queue_size: int = 5, write_interval: int = 5):
99 | """Creates SummaryWriter instance.
100 |
101 | Args:
102 | logdir: directory where event file will be written.
103 | queue_size: size of the queue for pending events and summaries
104 | before one of the 'add' calls forces a flush to disk.
105 | write_interval: how often, in seconds, to write the pending events and summaries to disk.
106 | """
107 | if not os.path.isdir(logdir):
108 | os.makedirs(logdir, exist_ok=True)
109 |
110 | self.writer = EventFileWriter(logdir, queue_size, write_interval)
111 |
112 | def write(self, summary: Summary, step: int):
113 | """Adds on event to the event file."""
114 | self.writer.add_event(event_pb2.Event(step=step, summary=summary(), wall_time=time()))
115 |
116 | def close(self):
117 | """Flushes the event file to disk and close the file."""
118 | self.writer.close()
119 |
120 | def __enter__(self):
121 | return self
122 |
123 | def __exit__(self, exc_type, exc_val, exc_tb):
124 | self.close()
125 |
--------------------------------------------------------------------------------
/objax/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import init
16 | from .layers import *
17 |
--------------------------------------------------------------------------------
/objax/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .adam import *
16 | from .ema import *
17 | from .lars import *
18 | from .momentum import *
19 | from .sgd import *
20 | from . import scheduler
--------------------------------------------------------------------------------
/objax/optimizer/adam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['Adam']
16 |
17 | from typing import List, Optional
18 |
19 | from jax import numpy as jn
20 |
21 | from objax import functional
22 | from objax.module import Module, ModuleList
23 | from objax.typing import JaxArray
24 | from objax.util import class_name
25 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection
26 |
27 |
28 | class Adam(Module):
29 | """Adam optimizer."""
30 |
31 | def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
32 | """Constructor for Adam optimizer class.
33 |
34 | Args:
35 | vc: collection of variables to optimize.
36 | beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9.
37 | beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999.
38 | eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8.
39 | """
40 | self.beta1 = beta1
41 | self.beta2 = beta2
42 | self.eps = eps
43 | self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
44 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
45 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
46 | self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
47 |
48 | def __call__(self, lr: float, grads: List[JaxArray], beta1: Optional[float] = None, beta2: Optional[float] = None):
49 | """Updates variables and other state based on Adam algorithm.
50 |
51 | Args:
52 | lr: the learning rate.
53 | grads: the gradients to apply.
54 | beta1: optional, override the default beta1.
55 | beta2: optional, override the default beta2.
56 | """
57 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables'
58 | if beta1 is None:
59 | beta1 = self.beta1
60 | if beta2 is None:
61 | beta2 = self.beta2
62 | self.step.value += 1
63 | lr *= jn.sqrt(1 - beta2 ** self.step.value) / (1 - beta1 ** self.step.value)
64 | for g, p, m, v in zip(grads, self.train_vars, self.m, self.v):
65 | m.value += (1 - beta1) * (g - m.value)
66 | v.value += (1 - beta2) * (g ** 2 - v.value)
67 | p.value -= lr * m.value * functional.rsqrt(v.value + self.eps)
68 |
69 | def __repr__(self):
70 | return f'{class_name(self)}(beta1={self.beta1}, beta2={self.beta2}, eps={self.eps})'
71 |
--------------------------------------------------------------------------------
/objax/optimizer/lars.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['LARS']
16 |
17 | from typing import List
18 |
19 | import jax.numpy as jn
20 |
21 | from objax.module import Module, ModuleList
22 | from objax.typing import JaxArray
23 | from objax.util import class_name
24 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection
25 |
26 |
27 | class LARS(Module):
28 | """Layerwise adaptive rate scaling (LARS) optimizer.
29 |
30 | See https://arxiv.org/abs/1708.03888
31 | """
32 |
33 | def __init__(self, vc: VarCollection,
34 | momentum: float = 0.9,
35 | weight_decay: float = 1e-4,
36 | tc: float = 1e-3,
37 | eps: float = 1e-5):
38 | """Constructor for LARS optimizer.
39 |
40 | Args:
41 | vc: collection of variables to optimize.
42 | momentum: coefficient used for the moving average of the gradient.
43 | weight_decay: weight decay coefficient.
44 | tc: trust coefficient eta ( < 1) for trust ratio computation.
45 | eps: epsilon used for trust ratio computation.
46 | """
47 | self.momentum = momentum
48 | self.weight_decay = weight_decay
49 | self.tc = tc
50 | self.eps = eps
51 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
52 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
53 |
54 | def __call__(self, lr: float, grads: List[JaxArray]):
55 | """Updates variables based on LARS algorithm.
56 |
57 | Args:
58 | lr: learning rate. The LARS paper suggests using lr = lr_0 * (1 -t/T)**2,
59 | where t is the current epoch number and T the maximum number of epochs.
60 | grads: the gradients to apply.
61 | """
62 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables'
63 |
64 | for g, p, m in zip(grads, self.train_vars, self.m):
65 | p_norm = jn.linalg.norm(p.value)
66 | g_norm = jn.linalg.norm(g)
67 | trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
68 | local_lr = lr * jn.maximum(jn.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
69 | m.value = self.momentum * m.value + local_lr * (g + self.weight_decay * p.value)
70 | p.value -= m.value
71 |
72 | def __repr__(self):
73 | return f'{class_name(self)}(momentum={self.momentum}, weight_decay={self.weight_decay}, ' \
74 | f'tc={self.tc}, eps={self.eps})'
75 |
--------------------------------------------------------------------------------
/objax/optimizer/momentum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['Momentum']
16 |
17 | from typing import List, Optional
18 |
19 | from jax import numpy as jn
20 |
21 | from objax.module import Module, ModuleList
22 | from objax.util import class_name
23 | from objax.variable import TrainRef, StateVar, TrainVar, VarCollection
24 |
25 |
26 | class Momentum(Module):
27 | """Momentum optimizer."""
28 |
29 | def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False):
30 | """Constructor for momentum optimizer class.
31 |
32 | Args:
33 | vc: collection of variables to optimize.
34 | momentum: the momentum hyperparameter.
35 | nesterov: bool indicating whether to use the Nesterov method.
36 | """
37 | self.momentum = momentum
38 | self.nesterov = nesterov
39 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
40 | self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
41 |
42 | def __call__(self, lr: float, grads: List[jn.ndarray], momentum: Optional[float] = None):
43 | """Updates variables and other state based on momentum (or Nesterov) SGD.
44 |
45 | Args:
46 | lr: the learning rate.
47 | grads: the gradients to apply.
48 | momentum: optional, override the default momentum.
49 | """
50 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables'
51 | if momentum is None:
52 | momentum = self.momentum
53 | if self.nesterov:
54 | for g, p, m in zip(grads, self.train_vars, self.m):
55 | m.value = g + momentum * m.value
56 | p.value -= lr * (g + momentum * m.value)
57 | else:
58 | for g, p, m in zip(grads, self.train_vars, self.m):
59 | m.value = g + momentum * m.value
60 | p.value -= lr * m.value
61 |
62 | def __repr__(self):
63 | return f'{class_name(self)}(momentum={self.momentum}, nesterov={self.nesterov})'
64 |
--------------------------------------------------------------------------------
/objax/optimizer/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['LinearAnnealing', 'StepDecay']
16 |
17 |
18 | import abc
19 | from typing import List, Tuple, Union
20 |
21 | import jax.numpy as jn
22 |
23 |
24 | class Scheduler:
25 | def __init__(self,
26 | base_lr: float = 1.0):
27 | """Constructs an instance for learning rate scheduler.
28 |
29 | Args:
30 | base_lr: base learning rate.
31 | """
32 | self.base_lr = base_lr
33 |
34 | @abc.abstractmethod
35 | def multiplier(self, step: float):
36 | """Returns learning rate multiplier w.r.t. certain schedule."""
37 | raise NotImplementedError
38 |
39 | def __call__(self, step: float):
40 | """Returns learning rate or multiplier at certain step.
41 |
42 | Args:
43 | step: number of training step. When 0, we use the step counter.
44 |
45 | Returns:
46 | learning rate when base_lr is provided; otherwise,
47 | only multiplier is returned.
48 | """
49 | return self.base_lr * self.multiplier(step)
50 |
51 |
52 | class LinearAnnealing(Scheduler):
53 | def __init__(self,
54 | max_step: float,
55 | base_lr: float = 1.0,
56 | is_cycle: bool = True,
57 | min_lr: float = 0.0):
58 | """Constructs an instance for linear annealing learning rate scheduler.
59 |
60 | Args:
61 | max_step: maximum number of train step.
62 | base_lr: base learning rate.
63 | is_cycle: trigger cyclical learning rate multiplier when step
64 | exceeds max_step.
65 | min_lr: minimum learning rate at max_step.
66 | """
67 | super().__init__(base_lr=base_lr)
68 | assert base_lr >= min_lr, (
69 | 'base_lr should be greater than or equal to min_lr.')
70 | self.max_step = max_step
71 | self.is_cycle = is_cycle
72 | self.min_lr_multiplier = min_lr / self.base_lr
73 |
74 | def multiplier(self, step: float):
75 | """Returns linear annealing learning rate multiplier."""
76 |
77 | # If is_cycle, we use the remainder of step; otherwise, we stop update.
78 | if self.is_cycle:
79 | step = jn.remainder(step, self.max_step)
80 | else:
81 | step = jn.minimum(step, self.max_step)
82 |
83 | return 1.0 - (step / self.max_step) * (
84 | 1.0 - self.min_lr_multiplier)
85 |
86 |
87 | class StepDecay(Scheduler):
88 | def __init__(self,
89 | step_size: Union[float, List, Tuple],
90 | base_lr: float = 1.0,
91 | gamma: float = 0.1):
92 | """Constructs an instance for step decay learning rate scheduler.
93 |
94 | Args:
95 | step_size: number of train steps to reduce learning rate.
96 | base_lr: base learning rate.
97 | gamma: learning rate decay rate.
98 | """
99 | super().__init__(base_lr=base_lr)
100 | self.gamma = gamma
101 | self.step_size = step_size
102 |
103 | def multiplier(self, step: float):
104 | """Returns step decay learning rate multiplier."""
105 | if isinstance(self.step_size, (tuple, list)):
106 | exponent = jn.sum(jn.greater_equal(step, jn.array(self.step_size)))
107 | else:
108 | exponent = step // self.step_size
109 | return self.gamma ** exponent
110 |
--------------------------------------------------------------------------------
/objax/optimizer/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['SGD']
16 |
17 | from typing import List
18 |
19 | from objax.module import Module, ModuleList
20 | from objax.typing import JaxArray
21 | from objax.util import class_name
22 | from objax.variable import TrainRef, TrainVar, VarCollection
23 |
24 |
25 | class SGD(Module):
26 | """Stochastic Gradient Descent (SGD) optimizer."""
27 |
28 | def __init__(self, vc: VarCollection):
29 | """Constructor for SGD optimizer.
30 |
31 | Args:
32 | vc: collection of variables to optimize.
33 | """
34 | self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
35 |
36 | def __call__(self, lr: float, grads: List[JaxArray]):
37 | """Updates variables based on SGD algorithm.
38 |
39 | Args:
40 | lr: the learning rate.
41 | grads: the gradients to apply.
42 | """
43 | assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables'
44 | for g, p in zip(grads, self.train_vars):
45 | p.value -= lr * g
46 |
47 | def __repr__(self):
48 | return f'{class_name(self)}()'
49 |
--------------------------------------------------------------------------------
/objax/privacy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import dpsgd
16 |
--------------------------------------------------------------------------------
/objax/privacy/dpsgd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .gradient import *
16 | from .privacyaccountant import *
17 |
--------------------------------------------------------------------------------
/objax/random/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .random import *
16 |
--------------------------------------------------------------------------------
/objax/random/random.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['DEFAULT_GENERATOR', 'Generator', 'randint', 'normal', 'truncated_normal', 'uniform']
16 |
17 | from typing import Optional, Tuple
18 |
19 | import jax.random as jr
20 |
21 | from objax.module import Module
22 | from objax.util import class_name
23 | from objax.variable import RandomState, VarCollection
24 |
25 |
26 | class Generator(Module):
27 | """Random number generator module."""
28 |
29 | def __init__(self, seed: int = 0):
30 | """Create a random key generator, seed is the random generator initial seed."""
31 | super().__init__()
32 | self.initial_seed = seed
33 | self._key: Optional[RandomState] = None
34 |
35 | @property
36 | def key(self):
37 | """The random generator state (a tensor of 2 int32)."""
38 | if self._key is None:
39 | self._key = RandomState(self.initial_seed)
40 | return self._key
41 |
42 | def seed(self, seed: int = 0):
43 | """Sets a new random generator seed."""
44 | self.initial_seed = seed
45 | if self._key is not None:
46 | self._key.seed(seed)
47 |
48 | def __call__(self):
49 | """Generate a new generator state."""
50 | return self.key.split(1)[0]
51 |
52 | def vars(self, scope: str = '') -> VarCollection:
53 | self.key # Make sure the key is created before collecting the vars.
54 | return super().vars(scope)
55 |
56 | def __repr__(self):
57 | return f'{class_name(self)}(seed={self.initial_seed})'
58 |
59 |
60 | DEFAULT_GENERATOR = Generator(0)
61 |
62 |
63 | def normal(shape: Tuple[int, ...], *, mean: float = 0, stddev: float = 1, generator: Generator = DEFAULT_GENERATOR):
64 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a normal distribution
65 | with mean ``mean`` and standard deviation ``stddev``.
66 |
67 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function
68 | then generator variables (including DEFAULT_GENERATOR) have to be added to the
69 | variable collection."""
70 | return jr.normal(generator(), shape=shape) * stddev + mean
71 |
72 |
73 | def randint(shape: Tuple[int, ...], low: int, high: int, generator: Generator = DEFAULT_GENERATOR):
74 | """Returns a ``JaxAarray`` of shape ``shape`` with random integers in {low, ..., high-1}.
75 |
76 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function
77 | then generator variables (including DEFAULT_GENERATOR) have to be added to the
78 | variable collection."""
79 | return jr.randint(generator(), shape=shape, minval=low, maxval=high)
80 |
81 |
82 | def truncated_normal(shape: Tuple[int, ...], *,
83 | stddev: float = 1,
84 | lower: float = -2,
85 | upper: float = 2,
86 | generator: Generator = DEFAULT_GENERATOR):
87 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a normal distribution
88 | with mean 0 and standard deviation ``stddev`` truncated by (``lower``, ``upper``).
89 |
90 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function
91 | then generator variables (including DEFAULT_GENERATOR) have to be added to the
92 | variable collection."""
93 | return jr.truncated_normal(generator(), shape=shape, lower=lower, upper=upper) * stddev
94 |
95 |
96 | def uniform(shape: Tuple[int, ...], generator: Generator = DEFAULT_GENERATOR):
97 | """Returns a ``JaxArray`` of shape ``shape`` with random numbers from a uniform distribution [0, 1].
98 |
99 | NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function
100 | then generator variables (including DEFAULT_GENERATOR) have to be added to the
101 | variable collection."""
102 | return jr.uniform(generator(), shape=shape)
103 |
--------------------------------------------------------------------------------
/objax/typing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """This module contains type declarations for Objax."""
16 |
17 | __all__ = ['FileOrStr', 'JaxArray', 'JaxDType']
18 |
19 | from typing import Union, IO, BinaryIO, Sequence, Tuple
20 |
21 | import jax
22 | import jax.numpy as jn
23 |
24 | ConvPaddingInt = Union[Sequence[Tuple[int, int]], Tuple[int, int], int]
25 | FileOrStr = Union[str, IO[BinaryIO]]
26 | JaxArray = jax.Array
27 | JaxDType = Union[jn.complex64, jn.complex128, jn.bfloat16,
28 | jn.float16, jn.float32, jn.float64,
29 | jn.int8, jn.int16, jn.int32, jn.int64,
30 | jn.uint8, jn.uint16, jn.uint32, jn.uint64]
31 |
--------------------------------------------------------------------------------
/objax/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import image
16 | from . import check
17 | from .util import *
18 | from .objax2tf import Objax2Tf
19 | from .tracing import find_used_variables
--------------------------------------------------------------------------------
/objax/util/check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['assert_assigned_type_and_shape_match']
16 |
17 | import jax
18 |
19 |
20 | TRACER_TYPES = (jax.interpreters.partial_eval.JaxprTracer,
21 | jax.interpreters.partial_eval.DynamicJaxprTracer)
22 |
23 |
24 | def split_shape_and_device(array):
25 | if isinstance(array, jax.Array) and hasattr(array, 'sharding') and isinstance(
26 | array.sharding, jax.sharding.PmapSharding):
27 | return array.shape[0], array.shape[1:]
28 | else:
29 | return None, array.shape
30 |
31 |
32 | def assert_assigned_type_and_shape_match(existing_tensor, new_tensor):
33 | assert isinstance(new_tensor, jax.Array), \
34 | f'Assignments to variable must be an instance of JaxArray, but received f{type(new_tensor)}.'
35 |
36 | new_tensor_device, new_tensor_shape = split_shape_and_device(new_tensor)
37 | self_device, self_shape = split_shape_and_device(existing_tensor)
38 |
39 | device_mismatch_error = f'Can not replicate a variable that is currently on ' \
40 | f'{self_device} devices to {new_tensor_device} devices.'
41 | assert (new_tensor_device is None) or (self_device is None) or (self_device == new_tensor_device), \
42 | device_mismatch_error
43 |
44 | shorter_length = min(len(new_tensor.shape), len(existing_tensor.shape))
45 | is_special_ok = (isinstance(new_tensor, TRACER_TYPES) or isinstance(existing_tensor, TRACER_TYPES))
46 | is_special_ok = is_special_ok and existing_tensor.shape[-shorter_length:] == new_tensor.shape[-shorter_length:]
47 |
48 | shape_mismatch_error = f'Assign can not change shape of variable. The current variable shape is {self_shape},' \
49 | f' but the requested new shape is {new_tensor_shape}.'
50 | assert is_special_ok or new_tensor_shape == self_shape or new_tensor.shape == existing_tensor.shape, \
51 | shape_mismatch_error
52 |
--------------------------------------------------------------------------------
/objax/util/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['from_file', 'image_grid', 'nchw', 'nhwc', 'normalize_to_uint8', 'normalize_to_unit_float', 'to_png']
16 |
17 | import io
18 | from typing import Union, BinaryIO, IO
19 |
20 | import jax.numpy as jn
21 | import numpy as np
22 | from PIL import Image
23 |
24 | from objax.typing import JaxArray
25 |
26 |
27 | def from_file(file: Union[str, IO[BinaryIO]]) -> np.ndarray:
28 | """Read an image from a file, convert it RGB and return it as an array.
29 |
30 | Args:
31 | file: filename or python file handle of the input file.
32 |
33 | Return:
34 | 3D numpy array (C, H, W) normalized with normalize_to_unit_float.
35 | """
36 | image = np.asarray(Image.open(file).convert('RGB'))
37 | return normalize_to_unit_float(image.transpose((2, 0, 1)))
38 |
39 |
40 | def image_grid(image: np.ndarray) -> np.ndarray:
41 | """Rearrange array of images (nh, hw, c, h, w) into image grid in a single image (c, nh * h, nh * w)."""
42 | s = image.shape
43 | return image.transpose([2, 0, 3, 1, 4]).reshape([s[2], s[3] * s[0], s[4] * s[1]])
44 |
45 |
46 | def nchw(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]:
47 | """Converts an array in (N,H,W,C) format to (N,C,H,W) format."""
48 | dims = list(range(x.ndim))
49 | dims.insert(-2, dims.pop())
50 | return x.transpose(dims)
51 |
52 |
53 | def nhwc(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]:
54 | """Converts an array in (N,C,H,W) format to (N,H,W,C) format."""
55 | dims = list(range(x.ndim))
56 | dims.append(dims.pop(-3))
57 | return x.transpose(dims)
58 |
59 |
60 | def normalize_to_uint8(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]:
61 | """Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, ..., 255}."""
62 | return (128 * (x + (1 - 1 / 256))).clip(0, 255).round().astype('uint8')
63 |
64 |
65 | def normalize_to_unit_float(x: Union[np.ndarray, JaxArray]) -> Union[np.ndarray, JaxArray]:
66 | """Map an uint8 image in {0, 1, ..., 255} to float interval [1/256-1, 1-1/256]."""
67 | return x * (1 / 128) + (1 / 256 - 1)
68 |
69 |
70 | def to_png(x: Union[np.ndarray, JaxArray]) -> bytes:
71 | """Converts numpy array in (C,H,W) format into PNG format."""
72 | if isinstance(x, jn.ndarray):
73 | x = np.array(x)
74 | if x.dtype in (np.float64, np.float32, np.float16):
75 | x = np.transpose(normalize_to_uint8(x), (1, 2, 0))
76 | elif x.dtype != np.uint8:
77 | raise ValueError('Unsupported array type, expecting float or uint8', x.dtype)
78 | if x.shape[2] == 1:
79 | x = np.broadcast_to(x, x.shape[:2] + (3,))
80 | with io.BytesIO() as f:
81 | Image.fromarray(x).save(f, 'png')
82 | return f.getvalue()
83 |
--------------------------------------------------------------------------------
/objax/util/objax2tf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List
16 |
17 | from objax.module import Module
18 | from objax.typing import JaxArray
19 |
20 | try:
21 | # Only import tensorflow if available.
22 | import tensorflow as tf
23 |
24 | tf.config.experimental.set_visible_devices([], 'GPU')
25 | except ImportError:
26 | # Make fake tf, so code in this file will be successfully imported even when Tensorflow is not installed.
27 | tf = type('tf', (), {})
28 | setattr(tf, 'Module', object)
29 |
30 | def _fake_tf_function(func=None, **kwargs):
31 | del kwargs
32 | if func is not None:
33 | return func
34 | else:
35 | return lambda x: x
36 |
37 | setattr(tf, 'function', _fake_tf_function)
38 |
39 |
40 | class Objax2Tf(tf.Module):
41 | """Objax to Tensorflow converter, which converts Objax module to tf.Module."""
42 |
43 | def __init__(self, module: Module):
44 | """Create a Tensorflow module from Objax module.
45 |
46 | Args:
47 | module: Objax module to be converted to Tensorflow tf.Module.
48 | """
49 | from jax.experimental import jax2tf
50 | assert hasattr(tf, '__version__'), 'Tensorflow must be installed for Objax2Tf to work.'
51 | assert tf.__version__ >= '2.0', 'Objax2Tf works only with Tensorflow 2.'
52 | assert isinstance(module, Module), 'Input argument to Objax2Tf must be an Objax module.'
53 |
54 | super().__init__()
55 |
56 | module_vars = module.vars()
57 |
58 | def wrapped_op(tensor_list: List[JaxArray], kwargs, *args):
59 | original_values = module_vars.tensors()
60 | try:
61 | module_vars.assign(tensor_list)
62 | return module(*args, **kwargs)
63 | finally:
64 | module_vars.assign(original_values)
65 |
66 | tf_function = jax2tf.convert(wrapped_op)
67 | self._tf_vars = [tf.Variable(v) for v in module_vars.tensors()]
68 | self._tf_call = tf_function
69 |
70 | @tf.function(autograph=False)
71 | def __call__(self, *args, **kwargs):
72 | """Calls Tensorflow function which was generated from Objax module."""
73 | return self._tf_call(self._tf_vars, kwargs, *args)
74 |
--------------------------------------------------------------------------------
/objax/zoo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/objax/zoo/convnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 |
17 | import objax
18 | from objax.typing import JaxArray
19 |
20 |
21 | class ConvNet(objax.nn.Sequential):
22 | """ConvNet implementation."""
23 |
24 | @staticmethod
25 | def _mean_reduce(x: JaxArray) -> JaxArray:
26 | return x.mean((2, 3))
27 |
28 | def __init__(self, nin, nclass, scales, filters, filters_max,
29 | pooling=objax.functional.max_pool_2d, **kwargs):
30 | """Creates ConvNet instance.
31 |
32 | Args:
33 | nin: number of channels in the input image.
34 | nclass: number of output classes.
35 | scales: number of pooling layers, each of which reduces spatial dimension by 2.
36 | filters: base number of convolution filters.
37 | Number of convolution filters is increased by 2 every scale until it reaches filters_max.
38 | filters_max: maximum number of filters.
39 | pooling: type of pooling layer.
40 | """
41 | del kwargs
42 |
43 | def nf(scale):
44 | return min(filters_max, filters << scale)
45 |
46 | ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu]
47 | for i in range(scales):
48 | ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu,
49 | objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu,
50 | functools.partial(pooling, size=2, strides=2)])
51 | ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce])
52 | super().__init__(ops)
53 |
--------------------------------------------------------------------------------
/objax/zoo/dnnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Callable, Iterable
16 |
17 | from objax.nn import Linear, Sequential
18 |
19 |
20 | class DNNet(Sequential):
21 | """Deep neural network (MLP) implementation."""
22 |
23 | def __init__(self, layer_sizes: Iterable[int], activation: Callable):
24 | """Creates DNNet instance.
25 |
26 | Args:
27 | layer_sizes: number of neurons for each layer.
28 | activation: layer activation.
29 | """
30 | layer_sizes = list(layer_sizes)
31 | assert len(layer_sizes) >= 2
32 | ops = []
33 | for i in range(1, len(layer_sizes)):
34 | ops.extend([Linear(layer_sizes[i - 1], layer_sizes[i]), activation])
35 | super().__init__(ops)
36 |
--------------------------------------------------------------------------------
/objax/zoo/rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Callable
16 |
17 | import jax.numpy as jn
18 |
19 | from objax import Module
20 | from objax.nn import Linear
21 | from objax.nn.init import kaiming_normal
22 | from objax.typing import JaxArray
23 | from objax.variable import TrainVar, StateVar
24 |
25 |
26 | class RNN(Module):
27 | """ Recurrent Neural Network (RNN) block."""
28 |
29 | def __init__(self,
30 | nstate: int,
31 | nin: int,
32 | nout: int,
33 | activation: Callable = jn.tanh,
34 | w_init: Callable = kaiming_normal):
35 | """Creates an RNN instance.
36 |
37 | Args:
38 | nstate: number of hidden units.
39 | nin: number of input units.
40 | nout: number of output units.
41 | activation: actication function for hidden layer.
42 | w_init: weight initializer for RNN model weights.
43 | """
44 | self.num_inputs = nin
45 | self.num_outputs = nout
46 | self.nstate = nstate
47 | self.activation = activation
48 |
49 | # Hidden layer parameters
50 | self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate)))
51 | self.w_hh = TrainVar(w_init((self.nstate, self.nstate)))
52 | self.b_h = TrainVar(jn.zeros(self.nstate))
53 |
54 | self.output_layer = Linear(self.nstate, self.num_outputs)
55 |
56 | def init_state(self, batch_size):
57 | """Initialize hidden state for input batch of size ``batch_size``."""
58 | self.state = StateVar(jn.zeros((batch_size, self.nstate)))
59 |
60 | def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray:
61 | """Forward pass through RNN.
62 |
63 | Args:
64 | inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``.
65 | only_return_final: return only the last output if ``True``, or all output otherwise.`
66 |
67 | Returns:
68 | Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``.
69 | """
70 | # Dimensions: num_steps, batch_size, vocab_size
71 | outputs = []
72 | for x in inputs:
73 | self.state.value = self.activation(
74 | jn.dot(x, self.w_xh.value)
75 | + jn.dot(self.state.value, self.w_hh.value)
76 | + self.b_h.value)
77 | y = self.output_layer(self.state.value)
78 | outputs.append(y)
79 | if only_return_final:
80 | return outputs[-1]
81 | return jn.concatenate(outputs, axis=0)
82 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | numpy>=1.18.0
3 | pillow
4 | jaxlib>=0.4.19
5 | jax>=0.3.25
6 | tensorboard>=2.3.0
7 | parameterized
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the 'License');
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an 'AS IS' BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 | from pkg_resources import parse_requirements
18 | from setuptools import find_packages, setup
19 |
20 | README_FILE = 'README.md'
21 | REQUIREMENTS_FILE = 'requirements.txt'
22 | VERSION_FILE = 'objax/_version.py'
23 | VERSION_REGEXP = r'^__version__ = \'(\d+\.\d+\.\d+)\''
24 |
25 | r = re.search(VERSION_REGEXP, open(VERSION_FILE).read(), re.M)
26 | if r is None:
27 | raise RuntimeError(f'Unable to find version string in {VERSION_FILE}.')
28 |
29 | version = r.group(1)
30 | long_description = open(README_FILE, encoding='utf-8').read()
31 | install_requires = [str(r) for r in parse_requirements(open(REQUIREMENTS_FILE, 'rt'))]
32 |
33 | setup(
34 | name='objax',
35 | version=version,
36 | description='Objax is a machine learning framework that provides an Object Oriented layer for JAX.',
37 | long_description=long_description,
38 | long_description_content_type='text/markdown',
39 | author='Objax team',
40 | author_email='objax-dev@google.com',
41 | url='https://github.com/google/objax',
42 | packages=find_packages(),
43 | classifiers=[
44 | 'Development Status :: 5 - Production/Stable',
45 | 'Intended Audience :: Developers',
46 | 'Intended Audience :: Science/Research',
47 | 'License :: OSI Approved :: Apache Software License',
48 | 'Programming Language :: Python :: 3.9',
49 | 'Programming Language :: Python :: 3.10',
50 | 'Programming Language :: Python :: 3.11',
51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
52 | ],
53 | install_requires=install_requires,
54 | )
55 |
--------------------------------------------------------------------------------
/tests/dropout.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit Tests for Dropout layer."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 |
21 | import objax
22 | from objax import random
23 |
24 |
25 | class TestDropout(unittest.TestCase):
26 | def test_on_dropout_0_5(self):
27 | """
28 | Pass an input through a Dropout layer
29 | that keeps half the input and test
30 | that half of the output values are zero.
31 | """
32 |
33 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]])
34 | keep = 0.5
35 | test_generator = random.DEFAULT_GENERATOR
36 | test_generator.seed(3)
37 | dropout_layer = objax.nn.Dropout(keep, test_generator)
38 | training = True
39 | drop_output = dropout_layer(drop_input, training)
40 | self.assertEqual(jn.count_nonzero(drop_output), 3)
41 | for index in range(drop_output.shape[1]):
42 | if drop_output[0][index] != 0:
43 | self.assertEqual(drop_output[0][index], drop_input[0][index] / keep)
44 |
45 | def test_on_dropout_two_dimension(self):
46 | """
47 | Pass a two dimensional input through a Dropout layer
48 | that keeps half the input and test
49 | that half of the output values are zero.
50 | """
51 |
52 | drop_input = jn.array([[1., 2., 3.], [5., 6., 7.]])
53 | keep = 0.5
54 | test_generator = random.DEFAULT_GENERATOR
55 | test_generator.seed(3)
56 | dropout_layer = objax.nn.Dropout(keep, test_generator)
57 | training = True
58 | drop_output = dropout_layer(drop_input, training)
59 | self.assertEqual(jn.count_nonzero(drop_output), 3)
60 |
61 | def test_on_dropout_1_0(self):
62 | """
63 | Pass an input through a Dropout layer
64 | that keeps all of the input and test
65 | that all of the output values are non-zero.
66 | """
67 |
68 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]])
69 | keep = 1.0
70 | dropout_layer = objax.nn.Dropout(keep)
71 | training = True
72 | drop_output = dropout_layer(drop_input, training)
73 | self.assertEqual(jn.count_nonzero(drop_output), 6)
74 | self.assertTrue(jn.array_equal(drop_input, drop_output))
75 |
76 | def test_on_dropout_0_0(self):
77 | """
78 | Pass an input through a Dropout layer
79 | that keeps none of the input and test
80 | that all of the output values are zero.
81 | """
82 |
83 | drop_input = jn.array([[1., 2., 3., 4., 5., 6.]])
84 | keep = 0.0
85 | test_generator = random.DEFAULT_GENERATOR
86 | test_generator.seed(1)
87 | dropout_layer = objax.nn.Dropout(keep, test_generator)
88 | training = True
89 | drop_output = dropout_layer(drop_input, training)
90 | self.assertEqual(jn.count_nonzero(drop_output), 0)
91 |
92 | def test_on_dropout_inference(self):
93 | """
94 | Pass an input to the Dropout layer when
95 | training is false and test that the output
96 | is equal to the input.
97 | """
98 |
99 | drop_input = jn.array([[1., 2., 3., 4., 5.]])
100 | dropout_layer = objax.nn.Dropout(0.5)
101 | training = False
102 | drop_output = dropout_layer(drop_input, training)
103 | self.assertTrue(jn.array_equal(drop_input, drop_output))
104 |
105 |
106 | if __name__ == '__main__':
107 | unittest.main()
108 |
--------------------------------------------------------------------------------
/tests/functional_interpolate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for functional upsample operations."""
16 |
17 | import unittest
18 | import jax
19 | import jax.numpy as jn
20 | import numpy as np
21 |
22 | import objax
23 |
24 |
25 | def shaparange(s):
26 | return jn.arange(np.prod(s), dtype=float).reshape(s)
27 |
28 |
29 | class TestUpsample(unittest.TestCase):
30 | methods = ['nearest', 'linear', 'bilinear', 'trilinear', 'triangle', 'cubic',
31 | 'bicubic', 'tricubic', 'lanczos3', 'lanczos5']
32 |
33 | def test_upsample2d(self):
34 | x = shaparange((2, 3, 10, 30))
35 | shape = x.shape
36 | for method in self.methods:
37 | y = objax.functional.core.ops.upsample_2d(x, (2, 3), method)
38 | output = jax.image.resize(x.transpose([0, 2, 3, 1]),
39 | shape=(shape[0], shape[2] * 2, shape[3] * 3, shape[1]),
40 | method=method).transpose([0, 3, 1, 2])
41 | self.assertEqual(y.tolist(), output.tolist())
42 |
43 | def test_interpolate(self):
44 | x = shaparange((1, 3, 2, 3, 10, 30))
45 | shape = x.shape
46 | for method in self.methods:
47 | output = 2
48 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method)
49 | self.assertEqual(y.shape, (shape[0], *(jn.array(shape[1:])) * output))
50 | output = (2, 2, 2)
51 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method)
52 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)],
53 | *(jn.array(shape[len(shape) - len(output):]) * jn.array(output))))
54 | output = (2, 2, 2, 2)
55 | y = objax.functional.core.ops.interpolate(x, scale_factor=output, mode=method)
56 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)],
57 | *(jn.array(shape[len(shape) - len(output):]) * jn.array(output))))
58 | output = 2
59 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method)
60 | self.assertEqual(y.shape, (*shape[:-1], output))
61 | output = (2, 2, 2)
62 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method)
63 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], *output))
64 | output = (2, 2, 2, 2)
65 | y = objax.functional.core.ops.interpolate(x, size=output, mode=method)
66 | self.assertEqual(y.shape, (*shape[:len(shape) - len(output)], *output))
67 |
68 |
69 | if __name__ == '__main__':
70 | unittest.main()
71 |
--------------------------------------------------------------------------------
/tests/functional_pooling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for functional pooling operations."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 | import numpy as np
21 |
22 | import objax
23 |
24 |
25 | def shaparange(s):
26 | return jn.arange(np.prod(s), dtype=float).reshape(s)
27 |
28 |
29 | def pad(x, pad_width):
30 | return np.pad(x, pad_width, mode='constant')
31 |
32 |
33 | class TestPooling(unittest.TestCase):
34 | def test_average_pooling2d(self):
35 | x = shaparange((2, 3, 10, 30))
36 | y = objax.functional.average_pool_2d(x, size=5)
37 | z = x.reshape((2, 3, 2, 5, 6, 5)).mean((-3, -1))
38 | self.assertEqual(y.tolist(), z.tolist())
39 | y = objax.functional.average_pool_2d(x, size=5, strides=1)
40 | z = np.zeros((2, 3, 6, 26), dtype=float)
41 | for i in range(6):
42 | for j in range(26):
43 | z[:, :, i, j] = x[:, :, i:i + 5, j:j + 5].mean((-2, -1))
44 | self.assertEqual(y.tolist(), z.tolist())
45 | y = objax.functional.average_pool_2d(x, size=(2, 3))
46 | z = x.reshape((2, 3, 5, 2, 10, 3)).mean((-3, -1))
47 | self.assertEqual(y.tolist(), z.tolist())
48 |
49 | def test_max_pooling2d(self):
50 | x = shaparange((2, 3, 10, 30))
51 | y = objax.functional.max_pool_2d(x, size=5)
52 | z = x.reshape((2, 3, 2, 5, 6, 5)).max((-3, -1))
53 | self.assertEqual(y.tolist(), z.tolist())
54 | y = objax.functional.max_pool_2d(x, size=5, strides=1)
55 | z = np.zeros((2, 3, 6, 26), dtype=float)
56 | for i in range(6):
57 | for j in range(26):
58 | z[:, :, i, j] = x[:, :, i:i + 5, j:j + 5].max((-2, -1))
59 | self.assertEqual(y.tolist(), z.tolist())
60 | y = objax.functional.max_pool_2d(x, size=(2, 3))
61 | z = x.reshape((2, 3, 5, 2, 10, 3)).max((-3, -1))
62 | self.assertEqual(y.tolist(), z.tolist())
63 |
64 | def test_pooling2d_padding(self):
65 | x = shaparange((2, 3, 10, 30))
66 | y = objax.functional.average_pool_2d(x, size=5, padding=(2, 3))
67 | z = pad(x, ((0, 0), (0, 0), (2, 3), (2, 3))).reshape((2, 3, 3, 5, 7, 5)).mean((-3, -1))
68 | self.assertEqual(y.tolist(), z.tolist())
69 | y = objax.functional.max_pool_2d(x, size=5, padding=(2, 3))
70 | z = pad(x, ((0, 0), (0, 0), (2, 3), (2, 3))).reshape((2, 3, 3, 5, 7, 5)).max((-3, -1))
71 | self.assertEqual(y.tolist(), z.tolist())
72 | y = objax.functional.average_pool_2d(x, size=5, padding=((2, 3), (3, 2)))
73 | z = pad(x, ((0, 0), (0, 0), (2, 3), (3, 2))).reshape((2, 3, 3, 5, 7, 5)).mean((-3, -1))
74 | self.assertEqual(y.tolist(), z.tolist())
75 | y = objax.functional.max_pool_2d(x, size=5, padding=((2, 3), (3, 2)))
76 | z = pad(x, ((0, 0), (0, 0), (2, 3), (3, 2))).reshape((2, 3, 3, 5, 7, 5)).max((-3, -1))
77 | self.assertEqual(y.tolist(), z.tolist())
78 | y = objax.functional.average_pool_2d(x, size=2, padding=1)
79 | z = pad(x, ((0, 0), (0, 0), (1, 1), (1, 1))).reshape((2, 3, 6, 2, 16, 2)).mean((-3, -1))
80 | self.assertEqual(y.tolist(), z.tolist())
81 | y = objax.functional.max_pool_2d(x, size=2, padding=1)
82 | z = pad(x, ((0, 0), (0, 0), (1, 1), (1, 1))).reshape((2, 3, 6, 2, 16, 2)).max((-3, -1))
83 | self.assertEqual(y.tolist(), z.tolist())
84 |
85 | def test_space_batch(self):
86 | """Test batch_to_space2d and space_to_batch2d."""
87 | x = shaparange((2, 3, 10, 30))
88 | y = objax.functional.space_to_batch2d(x, size=5)
89 | z = objax.functional.batch_to_space2d(y, size=5)
90 | self.assertEqual(x.tolist(), z.tolist())
91 | self.assertEqual(y.shape, (50, 3, 2, 6))
92 | y = objax.functional.space_to_batch2d(x, size=(2, 3))
93 | z = objax.functional.batch_to_space2d(y, size=(2, 3))
94 | self.assertEqual(x.tolist(), z.tolist())
95 | self.assertEqual(y.shape, (12, 3, 5, 10))
96 |
97 | def test_space_channel(self):
98 | """Test channel_to_space2d and space_to_channel2d."""
99 | x = shaparange((2, 3, 10, 30))
100 | y = objax.functional.space_to_channel2d(x, size=5)
101 | z = objax.functional.channel_to_space2d(y, size=5)
102 | self.assertEqual(x.tolist(), z.tolist())
103 | self.assertEqual(y.shape, (2, 75, 2, 6))
104 | y = objax.functional.space_to_channel2d(x, size=(2, 3))
105 | z = objax.functional.channel_to_space2d(y, size=(2, 3))
106 | self.assertEqual(x.tolist(), z.tolist())
107 | self.assertEqual(y.shape, (2, 18, 5, 10))
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/tests/jit.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for ObJAX JIT."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 | from jax.core import ConcretizationTypeError
21 |
22 | import objax
23 | from objax.typing import JaxArray
24 |
25 |
26 | class LinearArgs(objax.nn.Linear):
27 | def __call__(self, x: JaxArray, some_args: float) -> JaxArray:
28 | """Returns the results of applying the linear transformation to input x."""
29 | y = jn.dot(x, self.w.value) * some_args
30 | if self.b is not None:
31 | y += self.b.value
32 | return y
33 |
34 |
35 | class LinearTrain(objax.nn.Linear):
36 | def __call__(self, x: JaxArray, training: bool) -> JaxArray:
37 | """Returns the results of applying the linear transformation to input x."""
38 | y = jn.dot(x, self.w.value)
39 | if training:
40 | y = -y
41 | if self.b is not None:
42 | y += self.b.value
43 | return y
44 |
45 |
46 | class TestJit(unittest.TestCase):
47 | def test_on_linear(self):
48 | k = objax.nn.Linear(3, 3)
49 | kj = objax.Jit(k)
50 | x = objax.random.normal((64, 3))
51 | y1 = kj(x)
52 | k.w.assign(k.w.value + 1)
53 | y2 = kj(x)
54 | k.w.assign(k.w.value - 1)
55 | y3 = kj(x)
56 | self.assertAlmostEqual(((y1 - y3) ** 2).sum(), 0)
57 | self.assertNotEqual(((y1 - y2) ** 2).sum(), 0)
58 |
59 | def test_double_jit(self):
60 | k = objax.nn.Linear(3, 3)
61 | kj = objax.Jit(objax.Jit(k))
62 | x = objax.random.normal((64, 3))
63 | y1 = kj(x)
64 | k.w.assign(k.w.value + 1)
65 | y2 = kj(x)
66 | k.w.assign(k.w.value - 1)
67 | y3 = kj(x)
68 | self.assertAlmostEqual(((y1 - y3) ** 2).sum(), 0)
69 | self.assertNotEqual(((y1 - y2) ** 2).sum(), 0)
70 |
71 | def test_jit_kwargs(self):
72 | x = objax.random.normal((64, 3))
73 | kj = objax.Jit(LinearArgs(3, 3))
74 | y1 = kj(x, 1)
75 | y2 = kj(x, some_args=1)
76 | y3 = kj(x, some_args=2)
77 | self.assertEqual(y1.tolist(), y2.tolist())
78 | self.assertNotEqual(y1.tolist(), y3.tolist())
79 | kj = objax.Jit(LinearTrain(3, 3))
80 | with self.assertRaises(ConcretizationTypeError):
81 | kj(x, training=True)
82 |
83 | def test_trainvar_assign(self):
84 | m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
85 |
86 | def increase():
87 | m[0].assign(m[0].value + 1)
88 | return m[0].value
89 |
90 | jit_increase = objax.Jit(increase, m.vars())
91 | jit_increase()
92 | self.assertEqual(m[0].value.tolist(), [1., 1.])
93 |
94 | def test_trainvar_and_ref_assign(self):
95 | m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
96 | m.append(objax.TrainRef(m[0]))
97 |
98 | def increase():
99 | m[0].assign(m[0].value + 1)
100 | m[1].assign(m[1].value + 1)
101 | return m[0].value
102 |
103 | jit_increase = objax.Jit(increase, m.vars())
104 | v = jit_increase()
105 | self.assertEqual(v.tolist(), [2., 2.])
106 | self.assertEqual(m[0].value.tolist(), [2., 2.])
107 |
108 | def test_constant_optimization(self):
109 | m = objax.nn.Linear(3, 4)
110 | jit_constant = objax.Jit(m, objax.VarCollection())
111 |
112 | x = objax.random.normal((10, 3))
113 | self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0)
114 |
115 | # Modify m (which was supposed to be constant!)
116 | m.b.assign(m.b.value + 1)
117 | self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40)
118 |
119 |
120 | if __name__ == '__main__':
121 | unittest.main()
122 |
--------------------------------------------------------------------------------
/tests/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for Convolution Layer."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 |
21 | import objax
22 |
23 |
24 | class TestLinear(unittest.TestCase):
25 | def test_on_linear_three_unit(self):
26 | """
27 | Pass an input through a linear filter with 3 units and
28 | test the shape and contents of the output.
29 | """
30 |
31 | # Define linear filter with 1 input channel and 3 output channels
32 | linear_filter = objax.nn.Linear(1, 3, use_bias=False)
33 | weights = objax.TrainVar(jn.array([[1., 2., 1.]]))
34 | linear_filter.w = weights
35 |
36 | # Define data and compute output response of linear filter
37 | data = jn.array([[1.], [2.]])
38 | features = linear_filter(data)
39 | expected_features = jn.array([[1., 2., 1.], [2., 4., 2.]])
40 | self.assertEqual(features.shape, (2, 3))
41 | self.assertTrue(jn.array_equal(features, expected_features))
42 |
43 | def test_on_linear_three_unit_with_bias(self):
44 | """
45 | Pass an input through a linear filter with 3 units and bias
46 | test the shape and contents of the output.
47 | """
48 |
49 | # Define linear filter with 1 input channel and 3 output channels
50 | linear_filter = objax.nn.Linear(1, 3, use_bias=True)
51 | weights = objax.TrainVar(jn.array([[1., 2., 1.]]))
52 | bias = objax.TrainVar(jn.array([2., 1., 2.]))
53 | linear_filter.w = weights
54 | linear_filter.b = bias
55 |
56 | # Define data and compute output response of linear filter
57 | data = jn.array([[1.], [2.]])
58 | features = linear_filter(data)
59 | expected_features = jn.array([[3., 3., 3.], [4., 5., 4.]])
60 | self.assertEqual(features.shape, (2, 3))
61 | self.assertTrue(jn.array_equal(features, expected_features))
62 |
63 |
64 | if __name__ == '__main__':
65 | unittest.main()
66 |
--------------------------------------------------------------------------------
/tests/nn_moving_average.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for MovingAverage and ExponentialMovingAverage Layer."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 | import numpy as np
21 |
22 | import objax
23 |
24 |
25 | class TestMovingAverage(unittest.TestCase):
26 |
27 | def test_MovingAverage(self):
28 | """Test MovingAverage."""
29 | x1 = jn.array([[0, 1, 2]])
30 | x2 = jn.array([[0, 0, 0]])
31 | x3 = jn.array([[-3, -4, 5]])
32 | init_value = 100
33 | shape = x1.shape
34 | ma = objax.nn.MovingAverage(shape=shape, buffer_size=2, init_value=init_value)
35 |
36 | x_ma1 = ma(x1)
37 | x_ma2 = ma(x2)
38 | x_ma3 = ma(x3)
39 |
40 | np.testing.assert_allclose(x_ma1, np.array([[50, 50.5, 51]]))
41 | np.testing.assert_allclose(x_ma2, np.array([[0, 0.5, 1]]))
42 | np.testing.assert_allclose(x_ma3, np.array([[-1.5, -2, 2.5]]))
43 |
44 | def test_ExponentialMovingAverage(self):
45 | """Test ExponentialMovingAverage."""
46 | x1 = jn.array([[0, 1, 2]]) * 100
47 | x2 = jn.array([[-3, -4, 5]]) * 100
48 | init_value = 100
49 | shape = x1.shape
50 | ema = objax.nn.ExponentialMovingAverage(shape=shape, init_value=init_value, momentum=0.8)
51 |
52 | x_ema1 = ema(x1)
53 | x_ema2 = ema(x2)
54 |
55 | np.testing.assert_allclose(x_ema1, np.array([[80, 100, 120]]))
56 | np.testing.assert_allclose(x_ema2, np.array([[4, 0, 196]]))
57 |
58 |
59 | if __name__ == '__main__':
60 | unittest.main()
61 |
--------------------------------------------------------------------------------
/tests/objax2tf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for Objax2Tf converter."""
16 |
17 | import shutil
18 | import tempfile
19 | import unittest
20 |
21 | import numpy as np
22 | import objax
23 | from objax.zoo.wide_resnet import WideResNet
24 | import tensorflow as tf
25 |
26 |
27 | BATCH_SIZE = 4
28 | NCHANNELS = 3
29 | NCLASSES = 10
30 | IMAGE_SIZE = 32
31 |
32 |
33 | class TestObjax2Tf(unittest.TestCase):
34 |
35 | def verify_converted_predict_op(self, objax_op, tf_op, shape):
36 | x1 = np.random.normal(size=shape)
37 | x2 = np.random.normal(size=shape)
38 | # due to differences in op implementations, there might be small numerical
39 | # differences between TF and Objax, thus comparing up to 1e-4 relative tolerance
40 | np.testing.assert_allclose(objax_op(x1), tf_op(tf.convert_to_tensor(x1, dtype=tf.float32)), rtol=1e-4)
41 | np.testing.assert_allclose(objax_op(x2), tf_op(tf.convert_to_tensor(x2, dtype=tf.float32)), rtol=1e-4)
42 |
43 | # NOTE: Objax2Tf tests are temporary disabled until the release of TF 2.8
44 |
45 | def disabled_test_convert_wrn(self):
46 | # Make a model
47 | model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
48 | # Prediction op without JIT
49 | predict_op = objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax])
50 | predict_tf = objax.util.Objax2Tf(predict_op)
51 | # Compare results
52 | self.verify_converted_predict_op(predict_op, predict_tf,
53 | shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE))
54 | # Predict op with JIT
55 | predict_op_jit = objax.Jit(predict_op)
56 | predict_tf_jit = objax.util.Objax2Tf(predict_op_jit)
57 | # Compare results
58 | self.verify_converted_predict_op(predict_op_jit, predict_tf_jit,
59 | shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE))
60 |
61 | def disabled_test_savedmodel_wrn(self):
62 | model_dir = tempfile.mkdtemp()
63 | # Make a model and convert it to TF
64 | model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
65 | predict_op = objax.Jit(objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax]))
66 | predict_tf = objax.util.Objax2Tf(predict_op)
67 | # Save model
68 | input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)
69 | tf.saved_model.save(
70 | predict_tf,
71 | model_dir,
72 | signatures=predict_tf.__call__.get_concrete_function(tf.TensorSpec(input_shape, tf.float32)))
73 | # Load model
74 | loaded_tf_model = tf.saved_model.load(model_dir)
75 | loaded_predict_tf_op = loaded_tf_model.signatures['serving_default']
76 | self.verify_converted_predict_op(predict_op,
77 | lambda x: loaded_predict_tf_op(x)['output_0'],
78 | shape=input_shape)
79 | self.verify_converted_predict_op(predict_op,
80 | lambda x: loaded_tf_model(x),
81 | shape=input_shape)
82 | # Cleanup
83 | shutil.rmtree(model_dir)
84 |
85 |
86 | if __name__ == '__main__':
87 | unittest.main()
88 |
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | numpy
3 | tensorflow
4 |
--------------------------------------------------------------------------------
/tests/run_linter.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2020 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Change directory to repository root
17 | cd "$( dirname "${BASH_SOURCE[0]}" )/.."
18 |
19 | # Run linter with following changes to default rules:
20 | # - We allow assignment of lambda, thus ignore E731 error: https://www.flake8rules.com/rules/E731.html
21 | # - Line break should occur before binary operator, thus between W503 and W504 ignore W503 and follow W504,
22 | # https://www.flake8rules.com/rules/W503.html
23 | # - Set max line length to 120 characters
24 | # - Separately lint __init__.py and other files, otherwise flake8 complains about unused imports in __init__.py
25 | flake8 --exclude=__init__.py --max-line-length=120 --ignore=E731,W503 objax/ || exit 1
26 | flake8 --filename=__init__.py --max-line-length=120 --ignore=E731,W503 objax/ || exit 1
27 | flake8 --max-line-length=120 --ignore=E731,W503 tests/ || exit 1
28 |
--------------------------------------------------------------------------------
/tests/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2020 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Change directory to repository root
17 | cd "$( dirname "${BASH_SOURCE[0]}" )/.."
18 |
19 | if python3 -c "import pytest" &> /dev/null ; then
20 | # If pytest is installed then use it to run tests
21 | # Pytest has nicer output compared to unittest package and also it's used
22 | # to run automatic unit tests on GitHub.
23 | CUDA_VISIBLE_DEVICES= pytest tests/*.py
24 | else
25 | # If pytest is not installed then use default unittest to run tests.
26 | for i in tests/*.py; do
27 | CUDA_VISIBLE_DEVICES= python3 -m unittest $i >&$i.log &
28 | done
29 | wait
30 | fgrep FAILED tests/*.log
31 | fi
32 |
--------------------------------------------------------------------------------
/tests/scan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for scan method."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 |
21 | import objax
22 |
23 |
24 | class TestScan(unittest.TestCase):
25 | def test_scan(self):
26 | def cell(carry, x):
27 | return jn.array([2]) * carry * x, jn.array([3]) * carry * x
28 |
29 | carry = jn.array([8., 8.])
30 | output = jn.array([[3., 3.], [6., 6.], [12., 12.]])
31 | test_carry, test_output = objax.functional.scan(cell, jn.ones((2,)), jn.ones((3,)))
32 | self.assertTrue(jn.array_equal(carry, test_carry))
33 | self.assertTrue(jn.array_equal(output, test_output))
34 |
--------------------------------------------------------------------------------
/tests/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for optimizers."""
16 |
17 | import unittest
18 |
19 | import numpy as np
20 |
21 | import objax
22 |
23 |
24 | class TestScheduler(unittest.TestCase):
25 | def test_linear_annealing(self):
26 | sched = objax.optimizer.scheduler.LinearAnnealing(max_step=10, base_lr=1, is_cycle=True, min_lr=0)
27 | lrs = []
28 | for i in range(10):
29 | lrs.append(sched(step=i))
30 | lrs_gt = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
31 | np.testing.assert_array_almost_equal(lrs, lrs_gt)
32 |
33 | def test_step_decay(self):
34 | sched = objax.optimizer.scheduler.StepDecay(step_size=3, base_lr=1, gamma=0.9)
35 | lrs = []
36 | for i in range(10):
37 | lrs.append(sched(step=i))
38 | lrs_gt = [1, 1, 1, 0.9, 0.9, 0.9, 0.81, 0.81, 0.81, 0.729]
39 | np.testing.assert_array_almost_equal(lrs, lrs_gt)
40 |
41 | def test_multi_step_decay(self):
42 | sched = objax.optimizer.scheduler.StepDecay(step_size=[3, 5, 8], base_lr=1, gamma=0.9)
43 | lrs = []
44 | for i in range(10):
45 | lrs.append(sched(step=i))
46 | lrs_gt = [1, 1, 1, 0.9, 0.9, 0.81, 0.81, 0.81, 0.729, 0.729]
47 | np.testing.assert_array_almost_equal(lrs, lrs_gt)
48 |
49 |
50 | if __name__ == '__main__':
51 | unittest.main()
52 |
--------------------------------------------------------------------------------
/tests/sequential.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for Convolution Layer."""
16 |
17 | import unittest
18 |
19 | import jax.numpy as jn
20 |
21 | import objax
22 |
23 |
24 | class TestSequential(unittest.TestCase):
25 | def test_on_sequential_linear_relu(self):
26 | """
27 | Pass an input through a linear filter with 3 units followed by ReLU and
28 | test the shape and contents of the output.
29 | """
30 |
31 | # Define linear filter with 1 input channel and 3 output channels
32 | linear_filter = objax.nn.Linear(2, 3, use_bias=False)
33 | weights = objax.TrainVar(jn.array([[1., 2., 1.], [2., 1., 2.]]))
34 | linear_filter.w = weights
35 | sequential = objax.nn.Sequential([linear_filter,
36 | objax.functional.relu])
37 |
38 | # Define data and compute output response of linear filter
39 | data = jn.array([[1., -1.], [2., -2.]])
40 | features = sequential(data)
41 | expected_features = jn.array([[0., 1., 0.], [0., 2., 0.]])
42 | self.assertEqual(features.shape, (2, 3))
43 | self.assertTrue(jn.array_equal(features, expected_features))
44 |
45 | def test_on_sequential_relu_linear(self):
46 | """
47 | Pass an input through a linear filter with 3 units followed by ReLU and
48 | test the shape and contents of the output.
49 | """
50 |
51 | # Define linear filter with 1 input channel and 3 output channels
52 | linear_filter = objax.nn.Linear(2, 3, use_bias=False)
53 | weights = objax.TrainVar(jn.array([[1., 2., 1.], [2., 1., 2.]]))
54 | linear_filter.w = weights
55 | sequential = objax.nn.Sequential([objax.functional.relu,
56 | linear_filter])
57 |
58 | # Define data and compute output response of linear filter
59 | data = jn.array([[1., -1.], [2., -2.]])
60 | features = sequential(data)
61 | expected_features = jn.array([[1., 2., 1.], [2., 4., 2.]])
62 | self.assertEqual(features.shape, (2, 3))
63 | self.assertTrue(jn.array_equal(features, expected_features))
64 |
65 | def test_kwargs(self):
66 | """Test sequential on modules that take named inputs in kwargs."""
67 |
68 | class MyModule:
69 | def __init__(self):
70 | pass
71 |
72 | def __call__(self, x, some_param):
73 | return x + some_param
74 |
75 | seq = objax.nn.Sequential([MyModule(), MyModule()])
76 | self.assertEqual(seq(1, some_param=2), 5)
77 | with self.assertRaises(TypeError):
78 | seq(1)
79 |
80 | def test_variadic(self):
81 | """Test sequential on modules that take multiple inputs and have multiple outputs."""
82 |
83 | class MyModule:
84 | def __init__(self):
85 | pass
86 |
87 | def __call__(self, x, y):
88 | return x + y, x - y
89 |
90 | seq = objax.nn.Sequential([MyModule(), MyModule()])
91 | self.assertEqual(seq(1, 2), (2, 4))
92 |
93 | def test_slice(self):
94 | """Test sequential slices with variadic module."""
95 |
96 | class MyModule:
97 | def __init__(self, m):
98 | self.m = m
99 |
100 | def __call__(self, x, y):
101 | return self.m * x + y, self.m * x - y
102 |
103 | seq = objax.nn.Sequential([MyModule(2), MyModule(3)])
104 | self.assertEqual(seq(5, 7), (54, 48))
105 | self.assertEqual(seq[:1](5, 7), (17, 3))
106 | self.assertEqual(seq[1:](5, 7), (22, 8))
107 |
108 | def test_on_sequential_missing_argument(self):
109 | m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3), objax.nn.Linear(3, 2)])
110 | x = jn.array([[1., -1.], [2., -2.]])
111 | msg = "missing 1 required positional argument: 'training'"
112 | try:
113 | m(x)
114 | assert False
115 | except TypeError as e:
116 | self.assertIn(msg, str(e))
117 | m.pop()
118 | try:
119 | m(x)
120 | assert False
121 | except TypeError as e:
122 | self.assertIn(msg, str(e))
123 |
124 |
125 | if __name__ == '__main__':
126 | unittest.main()
127 |
--------------------------------------------------------------------------------
/tests/util_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for objax.util.image."""
16 |
17 | import io
18 | import tempfile
19 | import unittest
20 | from typing import Tuple
21 |
22 | import jax.numpy as jn
23 | import numpy as np
24 | from PIL import Image
25 |
26 | import objax
27 |
28 |
29 | class TestUtilImage(unittest.TestCase):
30 | def ndimarange(self, dims: Tuple[int, ...]):
31 | return np.arange(np.prod(dims), dtype=float).reshape(dims)
32 |
33 | def test_nchw(self):
34 | x = self.ndimarange((2, 3, 4, 5))
35 | self.assertEqual(objax.util.image.nchw(x).tolist(), x.transpose((0, 3, 1, 2)).tolist())
36 | self.assertEqual(objax.util.image.nchw(jn.array(x)).tolist(), x.transpose((0, 3, 1, 2)).tolist())
37 | x = self.ndimarange((2, 3, 4, 5, 6))
38 | self.assertEqual(objax.util.image.nchw(x).tolist(), x.transpose((0, 1, 4, 2, 3)).tolist())
39 | self.assertEqual(objax.util.image.nchw(jn.array(x)).tolist(), x.transpose((0, 1, 4, 2, 3)).tolist())
40 |
41 | def test_nhwc(self):
42 | x = self.ndimarange((2, 3, 4, 5))
43 | self.assertEqual(objax.util.image.nhwc(x).tolist(), x.transpose((0, 2, 3, 1)).tolist())
44 | self.assertEqual(objax.util.image.nhwc(jn.array(x)).tolist(), x.transpose((0, 2, 3, 1)).tolist())
45 | x = self.ndimarange((2, 3, 4, 5, 6))
46 | self.assertEqual(objax.util.image.nhwc(x).tolist(), x.transpose((0, 1, 3, 4, 2)).tolist())
47 | self.assertEqual(objax.util.image.nhwc(jn.array(x)).tolist(), x.transpose((0, 1, 3, 4, 2)).tolist())
48 |
49 | def test_normalize(self):
50 | """Test normalize methods."""
51 | x = np.arange(256)
52 | y = objax.util.image.normalize_to_unit_float(x)
53 | self.assertEqual((x / 128 - (1 - 1 / 256)).tolist(), y.tolist())
54 | self.assertEqual(y.tolist(), y.clip(-1, 1).tolist())
55 | z = objax.util.image.normalize_to_uint8(y)
56 | self.assertEqual(x.tolist(), z.tolist())
57 | z = objax.util.image.normalize_to_uint8(y + 1 / 128)
58 | self.assertEqual((x + 1).clip(0, 255).tolist(), z.tolist())
59 | z = objax.util.image.normalize_to_uint8(y - 1 / 128)
60 | self.assertEqual((x - 1).clip(0, 255).tolist(), z.tolist())
61 |
62 | def test_to_png(self):
63 | x = np.zeros((3, 32, 32), float) + 1 / 255
64 | x[:, :12, :12] = 1
65 | x[:, -12:, -12:] = -1
66 | y = objax.util.image.to_png(x)
67 | self.assertEqual(
68 | np.array(Image.open(io.BytesIO(y))).tolist(),
69 | np.array(Image.open(io.BytesIO(
70 | b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00 \x00\x00\x00 \x08\x02\x00\x00\x00\xfc'
71 | b'\x18\xed\xa3\x00\x00\x00FIDATx\x9cc\xfc\xff\xff?\x03!\xd0\xd8\xd8HP\r.\xc0D\xb6\xceQ'
72 | b'\x0bF-\x18\xb5`\x04Y\xc0BI9C\x0c\x18\xfaA4j\xc1\x08\xb0\x80\x85\x12\xcd\r\r\r\x04\xd5'
73 | b'\x0c\xfd \x1a\xb5`\xd4\x82Q\x0b\xe8`\x01\x00\xe3\xf1\x07\xc7\x82\x83p\xa5\x00\x00\x00\x00'
74 | b'IEND\xaeB`\x82'
75 | ))).tolist())
76 | z = np.array(Image.open(io.BytesIO(y)))
77 | z = (z.transpose((2, 0, 1)) - 127.5) / 127.5
78 | self.assertEqual(x.tolist(), z.tolist())
79 |
80 | def test_to_png_from_file(self):
81 | x = objax.random.randint((3, 32, 24), 0, 256)
82 | x = objax.util.image.normalize_to_unit_float(x)
83 | bin = objax.util.image.to_png(x)
84 | y = objax.util.image.from_file(io.BytesIO(bin))
85 | self.assertEqual(x.tolist(), y.tolist())
86 |
87 | def test_image_grid(self):
88 | x = objax.random.randint((5, 7, 3, 8, 4), 0, 256)
89 | y = objax.util.image.image_grid(x)
90 | z = x.transpose((2, 0, 3, 1, 4)).reshape((3, 40, 28))
91 | self.assertEqual(y.tolist(), z.tolist())
92 |
93 | def test_from_file_with_filename(self):
94 | x = objax.random.randint((3, 32, 24), 0, 256)
95 | x = objax.util.image.normalize_to_unit_float(x)
96 | with tempfile.NamedTemporaryFile('wb', suffix='.png') as f:
97 | f.write(objax.util.image.to_png(x))
98 | f.flush()
99 | y = objax.util.image.from_file(f.name)
100 | self.assertEqual(x.tolist(), y.tolist())
101 |
102 |
103 | if __name__ == '__main__':
104 | unittest.main()
105 |
--------------------------------------------------------------------------------
/tests/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unittests for Resnet v2."""
16 |
17 | import unittest
18 |
19 | import objax
20 | from objax.zoo.wide_resnet import WideResNet, WideResNetGeneral
21 |
22 |
23 | class TestWideResNetGeneral(unittest.TestCase):
24 |
25 | def test_wide_resnet_general(self):
26 | x = objax.random.normal((4, 3, 128, 128))
27 | model = WideResNetGeneral(nin=3, nclass=10, blocks_per_group=[4, 4, 4, 4], width=2)
28 | # run in eval mode
29 | y_eval = model(x, training=False)
30 | self.assertEqual(y_eval.shape, (4, 10))
31 | # run in train mode
32 | y_eval = model(x, training=True)
33 | self.assertEqual(y_eval.shape, (4, 10))
34 |
35 | def test_wide_resnet(self):
36 | x = objax.random.normal((4, 3, 32, 32))
37 | model = WideResNet(nin=3, nclass=10, depth=28, width=4)
38 | # run in eval mode
39 | y_eval = model(x, training=False)
40 | self.assertEqual(y_eval.shape, (4, 10))
41 | # run in train mode
42 | y_eval = model(x, training=True)
43 | self.assertEqual(y_eval.shape, (4, 10))
44 |
45 |
46 | if __name__ == '__main__':
47 | unittest.main()
48 |
--------------------------------------------------------------------------------